wgpu_core/indirect_validation/
draw.rs

1use super::{
2    utils::{BufferBarrierScratch, BufferBarriers, UniqueIndexExt as _, UniqueIndexScratch},
3    CreateIndirectValidationPipelineError,
4};
5use crate::{
6    command::RenderPassErrorInner,
7    device::{queue::TempResource, Device, DeviceError},
8    lock::{rank, Mutex},
9    pipeline::{CreateComputePipelineError, CreateShaderModuleError},
10    resource::{RawResourceAccess as _, StagingBuffer, Trackable},
11    snatch::SnatchGuard,
12    track::TrackerIndex,
13    FastHashMap,
14};
15use alloc::{boxed::Box, string::ToString, sync::Arc, vec, vec::Vec};
16use core::{
17    mem::{size_of, size_of_val},
18    num::NonZeroU64,
19};
20use wgt::Limits;
21
22/// Note: This needs to be under:
23///
24/// default max_compute_workgroups_per_dimension * size_of::<wgt::DrawIndirectArgs>() * `workgroup_size` used by the shader
25///
26/// = (2^16 - 1) * 2^4 * 2^6
27///
28/// It is currently set to:
29///
30/// = (2^16 - 1) * 2^4
31///
32/// This is enough space for:
33///
34/// - 65535 [`wgt::DrawIndirectArgs`] / [`MetadataEntry`]
35/// - 52428 [`wgt::DrawIndexedIndirectArgs`]
36const BUFFER_SIZE: wgt::BufferSize = unsafe { wgt::BufferSize::new_unchecked(1_048_560) };
37
38/// Holds all device-level resources that are needed to validate indirect draws.
39///
40/// This machinery requires the following limits:
41///
42/// - max_bind_groups: 3,
43/// - max_dynamic_storage_buffers_per_pipeline_layout: 1,
44/// - max_storage_buffers_per_shader_stage: 3,
45/// - max_immediate_size: 8,
46///
47/// These are all indirectly satisfied by `DownlevelFlags::INDIRECT_EXECUTION`, which is also
48/// required for this module's functionality to work.
49#[derive(Debug)]
50pub(crate) struct Draw {
51    module: Box<dyn hal::DynShaderModule>,
52    metadata_bind_group_layout: Box<dyn hal::DynBindGroupLayout>,
53    src_bind_group_layout: Box<dyn hal::DynBindGroupLayout>,
54    dst_bind_group_layout: Box<dyn hal::DynBindGroupLayout>,
55    pipeline_layout: Box<dyn hal::DynPipelineLayout>,
56    pipeline: Box<dyn hal::DynComputePipeline>,
57
58    free_indirect_entries: Mutex<Vec<BufferPoolEntry>>,
59    free_metadata_entries: Mutex<Vec<BufferPoolEntry>>,
60}
61
62impl Draw {
63    pub(super) fn new(
64        device: &dyn hal::DynDevice,
65        required_features: &wgt::Features,
66        backend: wgt::Backend,
67    ) -> Result<Self, CreateIndirectValidationPipelineError> {
68        let module = create_validation_module(device)?;
69
70        let metadata_bind_group_layout =
71            create_bind_group_layout(device, true, false, BUFFER_SIZE)?;
72        let src_bind_group_layout =
73            create_bind_group_layout(device, true, true, wgt::BufferSize::new(4 * 4).unwrap())?;
74        let dst_bind_group_layout = create_bind_group_layout(device, false, false, BUFFER_SIZE)?;
75
76        let pipeline_layout_desc = hal::PipelineLayoutDescriptor {
77            label: None,
78            flags: hal::PipelineLayoutFlags::empty(),
79            bind_group_layouts: &[
80                metadata_bind_group_layout.as_ref(),
81                src_bind_group_layout.as_ref(),
82                dst_bind_group_layout.as_ref(),
83            ],
84            immediates_ranges: &[wgt::ImmediateRange {
85                stages: wgt::ShaderStages::COMPUTE,
86                range: 0..8,
87            }],
88        };
89        let pipeline_layout = unsafe {
90            device
91                .create_pipeline_layout(&pipeline_layout_desc)
92                .map_err(DeviceError::from_hal)?
93        };
94
95        let supports_indirect_first_instance =
96            required_features.contains(wgt::Features::INDIRECT_FIRST_INSTANCE);
97        let write_d3d12_special_constants = backend == wgt::Backend::Dx12;
98        let pipeline = create_validation_pipeline(
99            device,
100            module.as_ref(),
101            pipeline_layout.as_ref(),
102            supports_indirect_first_instance,
103            write_d3d12_special_constants,
104        )?;
105
106        Ok(Self {
107            module,
108            metadata_bind_group_layout,
109            src_bind_group_layout,
110            dst_bind_group_layout,
111            pipeline_layout,
112            pipeline,
113
114            free_indirect_entries: Mutex::new(rank::BUFFER_POOL, Vec::new()),
115            free_metadata_entries: Mutex::new(rank::BUFFER_POOL, Vec::new()),
116        })
117    }
118
119    /// `Ok(None)` will only be returned if `buffer_size` is `0`.
120    pub(super) fn create_src_bind_group(
121        &self,
122        device: &dyn hal::DynDevice,
123        limits: &Limits,
124        buffer_size: u64,
125        buffer: &dyn hal::DynBuffer,
126    ) -> Result<Option<Box<dyn hal::DynBindGroup>>, DeviceError> {
127        let binding_size = calculate_src_buffer_binding_size(buffer_size, limits);
128        let Some(binding_size) = NonZeroU64::new(binding_size) else {
129            return Ok(None);
130        };
131        let hal_desc = hal::BindGroupDescriptor {
132            label: None,
133            layout: self.src_bind_group_layout.as_ref(),
134            entries: &[hal::BindGroupEntry {
135                binding: 0,
136                resource_index: 0,
137                count: 1,
138            }],
139            // SAFETY: We calculated the binding size to fit within the buffer.
140            buffers: &[hal::BufferBinding::new_unchecked(buffer, 0, binding_size)],
141            samplers: &[],
142            textures: &[],
143            acceleration_structures: &[],
144            external_textures: &[],
145        };
146        unsafe {
147            device
148                .create_bind_group(&hal_desc)
149                .map(Some)
150                .map_err(DeviceError::from_hal)
151        }
152    }
153
154    fn acquire_dst_entry(
155        &self,
156        device: &dyn hal::DynDevice,
157    ) -> Result<BufferPoolEntry, hal::DeviceError> {
158        let mut free_buffers = self.free_indirect_entries.lock();
159        match free_buffers.pop() {
160            Some(buffer) => Ok(buffer),
161            None => {
162                let usage = wgt::BufferUses::INDIRECT | wgt::BufferUses::STORAGE_READ_WRITE;
163                create_buffer_and_bind_group(device, usage, self.dst_bind_group_layout.as_ref())
164            }
165        }
166    }
167
168    fn release_dst_entries(&self, entries: impl Iterator<Item = BufferPoolEntry>) {
169        self.free_indirect_entries.lock().extend(entries);
170    }
171
172    fn acquire_metadata_entry(
173        &self,
174        device: &dyn hal::DynDevice,
175    ) -> Result<BufferPoolEntry, hal::DeviceError> {
176        let mut free_buffers = self.free_metadata_entries.lock();
177        match free_buffers.pop() {
178            Some(buffer) => Ok(buffer),
179            None => {
180                let usage = wgt::BufferUses::COPY_DST | wgt::BufferUses::STORAGE_READ_ONLY;
181                create_buffer_and_bind_group(
182                    device,
183                    usage,
184                    self.metadata_bind_group_layout.as_ref(),
185                )
186            }
187        }
188    }
189
190    fn release_metadata_entries(&self, entries: impl Iterator<Item = BufferPoolEntry>) {
191        self.free_metadata_entries.lock().extend(entries);
192    }
193
194    /// Injects a compute pass that will validate all indirect draws in the current render pass.
195    pub(crate) fn inject_validation_pass(
196        &self,
197        device: &Arc<Device>,
198        snatch_guard: &SnatchGuard,
199        resources: &mut DrawResources,
200        temp_resources: &mut Vec<TempResource>,
201        encoder: &mut dyn hal::DynCommandEncoder,
202        batcher: DrawBatcher,
203    ) -> Result<(), RenderPassErrorInner> {
204        let mut batches = batcher.batches;
205
206        if batches.is_empty() {
207            return Ok(());
208        }
209
210        let max_staging_buffer_size = 1 << 26; // ~67MiB
211
212        let mut staging_buffers = Vec::new();
213
214        let mut current_size = 0;
215        for batch in batches.values_mut() {
216            let data = batch.metadata();
217            let offset = if current_size + data.len() > max_staging_buffer_size {
218                let staging_buffer =
219                    StagingBuffer::new(device, NonZeroU64::new(current_size as u64).unwrap())?;
220                staging_buffers.push(staging_buffer);
221                current_size = data.len();
222                0
223            } else {
224                let offset = current_size;
225                current_size += data.len();
226                offset as u64
227            };
228            batch.staging_buffer_index = staging_buffers.len();
229            batch.staging_buffer_offset = offset;
230        }
231        if current_size != 0 {
232            let staging_buffer =
233                StagingBuffer::new(device, NonZeroU64::new(current_size as u64).unwrap())?;
234            staging_buffers.push(staging_buffer);
235        }
236
237        for batch in batches.values() {
238            let data = batch.metadata();
239            let staging_buffer = &mut staging_buffers[batch.staging_buffer_index];
240            unsafe {
241                staging_buffer.write_with_offset(
242                    data,
243                    0,
244                    batch.staging_buffer_offset as isize,
245                    data.len(),
246                )
247            };
248        }
249
250        let staging_buffers: Vec<_> = staging_buffers
251            .into_iter()
252            .map(|buffer| buffer.flush())
253            .collect();
254
255        let mut current_metadata_entry = None;
256        for batch in batches.values_mut() {
257            let data = batch.metadata();
258            let (metadata_resource_index, metadata_buffer_offset) =
259                resources.get_metadata_subrange(data.len() as u64, &mut current_metadata_entry)?;
260            batch.metadata_resource_index = metadata_resource_index;
261            batch.metadata_buffer_offset = metadata_buffer_offset;
262        }
263
264        let buffer_barrier_scratch = &mut BufferBarrierScratch::new();
265        let unique_index_scratch = &mut UniqueIndexScratch::new();
266
267        BufferBarriers::new(buffer_barrier_scratch)
268            .extend(
269                batches
270                    .values()
271                    .map(|batch| batch.staging_buffer_index)
272                    .unique(unique_index_scratch)
273                    .map(|index| hal::BufferBarrier {
274                        buffer: staging_buffers[index].raw(),
275                        usage: hal::StateTransition {
276                            from: wgt::BufferUses::MAP_WRITE,
277                            to: wgt::BufferUses::COPY_SRC,
278                        },
279                    }),
280            )
281            .extend(
282                batches
283                    .values()
284                    .map(|batch| batch.metadata_resource_index)
285                    .unique(unique_index_scratch)
286                    .map(|index| hal::BufferBarrier {
287                        buffer: resources.get_metadata_buffer(index),
288                        usage: hal::StateTransition {
289                            from: wgt::BufferUses::STORAGE_READ_ONLY,
290                            to: wgt::BufferUses::COPY_DST,
291                        },
292                    }),
293            )
294            .encode(encoder);
295
296        for batch in batches.values() {
297            let data = batch.metadata();
298            let data_size = NonZeroU64::new(data.len() as u64).unwrap();
299
300            let staging_buffer = &staging_buffers[batch.staging_buffer_index];
301
302            let metadata_buffer = resources.get_metadata_buffer(batch.metadata_resource_index);
303
304            unsafe {
305                encoder.copy_buffer_to_buffer(
306                    staging_buffer.raw(),
307                    metadata_buffer,
308                    &[hal::BufferCopy {
309                        src_offset: batch.staging_buffer_offset,
310                        dst_offset: batch.metadata_buffer_offset,
311                        size: data_size,
312                    }],
313                );
314            }
315        }
316
317        for staging_buffer in staging_buffers {
318            temp_resources.push(TempResource::StagingBuffer(staging_buffer));
319        }
320
321        BufferBarriers::new(buffer_barrier_scratch)
322            .extend(
323                batches
324                    .values()
325                    .map(|batch| batch.metadata_resource_index)
326                    .unique(unique_index_scratch)
327                    .map(|index| hal::BufferBarrier {
328                        buffer: resources.get_metadata_buffer(index),
329                        usage: hal::StateTransition {
330                            from: wgt::BufferUses::COPY_DST,
331                            to: wgt::BufferUses::STORAGE_READ_ONLY,
332                        },
333                    }),
334            )
335            .extend(
336                batches
337                    .values()
338                    .map(|batch| batch.dst_resource_index)
339                    .unique(unique_index_scratch)
340                    .map(|index| hal::BufferBarrier {
341                        buffer: resources.get_dst_buffer(index),
342                        usage: hal::StateTransition {
343                            from: wgt::BufferUses::INDIRECT,
344                            to: wgt::BufferUses::STORAGE_READ_WRITE,
345                        },
346                    }),
347            )
348            .encode(encoder);
349
350        let desc = hal::ComputePassDescriptor {
351            label: None,
352            timestamp_writes: None,
353        };
354        unsafe {
355            encoder.begin_compute_pass(&desc);
356        }
357        unsafe {
358            encoder.set_compute_pipeline(self.pipeline.as_ref());
359        }
360
361        for batch in batches.values() {
362            let pipeline_layout = self.pipeline_layout.as_ref();
363
364            let metadata_start =
365                (batch.metadata_buffer_offset / size_of::<MetadataEntry>() as u64) as u32;
366            let metadata_count = batch.entries.len() as u32;
367            unsafe {
368                encoder.set_immediates(
369                    pipeline_layout,
370                    wgt::ShaderStages::COMPUTE,
371                    0,
372                    &[metadata_start, metadata_count],
373                );
374            }
375
376            let metadata_bind_group =
377                resources.get_metadata_bind_group(batch.metadata_resource_index);
378            unsafe {
379                encoder.set_bind_group(pipeline_layout, 0, Some(metadata_bind_group), &[]);
380            }
381
382            // Make sure the indirect buffer is still valid.
383            batch.src_buffer.try_raw(snatch_guard)?;
384
385            let src_bind_group = batch
386                .src_buffer
387                .indirect_validation_bind_groups
388                .get(snatch_guard)
389                .unwrap()
390                .draw
391                .as_ref();
392            unsafe {
393                encoder.set_bind_group(
394                    pipeline_layout,
395                    1,
396                    Some(src_bind_group),
397                    &[batch.src_dynamic_offset as u32],
398                );
399            }
400
401            let dst_bind_group = resources.get_dst_bind_group(batch.dst_resource_index);
402            unsafe {
403                encoder.set_bind_group(pipeline_layout, 2, Some(dst_bind_group), &[]);
404            }
405
406            unsafe {
407                encoder.dispatch([(batch.entries.len() as u32).div_ceil(64), 1, 1]);
408            }
409        }
410
411        unsafe {
412            encoder.end_compute_pass();
413        }
414
415        BufferBarriers::new(buffer_barrier_scratch)
416            .extend(
417                batches
418                    .values()
419                    .map(|batch| batch.dst_resource_index)
420                    .unique(unique_index_scratch)
421                    .map(|index| hal::BufferBarrier {
422                        buffer: resources.get_dst_buffer(index),
423                        usage: hal::StateTransition {
424                            from: wgt::BufferUses::STORAGE_READ_WRITE,
425                            to: wgt::BufferUses::INDIRECT,
426                        },
427                    }),
428            )
429            .encode(encoder);
430
431        Ok(())
432    }
433
434    pub(super) fn dispose(self, device: &dyn hal::DynDevice) {
435        let Draw {
436            module,
437            metadata_bind_group_layout,
438            src_bind_group_layout,
439            dst_bind_group_layout,
440            pipeline_layout,
441            pipeline,
442
443            free_indirect_entries,
444            free_metadata_entries,
445        } = self;
446
447        for entry in free_indirect_entries.into_inner().drain(..) {
448            unsafe {
449                device.destroy_bind_group(entry.bind_group);
450                device.destroy_buffer(entry.buffer);
451            }
452        }
453
454        for entry in free_metadata_entries.into_inner().drain(..) {
455            unsafe {
456                device.destroy_bind_group(entry.bind_group);
457                device.destroy_buffer(entry.buffer);
458            }
459        }
460
461        unsafe {
462            device.destroy_compute_pipeline(pipeline);
463            device.destroy_pipeline_layout(pipeline_layout);
464            device.destroy_bind_group_layout(metadata_bind_group_layout);
465            device.destroy_bind_group_layout(src_bind_group_layout);
466            device.destroy_bind_group_layout(dst_bind_group_layout);
467            device.destroy_shader_module(module);
468        }
469    }
470}
471
472fn create_validation_module(
473    device: &dyn hal::DynDevice,
474) -> Result<Box<dyn hal::DynShaderModule>, CreateIndirectValidationPipelineError> {
475    let src = include_str!("./validate_draw.wgsl");
476
477    #[cfg(feature = "wgsl")]
478    let module = naga::front::wgsl::parse_str(src).map_err(|inner| {
479        CreateShaderModuleError::Parsing(naga::error::ShaderError {
480            source: src.to_string(),
481            label: None,
482            inner: Box::new(inner),
483        })
484    })?;
485    #[cfg(not(feature = "wgsl"))]
486    #[allow(clippy::diverging_sub_expression)]
487    let module = panic!("Indirect validation requires the wgsl feature flag to be enabled!");
488
489    let info = crate::device::create_validator(
490        wgt::Features::IMMEDIATES,
491        wgt::DownlevelFlags::empty(),
492        naga::valid::ValidationFlags::all(),
493    )
494    .validate(&module)
495    .map_err(|inner| {
496        CreateShaderModuleError::Validation(naga::error::ShaderError {
497            source: src.to_string(),
498            label: None,
499            inner: Box::new(inner),
500        })
501    })?;
502    let hal_shader = hal::ShaderInput::Naga(hal::NagaShader {
503        module: alloc::borrow::Cow::Owned(module),
504        info,
505        debug_source: None,
506    });
507    let hal_desc = hal::ShaderModuleDescriptor {
508        label: None,
509        runtime_checks: wgt::ShaderRuntimeChecks::unchecked(),
510    };
511    let module = unsafe { device.create_shader_module(&hal_desc, hal_shader) }.map_err(
512        |error| match error {
513            hal::ShaderError::Device(error) => {
514                CreateShaderModuleError::Device(DeviceError::from_hal(error))
515            }
516            hal::ShaderError::Compilation(ref msg) => {
517                log::error!("Shader error: {msg}");
518                CreateShaderModuleError::Generation
519            }
520        },
521    )?;
522
523    Ok(module)
524}
525
526fn create_validation_pipeline(
527    device: &dyn hal::DynDevice,
528    module: &dyn hal::DynShaderModule,
529    pipeline_layout: &dyn hal::DynPipelineLayout,
530    supports_indirect_first_instance: bool,
531    write_d3d12_special_constants: bool,
532) -> Result<Box<dyn hal::DynComputePipeline>, CreateIndirectValidationPipelineError> {
533    let pipeline_desc = hal::ComputePipelineDescriptor {
534        label: None,
535        layout: pipeline_layout,
536        stage: hal::ProgrammableStage {
537            module,
538            entry_point: "main",
539            constants: &hashbrown::HashMap::from([
540                (
541                    "supports_indirect_first_instance".to_string(),
542                    f64::from(supports_indirect_first_instance),
543                ),
544                (
545                    "write_d3d12_special_constants".to_string(),
546                    f64::from(write_d3d12_special_constants),
547                ),
548            ]),
549            zero_initialize_workgroup_memory: false,
550        },
551        cache: None,
552    };
553    let pipeline =
554        unsafe { device.create_compute_pipeline(&pipeline_desc) }.map_err(|err| match err {
555            hal::PipelineError::Device(error) => {
556                CreateComputePipelineError::Device(DeviceError::from_hal(error))
557            }
558            hal::PipelineError::Linkage(_stages, msg) => CreateComputePipelineError::Internal(msg),
559            hal::PipelineError::EntryPoint(_stage) => CreateComputePipelineError::Internal(
560                crate::device::ENTRYPOINT_FAILURE_ERROR.to_string(),
561            ),
562            hal::PipelineError::PipelineConstants(_, error) => {
563                CreateComputePipelineError::PipelineConstants(error)
564            }
565        })?;
566
567    Ok(pipeline)
568}
569
570fn create_bind_group_layout(
571    device: &dyn hal::DynDevice,
572    read_only: bool,
573    has_dynamic_offset: bool,
574    min_binding_size: wgt::BufferSize,
575) -> Result<Box<dyn hal::DynBindGroupLayout>, CreateIndirectValidationPipelineError> {
576    let bind_group_layout_desc = hal::BindGroupLayoutDescriptor {
577        label: None,
578        flags: hal::BindGroupLayoutFlags::empty(),
579        entries: &[wgt::BindGroupLayoutEntry {
580            binding: 0,
581            visibility: wgt::ShaderStages::COMPUTE,
582            ty: wgt::BindingType::Buffer {
583                ty: wgt::BufferBindingType::Storage { read_only },
584                has_dynamic_offset,
585                min_binding_size: Some(min_binding_size),
586            },
587            count: None,
588        }],
589    };
590    let bind_group_layout = unsafe {
591        device
592            .create_bind_group_layout(&bind_group_layout_desc)
593            .map_err(DeviceError::from_hal)?
594    };
595
596    Ok(bind_group_layout)
597}
598
599/// Returns the largest binding size that when combined with dynamic offsets can address the whole buffer.
600fn calculate_src_buffer_binding_size(buffer_size: u64, limits: &Limits) -> u64 {
601    let max_storage_buffer_binding_size = limits.max_storage_buffer_binding_size as u64;
602    let min_storage_buffer_offset_alignment = limits.min_storage_buffer_offset_alignment as u64;
603
604    if buffer_size <= max_storage_buffer_binding_size {
605        buffer_size
606    } else {
607        let buffer_rem = buffer_size % min_storage_buffer_offset_alignment;
608        let binding_rem = max_storage_buffer_binding_size % min_storage_buffer_offset_alignment;
609
610        // Can the buffer remainder fit in the binding remainder?
611        // If so, align max binding size and add buffer remainder
612        if buffer_rem <= binding_rem {
613            max_storage_buffer_binding_size - binding_rem + buffer_rem
614        }
615        // If not, align max binding size, shorten it by a chunk and add buffer remainder
616        else {
617            max_storage_buffer_binding_size - binding_rem - min_storage_buffer_offset_alignment
618                + buffer_rem
619        }
620    }
621}
622
623/// Splits the given `offset` into a dynamic offset & offset.
624fn calculate_src_offsets(buffer_size: u64, limits: &Limits, offset: u64) -> (u64, u64) {
625    let binding_size = calculate_src_buffer_binding_size(buffer_size, limits);
626
627    let min_storage_buffer_offset_alignment = limits.min_storage_buffer_offset_alignment as u64;
628
629    let chunk_adjustment = match min_storage_buffer_offset_alignment {
630        // No need to adjust since the src_offset is 4 byte aligned.
631        4 => 0,
632        // With 16/20 bytes of data we can straddle up to 2 8 byte boundaries:
633        //  - 16 bytes of data: (4|8|4)
634        //  - 20 bytes of data: (4|8|8, 8|8|4)
635        8 => 2,
636        // With 16/20 bytes of data we can straddle up to 1 16+ byte boundary:
637        //  - 16 bytes of data: (4|12, 8|8, 12|4)
638        //  - 20 bytes of data: (4|16, 8|12, 12|8, 16|4)
639        16.. => 1,
640        _ => unreachable!(),
641    };
642
643    let chunks = binding_size / min_storage_buffer_offset_alignment;
644    let dynamic_offset_stride =
645        chunks.saturating_sub(chunk_adjustment) * min_storage_buffer_offset_alignment;
646
647    if dynamic_offset_stride == 0 {
648        return (0, offset);
649    }
650
651    let max_dynamic_offset = buffer_size - binding_size;
652    let max_dynamic_offset_index = max_dynamic_offset / dynamic_offset_stride;
653
654    let src_dynamic_offset_index = offset / dynamic_offset_stride;
655
656    let src_dynamic_offset =
657        src_dynamic_offset_index.min(max_dynamic_offset_index) * dynamic_offset_stride;
658    let src_offset = offset - src_dynamic_offset;
659
660    (src_dynamic_offset, src_offset)
661}
662
663#[derive(Debug)]
664struct BufferPoolEntry {
665    buffer: Box<dyn hal::DynBuffer>,
666    bind_group: Box<dyn hal::DynBindGroup>,
667}
668
669fn create_buffer_and_bind_group(
670    device: &dyn hal::DynDevice,
671    usage: wgt::BufferUses,
672    bind_group_layout: &dyn hal::DynBindGroupLayout,
673) -> Result<BufferPoolEntry, hal::DeviceError> {
674    let buffer_desc = hal::BufferDescriptor {
675        label: None,
676        size: BUFFER_SIZE.get(),
677        usage,
678        memory_flags: hal::MemoryFlags::empty(),
679    };
680    let buffer = unsafe { device.create_buffer(&buffer_desc) }?;
681    let bind_group_desc = hal::BindGroupDescriptor {
682        label: None,
683        layout: bind_group_layout,
684        entries: &[hal::BindGroupEntry {
685            binding: 0,
686            resource_index: 0,
687            count: 1,
688        }],
689        // SAFETY: We just created the buffer with this size.
690        buffers: &[hal::BufferBinding::new_unchecked(
691            buffer.as_ref(),
692            0,
693            BUFFER_SIZE,
694        )],
695        samplers: &[],
696        textures: &[],
697        acceleration_structures: &[],
698        external_textures: &[],
699    };
700    let bind_group = unsafe { device.create_bind_group(&bind_group_desc) }?;
701    Ok(BufferPoolEntry { buffer, bind_group })
702}
703
704#[derive(Clone)]
705struct CurrentEntry {
706    index: usize,
707    offset: u64,
708}
709
710/// Holds all command buffer-level resources that are needed to validate indirect draws.
711pub(crate) struct DrawResources {
712    device: Arc<Device>,
713    dst_entries: Vec<BufferPoolEntry>,
714    metadata_entries: Vec<BufferPoolEntry>,
715}
716
717impl Drop for DrawResources {
718    fn drop(&mut self) {
719        if let Some(ref indirect_validation) = self.device.indirect_validation {
720            let indirect_draw_validation = &indirect_validation.draw;
721            indirect_draw_validation.release_dst_entries(self.dst_entries.drain(..));
722            indirect_draw_validation.release_metadata_entries(self.metadata_entries.drain(..));
723        }
724    }
725}
726
727impl DrawResources {
728    pub(crate) fn new(device: Arc<Device>) -> Self {
729        DrawResources {
730            device,
731            dst_entries: Vec::new(),
732            metadata_entries: Vec::new(),
733        }
734    }
735
736    pub(crate) fn get_dst_buffer(&self, index: usize) -> &dyn hal::DynBuffer {
737        self.dst_entries.get(index).unwrap().buffer.as_ref()
738    }
739
740    fn get_dst_bind_group(&self, index: usize) -> &dyn hal::DynBindGroup {
741        self.dst_entries.get(index).unwrap().bind_group.as_ref()
742    }
743
744    fn get_metadata_buffer(&self, index: usize) -> &dyn hal::DynBuffer {
745        self.metadata_entries.get(index).unwrap().buffer.as_ref()
746    }
747
748    fn get_metadata_bind_group(&self, index: usize) -> &dyn hal::DynBindGroup {
749        self.metadata_entries
750            .get(index)
751            .unwrap()
752            .bind_group
753            .as_ref()
754    }
755
756    fn get_dst_subrange(
757        &mut self,
758        size: u64,
759        current_entry: &mut Option<CurrentEntry>,
760    ) -> Result<(usize, u64), DeviceError> {
761        let indirect_draw_validation = &self.device.indirect_validation.as_ref().unwrap().draw;
762        let ensure_entry = |index: usize| {
763            if self.dst_entries.len() <= index {
764                let entry = indirect_draw_validation.acquire_dst_entry(self.device.raw())?;
765                self.dst_entries.push(entry);
766            }
767            Ok(())
768        };
769        let entry_data = Self::get_subrange_impl(ensure_entry, current_entry, size)?;
770        Ok((entry_data.index, entry_data.offset))
771    }
772
773    fn get_metadata_subrange(
774        &mut self,
775        size: u64,
776        current_entry: &mut Option<CurrentEntry>,
777    ) -> Result<(usize, u64), DeviceError> {
778        let indirect_draw_validation = &self.device.indirect_validation.as_ref().unwrap().draw;
779        let ensure_entry = |index: usize| {
780            if self.metadata_entries.len() <= index {
781                let entry = indirect_draw_validation.acquire_metadata_entry(self.device.raw())?;
782                self.metadata_entries.push(entry);
783            }
784            Ok(())
785        };
786        let entry_data = Self::get_subrange_impl(ensure_entry, current_entry, size)?;
787        Ok((entry_data.index, entry_data.offset))
788    }
789
790    fn get_subrange_impl(
791        ensure_entry: impl FnOnce(usize) -> Result<(), hal::DeviceError>,
792        current_entry: &mut Option<CurrentEntry>,
793        size: u64,
794    ) -> Result<CurrentEntry, DeviceError> {
795        let index = if let Some(current_entry) = current_entry.as_mut() {
796            if current_entry.offset + size <= BUFFER_SIZE.get() {
797                let entry_data = current_entry.clone();
798                current_entry.offset += size;
799                return Ok(entry_data);
800            } else {
801                current_entry.index + 1
802            }
803        } else {
804            0
805        };
806
807        ensure_entry(index).map_err(DeviceError::from_hal)?;
808
809        let entry_data = CurrentEntry { index, offset: 0 };
810
811        *current_entry = Some(CurrentEntry {
812            index,
813            offset: size,
814        });
815
816        Ok(entry_data)
817    }
818}
819
820/// This must match the `MetadataEntry` struct used by the shader.
821#[repr(C)]
822struct MetadataEntry {
823    src_offset: u32,
824    dst_offset: u32,
825    vertex_or_index_limit: u32,
826    instance_limit: u32,
827}
828
829impl MetadataEntry {
830    fn new(
831        indexed: bool,
832        src_offset: u64,
833        dst_offset: u64,
834        vertex_or_index_limit: u64,
835        instance_limit: u64,
836    ) -> Self {
837        debug_assert_eq!(
838            4,
839            size_of_val(&Limits::default().max_storage_buffer_binding_size)
840        );
841
842        let src_offset = src_offset as u32; // max_storage_buffer_binding_size is a u32
843        let src_offset = src_offset / 4; // translate byte offset to offset in u32's
844
845        // `src_offset` needs at most 30 bits,
846        // pack `indexed` in bit 31 of `src_offset`
847        let src_offset = src_offset | ((indexed as u32) << 31);
848
849        // max value for limits since first_X and X_count indirect draw arguments are u32
850        let max_limit = u32::MAX as u64 + u32::MAX as u64; // 1 11111111 11111111 11111111 11111110
851
852        let vertex_or_index_limit = vertex_or_index_limit.min(max_limit);
853        let vertex_or_index_limit_bit_32 = (vertex_or_index_limit >> 32) as u32; // extract bit 32
854        let vertex_or_index_limit = vertex_or_index_limit as u32; // truncate the limit to a u32
855
856        let instance_limit = instance_limit.min(max_limit);
857        let instance_limit_bit_32 = (instance_limit >> 32) as u32; // extract bit 32
858        let instance_limit = instance_limit as u32; // truncate the limit to a u32
859
860        let dst_offset = dst_offset as u32; // max_storage_buffer_binding_size is a u32
861        let dst_offset = dst_offset / 4; // translate byte offset to offset in u32's
862
863        // `dst_offset` needs at most 30 bits,
864        // pack `vertex_or_index_limit_bit_32` in bit 30 of `dst_offset` and
865        // pack `instance_limit_bit_32` in bit 31 of `dst_offset`
866        let dst_offset =
867            dst_offset | (vertex_or_index_limit_bit_32 << 30) | (instance_limit_bit_32 << 31);
868
869        Self {
870            src_offset,
871            dst_offset,
872            vertex_or_index_limit,
873            instance_limit,
874        }
875    }
876}
877
878struct DrawIndirectValidationBatch {
879    src_buffer: Arc<crate::resource::Buffer>,
880    src_dynamic_offset: u64,
881    dst_resource_index: usize,
882    entries: Vec<MetadataEntry>,
883
884    staging_buffer_index: usize,
885    staging_buffer_offset: u64,
886    metadata_resource_index: usize,
887    metadata_buffer_offset: u64,
888}
889
890impl DrawIndirectValidationBatch {
891    /// Data to be written to the metadata buffer.
892    fn metadata(&self) -> &[u8] {
893        unsafe {
894            core::slice::from_raw_parts(
895                self.entries.as_ptr().cast::<u8>(),
896                self.entries.len() * size_of::<MetadataEntry>(),
897            )
898        }
899    }
900}
901
902/// Accumulates all needed data needed to validate indirect draws.
903pub(crate) struct DrawBatcher {
904    batches: FastHashMap<(TrackerIndex, u64, usize), DrawIndirectValidationBatch>,
905    current_dst_entry: Option<CurrentEntry>,
906}
907
908impl DrawBatcher {
909    pub(crate) fn new() -> Self {
910        Self {
911            batches: FastHashMap::default(),
912            current_dst_entry: None,
913        }
914    }
915
916    /// Add an indirect draw to be validated.
917    ///
918    /// Returns the index of the indirect buffer in `indirect_draw_validation_resources`
919    /// and the offset to be used for the draw.
920    pub(crate) fn add<'a>(
921        &mut self,
922        indirect_draw_validation_resources: &'a mut DrawResources,
923        device: &Device,
924        src_buffer: &Arc<crate::resource::Buffer>,
925        offset: u64,
926        family: crate::command::DrawCommandFamily,
927        vertex_or_index_limit: u64,
928        instance_limit: u64,
929    ) -> Result<(usize, u64), DeviceError> {
930        // space for D3D12 special constants
931        let extra = if device.backend() == wgt::Backend::Dx12 {
932            3 * size_of::<u32>() as u64
933        } else {
934            0
935        };
936        let stride = extra + crate::command::get_stride_of_indirect_args(family);
937
938        let (dst_resource_index, dst_offset) = indirect_draw_validation_resources
939            .get_dst_subrange(stride, &mut self.current_dst_entry)?;
940
941        let buffer_size = src_buffer.size;
942        let limits = device.adapter.limits();
943        let (src_dynamic_offset, src_offset) = calculate_src_offsets(buffer_size, &limits, offset);
944
945        let src_buffer_tracker_index = src_buffer.tracker_index();
946
947        let entry = MetadataEntry::new(
948            family == crate::command::DrawCommandFamily::DrawIndexed,
949            src_offset,
950            dst_offset,
951            vertex_or_index_limit,
952            instance_limit,
953        );
954
955        match self.batches.entry((
956            src_buffer_tracker_index,
957            src_dynamic_offset,
958            dst_resource_index,
959        )) {
960            hashbrown::hash_map::Entry::Occupied(mut occupied_entry) => {
961                occupied_entry.get_mut().entries.push(entry)
962            }
963            hashbrown::hash_map::Entry::Vacant(vacant_entry) => {
964                vacant_entry.insert(DrawIndirectValidationBatch {
965                    src_buffer: src_buffer.clone(),
966                    src_dynamic_offset,
967                    dst_resource_index,
968                    entries: vec![entry],
969
970                    // these will be initialized once we accumulated all entries for the batch
971                    staging_buffer_index: 0,
972                    staging_buffer_offset: 0,
973                    metadata_resource_index: 0,
974                    metadata_buffer_offset: 0,
975                });
976            }
977        }
978
979        Ok((dst_resource_index, dst_offset))
980    }
981}