1use alloc::vec::Vec;
2use arrayvec::ArrayVec;
3use ash::vk;
4use hashbrown::{HashMap, HashSet};
5
6const POOL_MIN_SETS: u32 = 64;
7const POOL_MAX_SETS: u32 = 512;
8
9#[derive(Debug, Default, Clone, PartialEq, Eq, Hash)]
12pub struct DescriptorCounts {
13 pub sampler: u32,
14 pub sampled_image: u32,
15 pub storage_image: u32,
16 pub uniform_buffer: u32,
17 pub uniform_buffer_dynamic: u32,
18 pub storage_buffer: u32,
19 pub storage_buffer_dynamic: u32,
20 pub acceleration_structure: u32,
21}
22
23impl DescriptorCounts {
24 fn total(&self) -> u32 {
25 self.sampler
26 + self.sampled_image
27 + self.storage_image
28 + self.uniform_buffer
29 + self.uniform_buffer_dynamic
30 + self.storage_buffer
31 + self.storage_buffer_dynamic
32 + self.acceleration_structure
33 }
34}
35
36#[derive(Debug)]
37pub struct DescriptorSet {
38 raw: vk::DescriptorSet,
39 bucket_key: BucketKey,
40 pool_index: usize,
41}
42
43impl DescriptorSet {
44 pub fn raw(&self) -> vk::DescriptorSet {
45 self.raw
46 }
47}
48
49#[derive(Debug, PartialEq, Eq, Hash)]
50struct BucketKey {
51 counts: DescriptorCounts,
52 update_after_bind: bool,
53}
54
55struct Pool {
56 raw: vk::DescriptorPool,
57 capacity: u32,
58 available: u32,
59}
60
61#[derive(Default)]
63struct Bucket {
64 layouts: HashSet<vk::DescriptorSetLayout>,
67 pools: Vec<Pool>,
68 available_sets: u32,
69 allocated_sets: u32,
70}
71
72impl Bucket {
73 fn create_pool(
74 &mut self,
75 device: &ash::Device,
76 key: &BucketKey,
77 capacity_hint: u32,
78 ) -> Result<(usize, &mut Pool), crate::DeviceError> {
79 let index = self.pools.len();
80 let pool = create_descriptor_pool(device, key, capacity_hint)?;
81 self.available_sets += pool.capacity;
82 self.pools.push(pool);
83 Ok((index, self.pools.last_mut().unwrap()))
84 }
85}
86
87pub struct DescriptorAllocator {
88 buckets: HashMap<BucketKey, Bucket>,
89 max_update_after_bind_descriptors_in_all_pools: u32,
90 update_after_bind_descriptors_in_all_pools: u32,
91}
92
93impl super::BindGroupLayout {
94 fn bucket_key(&self) -> BucketKey {
95 let update_after_bind = self.contains_binding_arrays;
96 let counts = self.desc_count.clone();
97 BucketKey {
98 counts,
99 update_after_bind,
100 }
101 }
102}
103
104impl DescriptorAllocator {
105 pub fn new(max_update_after_bind_descriptors_in_all_pools: u32) -> Self {
106 DescriptorAllocator {
107 buckets: HashMap::default(),
108 max_update_after_bind_descriptors_in_all_pools,
109 update_after_bind_descriptors_in_all_pools: 0,
110 }
111 }
112
113 pub fn register_layout(
114 &mut self,
115 device: &ash::Device,
116 layout: &super::BindGroupLayout,
117 ) -> Result<(), crate::DeviceError> {
118 let key = layout.bucket_key();
119 let bucket = match self.buckets.entry(key) {
120 hashbrown::hash_map::Entry::Occupied(occupied_entry) => occupied_entry.into_mut(),
121 hashbrown::hash_map::Entry::Vacant(vacant_entry) => {
122 let mut bucket = Bucket::default();
123 bucket.create_pool(device, vacant_entry.key(), POOL_MIN_SETS)?;
126 vacant_entry.insert(bucket)
127 }
128 };
129
130 assert!(bucket.layouts.insert(layout.raw));
131
132 Ok(())
133 }
134
135 pub fn unregister_layout(&mut self, device: &ash::Device, layout: &super::BindGroupLayout) {
136 let key = layout.bucket_key();
137 let bucket = self.buckets.get_mut(&key).unwrap();
138
139 assert!(bucket.layouts.remove(&layout.raw));
140
141 if bucket.layouts.is_empty() {
142 let bucket = self.buckets.remove(&key).unwrap();
144 for pool in bucket.pools {
145 assert_eq!(
146 pool.available, pool.capacity,
147 "pool is not empty, at least one DescriptorSet has not been freed"
148 );
149 unsafe { device.destroy_descriptor_pool(pool.raw, None) };
150 }
151 }
152 }
153
154 pub unsafe fn alloc(
155 &mut self,
156 device: &ash::Device,
157 layout: &super::BindGroupLayout,
158 ) -> Result<DescriptorSet, crate::DeviceError> {
159 let update_after_bind = layout.contains_binding_arrays;
160 let total_descriptors = layout.desc_count.total();
161
162 if update_after_bind
163 && self.max_update_after_bind_descriptors_in_all_pools
164 - self.update_after_bind_descriptors_in_all_pools
165 < total_descriptors
166 {
167 return Err(crate::DeviceError::OutOfMemory);
168 }
169
170 let key = layout.bucket_key();
171 let bucket = self.buckets.get_mut(&key).unwrap();
172
173 let pool = bucket
177 .pools
178 .iter_mut()
179 .enumerate()
180 .find(|(_, pool)| pool.available != 0);
181
182 let (pool_index, pool) = if let Some(pool) = pool {
183 pool
184 } else {
185 let capacity_hint = bucket.allocated_sets;
186 bucket.create_pool(device, &key, capacity_hint)?
187 };
188
189 let vk_info = vk::DescriptorSetAllocateInfo::default()
190 .descriptor_pool(pool.raw)
191 .set_layouts(core::slice::from_ref(&layout.raw));
192
193 let raw = match unsafe { device.allocate_descriptor_sets(&vk_info) } {
194 Ok(sets) => Ok(sets[0]),
195 Err(vk::Result::ERROR_OUT_OF_POOL_MEMORY) => unreachable!(),
197 Err(vk::Result::ERROR_FRAGMENTED_POOL) => unreachable!(),
209 Err(err) => Err(super::map_host_device_oom_err(err)),
210 }?;
211
212 pool.available -= 1;
213 bucket.available_sets -= 1;
214 bucket.allocated_sets += 1;
215 if update_after_bind {
216 self.update_after_bind_descriptors_in_all_pools += total_descriptors;
217 }
218
219 Ok(DescriptorSet {
220 raw,
221 bucket_key: key,
222 pool_index,
223 })
224 }
225
226 pub unsafe fn free(&mut self, device: &ash::Device, set: DescriptorSet) {
227 let bucket = self.buckets.get_mut(&set.bucket_key).unwrap();
228 let pool = bucket.pools.get_mut(set.pool_index).unwrap();
229
230 let result =
231 unsafe { device.free_descriptor_sets(pool.raw, core::slice::from_ref(&set.raw())) };
232 if let Err(err) = result {
233 panic!("vkFreeDescriptorSets error: {err}, please report this error");
240 }
241
242 pool.available += 1;
243 bucket.available_sets += 1;
244 bucket.allocated_sets -= 1;
245 if set.bucket_key.update_after_bind {
246 self.update_after_bind_descriptors_in_all_pools -= set.bucket_key.counts.total();
247 }
248
249 let pool = bucket.pools.last().unwrap();
255 if pool.available == pool.capacity
256 && bucket.available_sets - pool.capacity > pool.capacity / 4
257 {
258 let pool = bucket.pools.pop().unwrap();
259 unsafe { device.destroy_descriptor_pool(pool.raw, None) };
260 bucket.available_sets -= pool.capacity;
261 }
262 }
263}
264
265impl Drop for DescriptorAllocator {
266 fn drop(&mut self) {
267 if !std::thread::panicking() {
268 assert!(
269 self.buckets.is_empty(),
270 "buckets are not empty, at least one BGL has not been unregistered"
271 )
272 }
273 }
274}
275
276fn create_descriptor_pool(
277 device: &ash::Device,
278 key: &BucketKey,
279 capacity_hint: u32,
280) -> Result<Pool, crate::DeviceError> {
281 let counts = &key.counts;
282
283 const NR_OF_DESCRIPTOR_TYPES: usize = 8;
284
285 use vk::DescriptorType as Dt;
286 let counts: [_; NR_OF_DESCRIPTOR_TYPES] = [
287 (Dt::SAMPLER, counts.sampler),
288 (Dt::SAMPLED_IMAGE, counts.sampled_image),
289 (Dt::STORAGE_IMAGE, counts.storage_image),
290 (Dt::UNIFORM_BUFFER, counts.uniform_buffer),
291 (Dt::UNIFORM_BUFFER_DYNAMIC, counts.uniform_buffer_dynamic),
292 (Dt::STORAGE_BUFFER, counts.storage_buffer),
293 (Dt::STORAGE_BUFFER_DYNAMIC, counts.storage_buffer_dynamic),
294 (
295 Dt::ACCELERATION_STRUCTURE_KHR,
296 counts.acceleration_structure,
297 ),
298 ];
299
300 let mut capacity = capacity_hint
302 .clamp(POOL_MIN_SETS, POOL_MAX_SETS)
303 .next_power_of_two();
304
305 for (_, count) in counts {
309 capacity = (u32::MAX / count.max(1)).min(capacity);
310 }
311
312 let pool_sizes = counts
313 .into_iter()
314 .filter(|&(_, count)| count != 0)
315 .map(|(ty, count)| vk::DescriptorPoolSize {
316 ty,
317 descriptor_count: count * capacity,
318 })
319 .collect::<ArrayVec<_, NR_OF_DESCRIPTOR_TYPES>>();
320
321 let mut flags = vk::DescriptorPoolCreateFlags::FREE_DESCRIPTOR_SET;
322 if key.update_after_bind {
323 flags |= vk::DescriptorPoolCreateFlags::UPDATE_AFTER_BIND;
324 };
325
326 let vk_info = vk::DescriptorPoolCreateInfo::default()
327 .flags(flags)
328 .max_sets(capacity)
329 .pool_sizes(&pool_sizes);
330
331 let raw = unsafe { device.create_descriptor_pool(&vk_info, None) }
332 .map_err(super::map_host_device_oom_and_fragmentation_err)?;
333
334 Ok(Pool {
335 raw,
336 capacity,
337 available: capacity,
338 })
339}