wgpu_core/indirect_validation/
draw.rs

1use super::{
2    utils::{BufferBarrierScratch, BufferBarriers, UniqueIndexExt as _, UniqueIndexScratch},
3    CreateIndirectValidationPipelineError,
4};
5use crate::{
6    device::{queue::TempResource, Device, DeviceError},
7    lock::{rank, Mutex},
8    pipeline::{CreateComputePipelineError, CreateShaderModuleError},
9    resource::{StagingBuffer, Trackable},
10    snatch::SnatchGuard,
11    track::TrackerIndex,
12    FastHashMap,
13};
14use alloc::{boxed::Box, string::ToString, sync::Arc, vec, vec::Vec};
15use core::{
16    mem::{size_of, size_of_val},
17    num::NonZeroU64,
18};
19use wgt::Limits;
20
21/// Note: This needs to be under:
22///
23/// default max_compute_workgroups_per_dimension * size_of::<wgt::DrawIndirectArgs>() * `workgroup_size` used by the shader
24///
25/// = (2^16 - 1) * 2^4 * 2^6
26///
27/// It is currently set to:
28///
29/// = (2^16 - 1) * 2^4
30///
31/// This is enough space for:
32///
33/// - 65535 [`wgt::DrawIndirectArgs`] / [`MetadataEntry`]
34/// - 52428 [`wgt::DrawIndexedIndirectArgs`]
35const BUFFER_SIZE: wgt::BufferSize = unsafe { wgt::BufferSize::new_unchecked(1_048_560) };
36
37/// Holds all device-level resources that are needed to validate indirect draws.
38///
39/// This machinery requires the following limits:
40///
41/// - max_bind_groups: 3,
42/// - max_dynamic_storage_buffers_per_pipeline_layout: 1,
43/// - max_storage_buffers_per_shader_stage: 3,
44/// - max_push_constant_size: 8,
45///
46/// These are all indirectly satisfied by `DownlevelFlags::INDIRECT_EXECUTION`, which is also
47/// required for this module's functionality to work.
48#[derive(Debug)]
49pub(crate) struct Draw {
50    module: Box<dyn hal::DynShaderModule>,
51    metadata_bind_group_layout: Box<dyn hal::DynBindGroupLayout>,
52    src_bind_group_layout: Box<dyn hal::DynBindGroupLayout>,
53    dst_bind_group_layout: Box<dyn hal::DynBindGroupLayout>,
54    pipeline_layout: Box<dyn hal::DynPipelineLayout>,
55    pipeline: Box<dyn hal::DynComputePipeline>,
56
57    free_indirect_entries: Mutex<Vec<BufferPoolEntry>>,
58    free_metadata_entries: Mutex<Vec<BufferPoolEntry>>,
59}
60
61impl Draw {
62    pub(super) fn new(
63        device: &dyn hal::DynDevice,
64        required_features: &wgt::Features,
65        backend: wgt::Backend,
66    ) -> Result<Self, CreateIndirectValidationPipelineError> {
67        let module = create_validation_module(device)?;
68
69        let metadata_bind_group_layout =
70            create_bind_group_layout(device, true, false, BUFFER_SIZE)?;
71        let src_bind_group_layout =
72            create_bind_group_layout(device, true, true, wgt::BufferSize::new(4 * 4).unwrap())?;
73        let dst_bind_group_layout = create_bind_group_layout(device, false, false, BUFFER_SIZE)?;
74
75        let pipeline_layout_desc = hal::PipelineLayoutDescriptor {
76            label: None,
77            flags: hal::PipelineLayoutFlags::empty(),
78            bind_group_layouts: &[
79                metadata_bind_group_layout.as_ref(),
80                src_bind_group_layout.as_ref(),
81                dst_bind_group_layout.as_ref(),
82            ],
83            push_constant_ranges: &[wgt::PushConstantRange {
84                stages: wgt::ShaderStages::COMPUTE,
85                range: 0..8,
86            }],
87        };
88        let pipeline_layout = unsafe {
89            device
90                .create_pipeline_layout(&pipeline_layout_desc)
91                .map_err(DeviceError::from_hal)?
92        };
93
94        let supports_indirect_first_instance =
95            required_features.contains(wgt::Features::INDIRECT_FIRST_INSTANCE);
96        let write_d3d12_special_constants = backend == wgt::Backend::Dx12;
97        let pipeline = create_validation_pipeline(
98            device,
99            module.as_ref(),
100            pipeline_layout.as_ref(),
101            supports_indirect_first_instance,
102            write_d3d12_special_constants,
103        )?;
104
105        Ok(Self {
106            module,
107            metadata_bind_group_layout,
108            src_bind_group_layout,
109            dst_bind_group_layout,
110            pipeline_layout,
111            pipeline,
112
113            free_indirect_entries: Mutex::new(rank::BUFFER_POOL, Vec::new()),
114            free_metadata_entries: Mutex::new(rank::BUFFER_POOL, Vec::new()),
115        })
116    }
117
118    /// `Ok(None)` will only be returned if `buffer_size` is `0`.
119    pub(super) fn create_src_bind_group(
120        &self,
121        device: &dyn hal::DynDevice,
122        limits: &Limits,
123        buffer_size: u64,
124        buffer: &dyn hal::DynBuffer,
125    ) -> Result<Option<Box<dyn hal::DynBindGroup>>, DeviceError> {
126        let binding_size = calculate_src_buffer_binding_size(buffer_size, limits);
127        let Some(binding_size) = NonZeroU64::new(binding_size) else {
128            return Ok(None);
129        };
130        let hal_desc = hal::BindGroupDescriptor {
131            label: None,
132            layout: self.src_bind_group_layout.as_ref(),
133            entries: &[hal::BindGroupEntry {
134                binding: 0,
135                resource_index: 0,
136                count: 1,
137            }],
138            // SAFETY: We calculated the binding size to fit within the buffer.
139            buffers: &[hal::BufferBinding::new_unchecked(buffer, 0, binding_size)],
140            samplers: &[],
141            textures: &[],
142            acceleration_structures: &[],
143            external_textures: &[],
144        };
145        unsafe {
146            device
147                .create_bind_group(&hal_desc)
148                .map(Some)
149                .map_err(DeviceError::from_hal)
150        }
151    }
152
153    fn acquire_dst_entry(
154        &self,
155        device: &dyn hal::DynDevice,
156    ) -> Result<BufferPoolEntry, hal::DeviceError> {
157        let mut free_buffers = self.free_indirect_entries.lock();
158        match free_buffers.pop() {
159            Some(buffer) => Ok(buffer),
160            None => {
161                let usage = wgt::BufferUses::INDIRECT | wgt::BufferUses::STORAGE_READ_WRITE;
162                create_buffer_and_bind_group(device, usage, self.dst_bind_group_layout.as_ref())
163            }
164        }
165    }
166
167    fn release_dst_entries(&self, entries: impl Iterator<Item = BufferPoolEntry>) {
168        self.free_indirect_entries.lock().extend(entries);
169    }
170
171    fn acquire_metadata_entry(
172        &self,
173        device: &dyn hal::DynDevice,
174    ) -> Result<BufferPoolEntry, hal::DeviceError> {
175        let mut free_buffers = self.free_metadata_entries.lock();
176        match free_buffers.pop() {
177            Some(buffer) => Ok(buffer),
178            None => {
179                let usage = wgt::BufferUses::COPY_DST | wgt::BufferUses::STORAGE_READ_ONLY;
180                create_buffer_and_bind_group(
181                    device,
182                    usage,
183                    self.metadata_bind_group_layout.as_ref(),
184                )
185            }
186        }
187    }
188
189    fn release_metadata_entries(&self, entries: impl Iterator<Item = BufferPoolEntry>) {
190        self.free_metadata_entries.lock().extend(entries);
191    }
192
193    /// Injects a compute pass that will validate all indirect draws in the current render pass.
194    pub(crate) fn inject_validation_pass(
195        &self,
196        device: &Arc<Device>,
197        snatch_guard: &SnatchGuard,
198        resources: &mut DrawResources,
199        temp_resources: &mut Vec<TempResource>,
200        encoder: &mut dyn hal::DynCommandEncoder,
201        batcher: DrawBatcher,
202    ) -> Result<(), DeviceError> {
203        let mut batches = batcher.batches;
204
205        if batches.is_empty() {
206            return Ok(());
207        }
208
209        let max_staging_buffer_size = 1 << 26; // ~67MiB
210
211        let mut staging_buffers = Vec::new();
212
213        let mut current_size = 0;
214        for batch in batches.values_mut() {
215            let data = batch.metadata();
216            let offset = if current_size + data.len() > max_staging_buffer_size {
217                let staging_buffer =
218                    StagingBuffer::new(device, NonZeroU64::new(current_size as u64).unwrap())?;
219                staging_buffers.push(staging_buffer);
220                current_size = data.len();
221                0
222            } else {
223                let offset = current_size;
224                current_size += data.len();
225                offset as u64
226            };
227            batch.staging_buffer_index = staging_buffers.len();
228            batch.staging_buffer_offset = offset;
229        }
230        if current_size != 0 {
231            let staging_buffer =
232                StagingBuffer::new(device, NonZeroU64::new(current_size as u64).unwrap())?;
233            staging_buffers.push(staging_buffer);
234        }
235
236        for batch in batches.values() {
237            let data = batch.metadata();
238            let staging_buffer = &mut staging_buffers[batch.staging_buffer_index];
239            unsafe {
240                staging_buffer.write_with_offset(
241                    data,
242                    0,
243                    batch.staging_buffer_offset as isize,
244                    data.len(),
245                )
246            };
247        }
248
249        let staging_buffers: Vec<_> = staging_buffers
250            .into_iter()
251            .map(|buffer| buffer.flush())
252            .collect();
253
254        let mut current_metadata_entry = None;
255        for batch in batches.values_mut() {
256            let data = batch.metadata();
257            let (metadata_resource_index, metadata_buffer_offset) =
258                resources.get_metadata_subrange(data.len() as u64, &mut current_metadata_entry)?;
259            batch.metadata_resource_index = metadata_resource_index;
260            batch.metadata_buffer_offset = metadata_buffer_offset;
261        }
262
263        let buffer_barrier_scratch = &mut BufferBarrierScratch::new();
264        let unique_index_scratch = &mut UniqueIndexScratch::new();
265
266        BufferBarriers::new(buffer_barrier_scratch)
267            .extend(
268                batches
269                    .values()
270                    .map(|batch| batch.staging_buffer_index)
271                    .unique(unique_index_scratch)
272                    .map(|index| hal::BufferBarrier {
273                        buffer: staging_buffers[index].raw(),
274                        usage: hal::StateTransition {
275                            from: wgt::BufferUses::MAP_WRITE,
276                            to: wgt::BufferUses::COPY_SRC,
277                        },
278                    }),
279            )
280            .extend(
281                batches
282                    .values()
283                    .map(|batch| batch.metadata_resource_index)
284                    .unique(unique_index_scratch)
285                    .map(|index| hal::BufferBarrier {
286                        buffer: resources.get_metadata_buffer(index),
287                        usage: hal::StateTransition {
288                            from: wgt::BufferUses::STORAGE_READ_ONLY,
289                            to: wgt::BufferUses::COPY_DST,
290                        },
291                    }),
292            )
293            .encode(encoder);
294
295        for batch in batches.values() {
296            let data = batch.metadata();
297            let data_size = NonZeroU64::new(data.len() as u64).unwrap();
298
299            let staging_buffer = &staging_buffers[batch.staging_buffer_index];
300
301            let metadata_buffer = resources.get_metadata_buffer(batch.metadata_resource_index);
302
303            unsafe {
304                encoder.copy_buffer_to_buffer(
305                    staging_buffer.raw(),
306                    metadata_buffer,
307                    &[hal::BufferCopy {
308                        src_offset: batch.staging_buffer_offset,
309                        dst_offset: batch.metadata_buffer_offset,
310                        size: data_size,
311                    }],
312                );
313            }
314        }
315
316        for staging_buffer in staging_buffers {
317            temp_resources.push(TempResource::StagingBuffer(staging_buffer));
318        }
319
320        BufferBarriers::new(buffer_barrier_scratch)
321            .extend(
322                batches
323                    .values()
324                    .map(|batch| batch.metadata_resource_index)
325                    .unique(unique_index_scratch)
326                    .map(|index| hal::BufferBarrier {
327                        buffer: resources.get_metadata_buffer(index),
328                        usage: hal::StateTransition {
329                            from: wgt::BufferUses::COPY_DST,
330                            to: wgt::BufferUses::STORAGE_READ_ONLY,
331                        },
332                    }),
333            )
334            .extend(
335                batches
336                    .values()
337                    .map(|batch| batch.dst_resource_index)
338                    .unique(unique_index_scratch)
339                    .map(|index| hal::BufferBarrier {
340                        buffer: resources.get_dst_buffer(index),
341                        usage: hal::StateTransition {
342                            from: wgt::BufferUses::INDIRECT,
343                            to: wgt::BufferUses::STORAGE_READ_WRITE,
344                        },
345                    }),
346            )
347            .encode(encoder);
348
349        let desc = hal::ComputePassDescriptor {
350            label: None,
351            timestamp_writes: None,
352        };
353        unsafe {
354            encoder.begin_compute_pass(&desc);
355        }
356        unsafe {
357            encoder.set_compute_pipeline(self.pipeline.as_ref());
358        }
359
360        for batch in batches.values() {
361            let pipeline_layout = self.pipeline_layout.as_ref();
362
363            let metadata_start =
364                (batch.metadata_buffer_offset / size_of::<MetadataEntry>() as u64) as u32;
365            let metadata_count = batch.entries.len() as u32;
366            unsafe {
367                encoder.set_push_constants(
368                    pipeline_layout,
369                    wgt::ShaderStages::COMPUTE,
370                    0,
371                    &[metadata_start, metadata_count],
372                );
373            }
374
375            let metadata_bind_group =
376                resources.get_metadata_bind_group(batch.metadata_resource_index);
377            unsafe {
378                encoder.set_bind_group(pipeline_layout, 0, Some(metadata_bind_group), &[]);
379            }
380
381            let src_bind_group = batch
382                .src_buffer
383                .indirect_validation_bind_groups
384                .get(snatch_guard)
385                .unwrap()
386                .draw
387                .as_ref();
388            unsafe {
389                encoder.set_bind_group(
390                    pipeline_layout,
391                    1,
392                    Some(src_bind_group),
393                    &[batch.src_dynamic_offset as u32],
394                );
395            }
396
397            let dst_bind_group = resources.get_dst_bind_group(batch.dst_resource_index);
398            unsafe {
399                encoder.set_bind_group(pipeline_layout, 2, Some(dst_bind_group), &[]);
400            }
401
402            unsafe {
403                encoder.dispatch([(batch.entries.len() as u32).div_ceil(64), 1, 1]);
404            }
405        }
406
407        unsafe {
408            encoder.end_compute_pass();
409        }
410
411        BufferBarriers::new(buffer_barrier_scratch)
412            .extend(
413                batches
414                    .values()
415                    .map(|batch| batch.dst_resource_index)
416                    .unique(unique_index_scratch)
417                    .map(|index| hal::BufferBarrier {
418                        buffer: resources.get_dst_buffer(index),
419                        usage: hal::StateTransition {
420                            from: wgt::BufferUses::STORAGE_READ_WRITE,
421                            to: wgt::BufferUses::INDIRECT,
422                        },
423                    }),
424            )
425            .encode(encoder);
426
427        Ok(())
428    }
429
430    pub(super) fn dispose(self, device: &dyn hal::DynDevice) {
431        let Draw {
432            module,
433            metadata_bind_group_layout,
434            src_bind_group_layout,
435            dst_bind_group_layout,
436            pipeline_layout,
437            pipeline,
438
439            free_indirect_entries,
440            free_metadata_entries,
441        } = self;
442
443        for entry in free_indirect_entries.into_inner().drain(..) {
444            unsafe {
445                device.destroy_bind_group(entry.bind_group);
446                device.destroy_buffer(entry.buffer);
447            }
448        }
449
450        for entry in free_metadata_entries.into_inner().drain(..) {
451            unsafe {
452                device.destroy_bind_group(entry.bind_group);
453                device.destroy_buffer(entry.buffer);
454            }
455        }
456
457        unsafe {
458            device.destroy_compute_pipeline(pipeline);
459            device.destroy_pipeline_layout(pipeline_layout);
460            device.destroy_bind_group_layout(metadata_bind_group_layout);
461            device.destroy_bind_group_layout(src_bind_group_layout);
462            device.destroy_bind_group_layout(dst_bind_group_layout);
463            device.destroy_shader_module(module);
464        }
465    }
466}
467
468fn create_validation_module(
469    device: &dyn hal::DynDevice,
470) -> Result<Box<dyn hal::DynShaderModule>, CreateIndirectValidationPipelineError> {
471    let src = include_str!("./validate_draw.wgsl");
472
473    #[cfg(feature = "wgsl")]
474    let module = naga::front::wgsl::parse_str(src).map_err(|inner| {
475        CreateShaderModuleError::Parsing(naga::error::ShaderError {
476            source: src.to_string(),
477            label: None,
478            inner: Box::new(inner),
479        })
480    })?;
481    #[cfg(not(feature = "wgsl"))]
482    #[allow(clippy::diverging_sub_expression)]
483    let module = panic!("Indirect validation requires the wgsl feature flag to be enabled!");
484
485    let info = crate::device::create_validator(
486        wgt::Features::PUSH_CONSTANTS,
487        wgt::DownlevelFlags::empty(),
488        naga::valid::ValidationFlags::all(),
489    )
490    .validate(&module)
491    .map_err(|inner| {
492        CreateShaderModuleError::Validation(naga::error::ShaderError {
493            source: src.to_string(),
494            label: None,
495            inner: Box::new(inner),
496        })
497    })?;
498    let hal_shader = hal::ShaderInput::Naga(hal::NagaShader {
499        module: alloc::borrow::Cow::Owned(module),
500        info,
501        debug_source: None,
502    });
503    let hal_desc = hal::ShaderModuleDescriptor {
504        label: None,
505        runtime_checks: wgt::ShaderRuntimeChecks::unchecked(),
506    };
507    let module = unsafe { device.create_shader_module(&hal_desc, hal_shader) }.map_err(
508        |error| match error {
509            hal::ShaderError::Device(error) => {
510                CreateShaderModuleError::Device(DeviceError::from_hal(error))
511            }
512            hal::ShaderError::Compilation(ref msg) => {
513                log::error!("Shader error: {msg}");
514                CreateShaderModuleError::Generation
515            }
516        },
517    )?;
518
519    Ok(module)
520}
521
522fn create_validation_pipeline(
523    device: &dyn hal::DynDevice,
524    module: &dyn hal::DynShaderModule,
525    pipeline_layout: &dyn hal::DynPipelineLayout,
526    supports_indirect_first_instance: bool,
527    write_d3d12_special_constants: bool,
528) -> Result<Box<dyn hal::DynComputePipeline>, CreateIndirectValidationPipelineError> {
529    let pipeline_desc = hal::ComputePipelineDescriptor {
530        label: None,
531        layout: pipeline_layout,
532        stage: hal::ProgrammableStage {
533            module,
534            entry_point: "main",
535            constants: &hashbrown::HashMap::from([
536                (
537                    "supports_indirect_first_instance".to_string(),
538                    f64::from(supports_indirect_first_instance),
539                ),
540                (
541                    "write_d3d12_special_constants".to_string(),
542                    f64::from(write_d3d12_special_constants),
543                ),
544            ]),
545            zero_initialize_workgroup_memory: false,
546        },
547        cache: None,
548    };
549    let pipeline =
550        unsafe { device.create_compute_pipeline(&pipeline_desc) }.map_err(|err| match err {
551            hal::PipelineError::Device(error) => {
552                CreateComputePipelineError::Device(DeviceError::from_hal(error))
553            }
554            hal::PipelineError::Linkage(_stages, msg) => CreateComputePipelineError::Internal(msg),
555            hal::PipelineError::EntryPoint(_stage) => CreateComputePipelineError::Internal(
556                crate::device::ENTRYPOINT_FAILURE_ERROR.to_string(),
557            ),
558            hal::PipelineError::PipelineConstants(_, error) => {
559                CreateComputePipelineError::PipelineConstants(error)
560            }
561        })?;
562
563    Ok(pipeline)
564}
565
566fn create_bind_group_layout(
567    device: &dyn hal::DynDevice,
568    read_only: bool,
569    has_dynamic_offset: bool,
570    min_binding_size: wgt::BufferSize,
571) -> Result<Box<dyn hal::DynBindGroupLayout>, CreateIndirectValidationPipelineError> {
572    let bind_group_layout_desc = hal::BindGroupLayoutDescriptor {
573        label: None,
574        flags: hal::BindGroupLayoutFlags::empty(),
575        entries: &[wgt::BindGroupLayoutEntry {
576            binding: 0,
577            visibility: wgt::ShaderStages::COMPUTE,
578            ty: wgt::BindingType::Buffer {
579                ty: wgt::BufferBindingType::Storage { read_only },
580                has_dynamic_offset,
581                min_binding_size: Some(min_binding_size),
582            },
583            count: None,
584        }],
585    };
586    let bind_group_layout = unsafe {
587        device
588            .create_bind_group_layout(&bind_group_layout_desc)
589            .map_err(DeviceError::from_hal)?
590    };
591
592    Ok(bind_group_layout)
593}
594
595/// Returns the largest binding size that when combined with dynamic offsets can address the whole buffer.
596fn calculate_src_buffer_binding_size(buffer_size: u64, limits: &Limits) -> u64 {
597    let max_storage_buffer_binding_size = limits.max_storage_buffer_binding_size as u64;
598    let min_storage_buffer_offset_alignment = limits.min_storage_buffer_offset_alignment as u64;
599
600    if buffer_size <= max_storage_buffer_binding_size {
601        buffer_size
602    } else {
603        let buffer_rem = buffer_size % min_storage_buffer_offset_alignment;
604        let binding_rem = max_storage_buffer_binding_size % min_storage_buffer_offset_alignment;
605
606        // Can the buffer remainder fit in the binding remainder?
607        // If so, align max binding size and add buffer remainder
608        if buffer_rem <= binding_rem {
609            max_storage_buffer_binding_size - binding_rem + buffer_rem
610        }
611        // If not, align max binding size, shorten it by a chunk and add buffer remainder
612        else {
613            max_storage_buffer_binding_size - binding_rem - min_storage_buffer_offset_alignment
614                + buffer_rem
615        }
616    }
617}
618
619/// Splits the given `offset` into a dynamic offset & offset.
620fn calculate_src_offsets(buffer_size: u64, limits: &Limits, offset: u64) -> (u64, u64) {
621    let binding_size = calculate_src_buffer_binding_size(buffer_size, limits);
622
623    let min_storage_buffer_offset_alignment = limits.min_storage_buffer_offset_alignment as u64;
624
625    let chunk_adjustment = match min_storage_buffer_offset_alignment {
626        // No need to adjust since the src_offset is 4 byte aligned.
627        4 => 0,
628        // With 16/20 bytes of data we can straddle up to 2 8 byte boundaries:
629        //  - 16 bytes of data: (4|8|4)
630        //  - 20 bytes of data: (4|8|8, 8|8|4)
631        8 => 2,
632        // With 16/20 bytes of data we can straddle up to 1 16+ byte boundary:
633        //  - 16 bytes of data: (4|12, 8|8, 12|4)
634        //  - 20 bytes of data: (4|16, 8|12, 12|8, 16|4)
635        16.. => 1,
636        _ => unreachable!(),
637    };
638
639    let chunks = binding_size / min_storage_buffer_offset_alignment;
640    let dynamic_offset_stride =
641        chunks.saturating_sub(chunk_adjustment) * min_storage_buffer_offset_alignment;
642
643    if dynamic_offset_stride == 0 {
644        return (0, offset);
645    }
646
647    let max_dynamic_offset = buffer_size - binding_size;
648    let max_dynamic_offset_index = max_dynamic_offset / dynamic_offset_stride;
649
650    let src_dynamic_offset_index = offset / dynamic_offset_stride;
651
652    let src_dynamic_offset =
653        src_dynamic_offset_index.min(max_dynamic_offset_index) * dynamic_offset_stride;
654    let src_offset = offset - src_dynamic_offset;
655
656    (src_dynamic_offset, src_offset)
657}
658
659#[derive(Debug)]
660struct BufferPoolEntry {
661    buffer: Box<dyn hal::DynBuffer>,
662    bind_group: Box<dyn hal::DynBindGroup>,
663}
664
665fn create_buffer_and_bind_group(
666    device: &dyn hal::DynDevice,
667    usage: wgt::BufferUses,
668    bind_group_layout: &dyn hal::DynBindGroupLayout,
669) -> Result<BufferPoolEntry, hal::DeviceError> {
670    let buffer_desc = hal::BufferDescriptor {
671        label: None,
672        size: BUFFER_SIZE.get(),
673        usage,
674        memory_flags: hal::MemoryFlags::empty(),
675    };
676    let buffer = unsafe { device.create_buffer(&buffer_desc) }?;
677    let bind_group_desc = hal::BindGroupDescriptor {
678        label: None,
679        layout: bind_group_layout,
680        entries: &[hal::BindGroupEntry {
681            binding: 0,
682            resource_index: 0,
683            count: 1,
684        }],
685        // SAFETY: We just created the buffer with this size.
686        buffers: &[hal::BufferBinding::new_unchecked(
687            buffer.as_ref(),
688            0,
689            BUFFER_SIZE,
690        )],
691        samplers: &[],
692        textures: &[],
693        acceleration_structures: &[],
694        external_textures: &[],
695    };
696    let bind_group = unsafe { device.create_bind_group(&bind_group_desc) }?;
697    Ok(BufferPoolEntry { buffer, bind_group })
698}
699
700#[derive(Clone)]
701struct CurrentEntry {
702    index: usize,
703    offset: u64,
704}
705
706/// Holds all command buffer-level resources that are needed to validate indirect draws.
707pub(crate) struct DrawResources {
708    device: Arc<Device>,
709    dst_entries: Vec<BufferPoolEntry>,
710    metadata_entries: Vec<BufferPoolEntry>,
711}
712
713impl Drop for DrawResources {
714    fn drop(&mut self) {
715        if let Some(ref indirect_validation) = self.device.indirect_validation {
716            let indirect_draw_validation = &indirect_validation.draw;
717            indirect_draw_validation.release_dst_entries(self.dst_entries.drain(..));
718            indirect_draw_validation.release_metadata_entries(self.metadata_entries.drain(..));
719        }
720    }
721}
722
723impl DrawResources {
724    pub(crate) fn new(device: Arc<Device>) -> Self {
725        DrawResources {
726            device,
727            dst_entries: Vec::new(),
728            metadata_entries: Vec::new(),
729        }
730    }
731
732    pub(crate) fn get_dst_buffer(&self, index: usize) -> &dyn hal::DynBuffer {
733        self.dst_entries.get(index).unwrap().buffer.as_ref()
734    }
735
736    fn get_dst_bind_group(&self, index: usize) -> &dyn hal::DynBindGroup {
737        self.dst_entries.get(index).unwrap().bind_group.as_ref()
738    }
739
740    fn get_metadata_buffer(&self, index: usize) -> &dyn hal::DynBuffer {
741        self.metadata_entries.get(index).unwrap().buffer.as_ref()
742    }
743
744    fn get_metadata_bind_group(&self, index: usize) -> &dyn hal::DynBindGroup {
745        self.metadata_entries
746            .get(index)
747            .unwrap()
748            .bind_group
749            .as_ref()
750    }
751
752    fn get_dst_subrange(
753        &mut self,
754        size: u64,
755        current_entry: &mut Option<CurrentEntry>,
756    ) -> Result<(usize, u64), DeviceError> {
757        let indirect_draw_validation = &self.device.indirect_validation.as_ref().unwrap().draw;
758        let ensure_entry = |index: usize| {
759            if self.dst_entries.len() <= index {
760                let entry = indirect_draw_validation.acquire_dst_entry(self.device.raw())?;
761                self.dst_entries.push(entry);
762            }
763            Ok(())
764        };
765        let entry_data = Self::get_subrange_impl(ensure_entry, current_entry, size)?;
766        Ok((entry_data.index, entry_data.offset))
767    }
768
769    fn get_metadata_subrange(
770        &mut self,
771        size: u64,
772        current_entry: &mut Option<CurrentEntry>,
773    ) -> Result<(usize, u64), DeviceError> {
774        let indirect_draw_validation = &self.device.indirect_validation.as_ref().unwrap().draw;
775        let ensure_entry = |index: usize| {
776            if self.metadata_entries.len() <= index {
777                let entry = indirect_draw_validation.acquire_metadata_entry(self.device.raw())?;
778                self.metadata_entries.push(entry);
779            }
780            Ok(())
781        };
782        let entry_data = Self::get_subrange_impl(ensure_entry, current_entry, size)?;
783        Ok((entry_data.index, entry_data.offset))
784    }
785
786    fn get_subrange_impl(
787        ensure_entry: impl FnOnce(usize) -> Result<(), hal::DeviceError>,
788        current_entry: &mut Option<CurrentEntry>,
789        size: u64,
790    ) -> Result<CurrentEntry, DeviceError> {
791        let index = if let Some(current_entry) = current_entry.as_mut() {
792            if current_entry.offset + size <= BUFFER_SIZE.get() {
793                let entry_data = current_entry.clone();
794                current_entry.offset += size;
795                return Ok(entry_data);
796            } else {
797                current_entry.index + 1
798            }
799        } else {
800            0
801        };
802
803        ensure_entry(index).map_err(DeviceError::from_hal)?;
804
805        let entry_data = CurrentEntry { index, offset: 0 };
806
807        *current_entry = Some(CurrentEntry {
808            index,
809            offset: size,
810        });
811
812        Ok(entry_data)
813    }
814}
815
816/// This must match the `MetadataEntry` struct used by the shader.
817#[repr(C)]
818struct MetadataEntry {
819    src_offset: u32,
820    dst_offset: u32,
821    vertex_or_index_limit: u32,
822    instance_limit: u32,
823}
824
825impl MetadataEntry {
826    fn new(
827        indexed: bool,
828        src_offset: u64,
829        dst_offset: u64,
830        vertex_or_index_limit: u64,
831        instance_limit: u64,
832    ) -> Self {
833        debug_assert_eq!(
834            4,
835            size_of_val(&Limits::default().max_storage_buffer_binding_size)
836        );
837
838        let src_offset = src_offset as u32; // max_storage_buffer_binding_size is a u32
839        let src_offset = src_offset / 4; // translate byte offset to offset in u32's
840
841        // `src_offset` needs at most 30 bits,
842        // pack `indexed` in bit 31 of `src_offset`
843        let src_offset = src_offset | ((indexed as u32) << 31);
844
845        // max value for limits since first_X and X_count indirect draw arguments are u32
846        let max_limit = u32::MAX as u64 + u32::MAX as u64; // 1 11111111 11111111 11111111 11111110
847
848        let vertex_or_index_limit = vertex_or_index_limit.min(max_limit);
849        let vertex_or_index_limit_bit_32 = (vertex_or_index_limit >> 32) as u32; // extract bit 32
850        let vertex_or_index_limit = vertex_or_index_limit as u32; // truncate the limit to a u32
851
852        let instance_limit = instance_limit.min(max_limit);
853        let instance_limit_bit_32 = (instance_limit >> 32) as u32; // extract bit 32
854        let instance_limit = instance_limit as u32; // truncate the limit to a u32
855
856        let dst_offset = dst_offset as u32; // max_storage_buffer_binding_size is a u32
857        let dst_offset = dst_offset / 4; // translate byte offset to offset in u32's
858
859        // `dst_offset` needs at most 30 bits,
860        // pack `vertex_or_index_limit_bit_32` in bit 30 of `dst_offset` and
861        // pack `instance_limit_bit_32` in bit 31 of `dst_offset`
862        let dst_offset =
863            dst_offset | (vertex_or_index_limit_bit_32 << 30) | (instance_limit_bit_32 << 31);
864
865        Self {
866            src_offset,
867            dst_offset,
868            vertex_or_index_limit,
869            instance_limit,
870        }
871    }
872}
873
874struct DrawIndirectValidationBatch {
875    src_buffer: Arc<crate::resource::Buffer>,
876    src_dynamic_offset: u64,
877    dst_resource_index: usize,
878    entries: Vec<MetadataEntry>,
879
880    staging_buffer_index: usize,
881    staging_buffer_offset: u64,
882    metadata_resource_index: usize,
883    metadata_buffer_offset: u64,
884}
885
886impl DrawIndirectValidationBatch {
887    /// Data to be written to the metadata buffer.
888    fn metadata(&self) -> &[u8] {
889        unsafe {
890            core::slice::from_raw_parts(
891                self.entries.as_ptr().cast::<u8>(),
892                self.entries.len() * size_of::<MetadataEntry>(),
893            )
894        }
895    }
896}
897
898/// Accumulates all needed data needed to validate indirect draws.
899pub(crate) struct DrawBatcher {
900    batches: FastHashMap<(TrackerIndex, u64, usize), DrawIndirectValidationBatch>,
901    current_dst_entry: Option<CurrentEntry>,
902}
903
904impl DrawBatcher {
905    pub(crate) fn new() -> Self {
906        Self {
907            batches: FastHashMap::default(),
908            current_dst_entry: None,
909        }
910    }
911
912    /// Add an indirect draw to be validated.
913    ///
914    /// Returns the index of the indirect buffer in `indirect_draw_validation_resources`
915    /// and the offset to be used for the draw.
916    pub(crate) fn add<'a>(
917        &mut self,
918        indirect_draw_validation_resources: &'a mut DrawResources,
919        device: &Device,
920        src_buffer: &Arc<crate::resource::Buffer>,
921        offset: u64,
922        family: crate::command::DrawCommandFamily,
923        vertex_or_index_limit: u64,
924        instance_limit: u64,
925    ) -> Result<(usize, u64), DeviceError> {
926        // space for D3D12 special constants
927        let extra = if device.backend() == wgt::Backend::Dx12 {
928            3 * size_of::<u32>() as u64
929        } else {
930            0
931        };
932        let stride = extra + crate::command::get_stride_of_indirect_args(family);
933
934        let (dst_resource_index, dst_offset) = indirect_draw_validation_resources
935            .get_dst_subrange(stride, &mut self.current_dst_entry)?;
936
937        let buffer_size = src_buffer.size;
938        let limits = device.adapter.limits();
939        let (src_dynamic_offset, src_offset) = calculate_src_offsets(buffer_size, &limits, offset);
940
941        let src_buffer_tracker_index = src_buffer.tracker_index();
942
943        let entry = MetadataEntry::new(
944            family == crate::command::DrawCommandFamily::DrawIndexed,
945            src_offset,
946            dst_offset,
947            vertex_or_index_limit,
948            instance_limit,
949        );
950
951        match self.batches.entry((
952            src_buffer_tracker_index,
953            src_dynamic_offset,
954            dst_resource_index,
955        )) {
956            hashbrown::hash_map::Entry::Occupied(mut occupied_entry) => {
957                occupied_entry.get_mut().entries.push(entry)
958            }
959            hashbrown::hash_map::Entry::Vacant(vacant_entry) => {
960                vacant_entry.insert(DrawIndirectValidationBatch {
961                    src_buffer: src_buffer.clone(),
962                    src_dynamic_offset,
963                    dst_resource_index,
964                    entries: vec![entry],
965
966                    // these will be initialized once we accumulated all entries for the batch
967                    staging_buffer_index: 0,
968                    staging_buffer_offset: 0,
969                    metadata_resource_index: 0,
970                    metadata_buffer_offset: 0,
971                });
972            }
973        }
974
975        Ok((dst_resource_index, dst_offset))
976    }
977}