wgpu_core/indirect_validation/
draw.rs

1use super::{
2    utils::{BufferBarrierScratch, BufferBarriers, UniqueIndexExt as _, UniqueIndexScratch},
3    CreateIndirectValidationPipelineError,
4};
5use crate::{
6    command::{get_src_stride_of_indirect_args, RenderPassErrorInner},
7    device::{queue::TempResource, Device, DeviceError},
8    hal_label,
9    lock::{rank, Mutex},
10    pipeline::{CreateComputePipelineError, CreateShaderModuleError},
11    resource::{RawResourceAccess as _, StagingBuffer, Trackable},
12    snatch::SnatchGuard,
13    track::TrackerIndex,
14    FastHashMap,
15};
16use alloc::{boxed::Box, string::ToString, sync::Arc, vec, vec::Vec};
17use core::{mem::size_of, num::NonZeroU64};
18use wgt::Limits;
19
20/// Note: This needs to be under:
21///
22/// default max_compute_workgroups_per_dimension * size_of::<wgt::DrawIndirectArgs>() * `workgroup_size` used by the shader
23///
24/// = (2^16 - 1) * 2^4 * 2^6
25///
26/// It is currently set to:
27///
28/// = (2^16 - 1) * 2^4
29///
30/// This is enough space for:
31///
32/// - 65535 [`wgt::DrawIndirectArgs`] / [`MetadataEntry`]
33/// - 52428 [`wgt::DrawIndexedIndirectArgs`]
34const BUFFER_SIZE: wgt::BufferSize = wgt::BufferSize::new(1_048_560).unwrap();
35
36/// Holds all device-level resources that are needed to validate indirect draws.
37///
38/// This machinery requires the following limits:
39///
40/// - max_bind_groups: 3,
41/// - max_dynamic_storage_buffers_per_pipeline_layout: 1,
42/// - max_storage_buffers_per_shader_stage: 3,
43/// - max_immediate_size: 8,
44///
45/// These are all indirectly satisfied by `DownlevelFlags::INDIRECT_EXECUTION`, which is also
46/// required for this module's functionality to work.
47#[derive(Debug)]
48pub(crate) struct Draw {
49    module: Box<dyn hal::DynShaderModule>,
50    metadata_bind_group_layout: Box<dyn hal::DynBindGroupLayout>,
51    src_bind_group_layout: Box<dyn hal::DynBindGroupLayout>,
52    dst_bind_group_layout: Box<dyn hal::DynBindGroupLayout>,
53    pipeline_layout: Box<dyn hal::DynPipelineLayout>,
54    pipeline: Box<dyn hal::DynComputePipeline>,
55
56    free_indirect_entries: Mutex<Vec<BufferPoolEntry>>,
57    free_metadata_entries: Mutex<Vec<BufferPoolEntry>>,
58}
59
60impl Draw {
61    pub(super) fn new(
62        device: &dyn hal::DynDevice,
63        required_features: &wgt::Features,
64        instance_flags: wgt::InstanceFlags,
65        backend: wgt::Backend,
66        limits: &Limits,
67    ) -> Result<Self, CreateIndirectValidationPipelineError> {
68        // Indirect draw validation doesn't support buffer sizes higher than u32
69        // since its offsets in the shader and dynamic offsets are u32.
70        //
71        // See also: `u64_offset_to_u32_offset`.
72        assert!(limits.max_buffer_size <= u32::MAX as u64);
73
74        let module = create_validation_module(device, instance_flags)?;
75
76        let metadata_bind_group_layout = create_bind_group_layout(
77            device,
78            true,
79            false,
80            BUFFER_SIZE,
81            hal_label(
82                Some("(wgpu internal) Indirect draw validation metadata bind group layout"),
83                instance_flags,
84            ),
85        )?;
86        let src_bind_group_layout = create_bind_group_layout(
87            device,
88            true,
89            true,
90            wgt::BufferSize::new(4 * 4).unwrap(),
91            hal_label(
92                Some("(wgpu internal) Indirect draw validation source bind group layout"),
93                instance_flags,
94            ),
95        )?;
96        let dst_bind_group_layout = create_bind_group_layout(
97            device,
98            false,
99            false,
100            BUFFER_SIZE,
101            hal_label(
102                Some("(wgpu internal) Indirect draw validation destination bind group layout"),
103                instance_flags,
104            ),
105        )?;
106
107        let pipeline_layout_desc = hal::PipelineLayoutDescriptor {
108            label: hal_label(
109                Some("(wgpu internal) Indirect draw validation pipeline layout"),
110                instance_flags,
111            ),
112            flags: hal::PipelineLayoutFlags::empty(),
113            bind_group_layouts: &[
114                Some(metadata_bind_group_layout.as_ref()),
115                Some(src_bind_group_layout.as_ref()),
116                Some(dst_bind_group_layout.as_ref()),
117            ],
118            immediate_size: 8,
119        };
120        let pipeline_layout = unsafe {
121            device
122                .create_pipeline_layout(&pipeline_layout_desc)
123                .map_err(DeviceError::from_hal)?
124        };
125
126        let supports_indirect_first_instance =
127            required_features.contains(wgt::Features::INDIRECT_FIRST_INSTANCE);
128        let write_d3d12_special_constants = backend == wgt::Backend::Dx12;
129        let pipeline = create_validation_pipeline(
130            device,
131            module.as_ref(),
132            pipeline_layout.as_ref(),
133            supports_indirect_first_instance,
134            write_d3d12_special_constants,
135            instance_flags,
136        )?;
137
138        Ok(Self {
139            module,
140            metadata_bind_group_layout,
141            src_bind_group_layout,
142            dst_bind_group_layout,
143            pipeline_layout,
144            pipeline,
145
146            free_indirect_entries: Mutex::new(rank::BUFFER_POOL, Vec::new()),
147            free_metadata_entries: Mutex::new(rank::BUFFER_POOL, Vec::new()),
148        })
149    }
150
151    /// `Ok(None)` will only be returned if `buffer_size` is `0`.
152    pub(super) fn create_src_bind_group(
153        &self,
154        device: &dyn hal::DynDevice,
155        limits: &Limits,
156        buffer_size: u64,
157        buffer: &dyn hal::DynBuffer,
158        instance_flags: wgt::InstanceFlags,
159    ) -> Result<Option<Box<dyn hal::DynBindGroup>>, DeviceError> {
160        let binding_size = calculate_src_buffer_binding_size(buffer_size, limits);
161        let Some(binding_size) = NonZeroU64::new(binding_size) else {
162            return Ok(None);
163        };
164        let hal_desc = hal::BindGroupDescriptor {
165            label: hal_label(
166                Some("(wgpu internal) Indirect draw validation source bind group"),
167                instance_flags,
168            ),
169            layout: self.src_bind_group_layout.as_ref(),
170            entries: &[hal::BindGroupEntry {
171                binding: 0,
172                resource_index: 0,
173                count: 1,
174            }],
175            // SAFETY: We calculated the binding size to fit within the buffer.
176            buffers: &[hal::BufferBinding::new_unchecked(buffer, 0, binding_size)],
177            samplers: &[],
178            textures: &[],
179            acceleration_structures: &[],
180            external_textures: &[],
181        };
182        unsafe {
183            device
184                .create_bind_group(&hal_desc)
185                .map(Some)
186                .map_err(DeviceError::from_hal)
187        }
188    }
189
190    fn acquire_dst_entry(
191        &self,
192        device: &dyn hal::DynDevice,
193        instance_flags: wgt::InstanceFlags,
194    ) -> Result<BufferPoolEntry, hal::DeviceError> {
195        let mut free_buffers = self.free_indirect_entries.lock();
196        match free_buffers.pop() {
197            Some(buffer) => Ok(buffer),
198            None => {
199                let usage = wgt::BufferUses::INDIRECT | wgt::BufferUses::STORAGE_READ_WRITE;
200                create_buffer_and_bind_group(
201                    device,
202                    usage,
203                    self.dst_bind_group_layout.as_ref(),
204                    hal_label(Some("(wgpu internal) Indirect draw validation destination buffer"), instance_flags),
205                    hal_label(Some("(wgpu internal) Indirect draw validation destination bind group layout"), instance_flags),
206                )
207            }
208        }
209    }
210
211    fn release_dst_entries(&self, entries: impl Iterator<Item = BufferPoolEntry>) {
212        self.free_indirect_entries.lock().extend(entries);
213    }
214
215    fn acquire_metadata_entry(
216        &self,
217        device: &dyn hal::DynDevice,
218        instance_flags: wgt::InstanceFlags,
219    ) -> Result<BufferPoolEntry, hal::DeviceError> {
220        let mut free_buffers = self.free_metadata_entries.lock();
221        match free_buffers.pop() {
222            Some(buffer) => Ok(buffer),
223            None => {
224                let usage = wgt::BufferUses::COPY_DST | wgt::BufferUses::STORAGE_READ_ONLY;
225                create_buffer_and_bind_group(
226                    device,
227                    usage,
228                    self.metadata_bind_group_layout.as_ref(),
229                    hal_label(
230                        Some("(wgpu internal) Indirect draw validation metadata buffer"),
231                        instance_flags,
232                    ),
233                    hal_label(
234                        Some("(wgpu internal) Indirect draw validation metadata bind group layout"),
235                        instance_flags,
236                    ),
237                )
238            }
239        }
240    }
241
242    fn release_metadata_entries(&self, entries: impl Iterator<Item = BufferPoolEntry>) {
243        self.free_metadata_entries.lock().extend(entries);
244    }
245
246    /// Injects a compute pass that will validate all indirect draws in the current render pass.
247    pub(crate) fn inject_validation_pass(
248        &self,
249        device: &Arc<Device>,
250        snatch_guard: &SnatchGuard,
251        resources: &mut DrawResources,
252        temp_resources: &mut Vec<TempResource>,
253        encoder: &mut dyn hal::DynCommandEncoder,
254        batcher: DrawBatcher,
255    ) -> Result<(), RenderPassErrorInner> {
256        let mut batches = batcher.batches;
257
258        if batches.is_empty() {
259            return Ok(());
260        }
261
262        let max_staging_buffer_size = 1 << 26; // ~67MiB
263
264        let mut staging_buffers = Vec::new();
265
266        let mut current_size = 0;
267        for batch in batches.values_mut() {
268            let data = batch.metadata();
269            let offset = if current_size + data.len() > max_staging_buffer_size {
270                let staging_buffer =
271                    StagingBuffer::new(device, NonZeroU64::new(current_size as u64).unwrap())?;
272                staging_buffers.push(staging_buffer);
273                current_size = data.len();
274                0
275            } else {
276                let offset = current_size;
277                current_size += data.len();
278                offset as u64
279            };
280            batch.staging_buffer_index = staging_buffers.len();
281            batch.staging_buffer_offset = offset;
282        }
283        if current_size != 0 {
284            let staging_buffer =
285                StagingBuffer::new(device, NonZeroU64::new(current_size as u64).unwrap())?;
286            staging_buffers.push(staging_buffer);
287        }
288
289        for batch in batches.values() {
290            let data = batch.metadata();
291            let staging_buffer = &mut staging_buffers[batch.staging_buffer_index];
292            unsafe {
293                staging_buffer.write_with_offset(
294                    data,
295                    0,
296                    batch.staging_buffer_offset as isize,
297                    data.len(),
298                )
299            };
300        }
301
302        let staging_buffers: Vec<_> = staging_buffers
303            .into_iter()
304            .map(|buffer| buffer.flush())
305            .collect();
306
307        let mut current_metadata_entry = None;
308        for batch in batches.values_mut() {
309            let data = batch.metadata();
310            let (metadata_resource_index, metadata_buffer_offset) =
311                resources.get_metadata_subrange(data.len() as u64, &mut current_metadata_entry)?;
312            batch.metadata_resource_index = metadata_resource_index;
313            batch.metadata_buffer_offset = metadata_buffer_offset;
314        }
315
316        let buffer_barrier_scratch = &mut BufferBarrierScratch::new();
317        let unique_index_scratch = &mut UniqueIndexScratch::new();
318
319        BufferBarriers::new(buffer_barrier_scratch)
320            .extend(
321                batches
322                    .values()
323                    .map(|batch| batch.staging_buffer_index)
324                    .unique(unique_index_scratch)
325                    .map(|index| hal::BufferBarrier {
326                        buffer: staging_buffers[index].raw(),
327                        usage: hal::StateTransition {
328                            from: wgt::BufferUses::MAP_WRITE,
329                            to: wgt::BufferUses::COPY_SRC,
330                        },
331                    }),
332            )
333            .extend(
334                batches
335                    .values()
336                    .map(|batch| batch.metadata_resource_index)
337                    .unique(unique_index_scratch)
338                    .map(|index| hal::BufferBarrier {
339                        buffer: resources.get_metadata_buffer(index),
340                        usage: hal::StateTransition {
341                            from: wgt::BufferUses::STORAGE_READ_ONLY,
342                            to: wgt::BufferUses::COPY_DST,
343                        },
344                    }),
345            )
346            .encode(encoder);
347
348        for batch in batches.values() {
349            let data = batch.metadata();
350            let data_size = NonZeroU64::new(data.len() as u64).unwrap();
351
352            let staging_buffer = &staging_buffers[batch.staging_buffer_index];
353
354            let metadata_buffer = resources.get_metadata_buffer(batch.metadata_resource_index);
355
356            unsafe {
357                encoder.copy_buffer_to_buffer(
358                    staging_buffer.raw(),
359                    metadata_buffer,
360                    &[hal::BufferCopy {
361                        src_offset: batch.staging_buffer_offset,
362                        dst_offset: batch.metadata_buffer_offset,
363                        size: data_size,
364                    }],
365                );
366            }
367        }
368
369        for staging_buffer in staging_buffers {
370            temp_resources.push(TempResource::StagingBuffer(staging_buffer));
371        }
372
373        BufferBarriers::new(buffer_barrier_scratch)
374            .extend(
375                batches
376                    .values()
377                    .map(|batch| batch.metadata_resource_index)
378                    .unique(unique_index_scratch)
379                    .map(|index| hal::BufferBarrier {
380                        buffer: resources.get_metadata_buffer(index),
381                        usage: hal::StateTransition {
382                            from: wgt::BufferUses::COPY_DST,
383                            to: wgt::BufferUses::STORAGE_READ_ONLY,
384                        },
385                    }),
386            )
387            .extend(
388                batches
389                    .values()
390                    .map(|batch| batch.dst_resource_index)
391                    .unique(unique_index_scratch)
392                    .map(|index| hal::BufferBarrier {
393                        buffer: resources.get_dst_buffer(index),
394                        usage: hal::StateTransition {
395                            from: wgt::BufferUses::INDIRECT,
396                            to: wgt::BufferUses::STORAGE_READ_WRITE,
397                        },
398                    }),
399            )
400            .encode(encoder);
401
402        let desc = hal::ComputePassDescriptor {
403            label: hal_label(
404                Some("(wgpu internal) Indirect draw validation pass"),
405                device.instance_flags,
406            ),
407            timestamp_writes: None,
408        };
409        unsafe {
410            encoder.begin_compute_pass(&desc);
411        }
412        unsafe {
413            encoder.set_compute_pipeline(self.pipeline.as_ref());
414        }
415
416        for batch in batches.values() {
417            let pipeline_layout = self.pipeline_layout.as_ref();
418
419            let metadata_start =
420                (batch.metadata_buffer_offset / size_of::<MetadataEntry>() as u64) as u32;
421            let metadata_count = batch.entries.len() as u32;
422            unsafe {
423                encoder.set_immediates(pipeline_layout, 0, &[metadata_start, metadata_count]);
424            }
425
426            let metadata_bind_group =
427                resources.get_metadata_bind_group(batch.metadata_resource_index);
428            unsafe {
429                encoder.set_bind_group(pipeline_layout, 0, metadata_bind_group, &[]);
430            }
431
432            // Make sure the indirect buffer is still valid.
433            batch.src_buffer.try_raw(snatch_guard)?;
434
435            let src_bind_group = batch
436                .src_buffer
437                .indirect_validation_bind_groups
438                .get(snatch_guard)
439                .unwrap()
440                .draw
441                .as_ref();
442            unsafe {
443                encoder.set_bind_group(
444                    pipeline_layout,
445                    1,
446                    src_bind_group,
447                    &[u64_offset_to_u32_offset(batch.src_dynamic_offset)],
448                );
449            }
450
451            let dst_bind_group = resources.get_dst_bind_group(batch.dst_resource_index);
452            unsafe {
453                encoder.set_bind_group(pipeline_layout, 2, dst_bind_group, &[]);
454            }
455
456            unsafe {
457                encoder.dispatch_workgroups([(batch.entries.len() as u32).div_ceil(64), 1, 1]);
458            }
459        }
460
461        unsafe {
462            encoder.end_compute_pass();
463        }
464
465        BufferBarriers::new(buffer_barrier_scratch)
466            .extend(
467                batches
468                    .values()
469                    .map(|batch| batch.dst_resource_index)
470                    .unique(unique_index_scratch)
471                    .map(|index| hal::BufferBarrier {
472                        buffer: resources.get_dst_buffer(index),
473                        usage: hal::StateTransition {
474                            from: wgt::BufferUses::STORAGE_READ_WRITE,
475                            to: wgt::BufferUses::INDIRECT,
476                        },
477                    }),
478            )
479            .encode(encoder);
480
481        Ok(())
482    }
483
484    pub(super) fn dispose(self, device: &dyn hal::DynDevice) {
485        let Draw {
486            module,
487            metadata_bind_group_layout,
488            src_bind_group_layout,
489            dst_bind_group_layout,
490            pipeline_layout,
491            pipeline,
492
493            free_indirect_entries,
494            free_metadata_entries,
495        } = self;
496
497        for entry in free_indirect_entries.into_inner().drain(..) {
498            unsafe {
499                device.destroy_bind_group(entry.bind_group);
500                device.destroy_buffer(entry.buffer);
501            }
502        }
503
504        for entry in free_metadata_entries.into_inner().drain(..) {
505            unsafe {
506                device.destroy_bind_group(entry.bind_group);
507                device.destroy_buffer(entry.buffer);
508            }
509        }
510
511        unsafe {
512            device.destroy_compute_pipeline(pipeline);
513            device.destroy_pipeline_layout(pipeline_layout);
514            device.destroy_bind_group_layout(metadata_bind_group_layout);
515            device.destroy_bind_group_layout(src_bind_group_layout);
516            device.destroy_bind_group_layout(dst_bind_group_layout);
517            device.destroy_shader_module(module);
518        }
519    }
520}
521
522fn create_validation_module(
523    device: &dyn hal::DynDevice,
524    instance_flags: wgt::InstanceFlags,
525) -> Result<Box<dyn hal::DynShaderModule>, CreateIndirectValidationPipelineError> {
526    let src = include_str!("./validate_draw.wgsl");
527
528    #[cfg(feature = "wgsl")]
529    let module = naga::front::wgsl::parse_str(src).map_err(|inner| {
530        CreateShaderModuleError::Parsing(naga::error::ShaderError {
531            source: src.to_string(),
532            label: None,
533            inner: Box::new(inner),
534        })
535    })?;
536    #[cfg(not(feature = "wgsl"))]
537    #[allow(clippy::diverging_sub_expression)]
538    let module = panic!("Indirect validation requires the wgsl feature flag to be enabled!");
539
540    let info = crate::device::create_validator(
541        wgt::Features::IMMEDIATES,
542        wgt::DownlevelFlags::empty(),
543        naga::valid::ValidationFlags::all(),
544    )
545    .validate(&module)
546    .map_err(|inner| {
547        CreateShaderModuleError::Validation(naga::error::ShaderError {
548            source: src.to_string(),
549            label: None,
550            inner: Box::new(inner),
551        })
552    })?;
553    let hal_shader = hal::ShaderInput::Naga(hal::NagaShader {
554        module: alloc::borrow::Cow::Owned(module),
555        info,
556        debug_source: None,
557    });
558    let hal_desc = hal::ShaderModuleDescriptor {
559        label: hal_label(
560            Some("(wgpu internal) Indirect draw validation shader module"),
561            instance_flags,
562        ),
563        runtime_checks: wgt::ShaderRuntimeChecks::unchecked(),
564    };
565    let module = unsafe { device.create_shader_module(&hal_desc, hal_shader) }.map_err(
566        |error| match error {
567            hal::ShaderError::Device(error) => {
568                CreateShaderModuleError::Device(DeviceError::from_hal(error))
569            }
570            hal::ShaderError::Compilation(ref msg) => {
571                log::error!("Shader error: {msg}");
572                CreateShaderModuleError::Generation
573            }
574        },
575    )?;
576
577    Ok(module)
578}
579
580fn create_validation_pipeline(
581    device: &dyn hal::DynDevice,
582    module: &dyn hal::DynShaderModule,
583    pipeline_layout: &dyn hal::DynPipelineLayout,
584    supports_indirect_first_instance: bool,
585    write_d3d12_special_constants: bool,
586    instance_flags: wgt::InstanceFlags,
587) -> Result<Box<dyn hal::DynComputePipeline>, CreateIndirectValidationPipelineError> {
588    let pipeline_desc = hal::ComputePipelineDescriptor {
589        label: hal_label(
590            Some("(wgpu internal) Indirect draw validation pipeline"),
591            instance_flags,
592        ),
593        layout: pipeline_layout,
594        stage: hal::ProgrammableStage {
595            module,
596            entry_point: "main",
597            constants: &hashbrown::HashMap::from([
598                (
599                    "supports_indirect_first_instance".to_string(),
600                    f64::from(supports_indirect_first_instance),
601                ),
602                (
603                    "write_d3d12_special_constants".to_string(),
604                    f64::from(write_d3d12_special_constants),
605                ),
606            ]),
607            zero_initialize_workgroup_memory: false,
608        },
609        cache: None,
610    };
611    let pipeline =
612        unsafe { device.create_compute_pipeline(&pipeline_desc) }.map_err(|err| match err {
613            hal::PipelineError::Device(error) => {
614                CreateComputePipelineError::Device(DeviceError::from_hal(error))
615            }
616            hal::PipelineError::Linkage(_stages, msg) => CreateComputePipelineError::Internal(msg),
617            hal::PipelineError::EntryPoint(_stage) => CreateComputePipelineError::Internal(
618                crate::device::ENTRYPOINT_FAILURE_ERROR.to_string(),
619            ),
620            hal::PipelineError::PipelineConstants(_, error) => {
621                CreateComputePipelineError::PipelineConstants(error)
622            }
623        })?;
624
625    Ok(pipeline)
626}
627
628fn create_bind_group_layout(
629    device: &dyn hal::DynDevice,
630    read_only: bool,
631    has_dynamic_offset: bool,
632    min_binding_size: wgt::BufferSize,
633    label: Option<&'static str>,
634) -> Result<Box<dyn hal::DynBindGroupLayout>, CreateIndirectValidationPipelineError> {
635    let bind_group_layout_desc = hal::BindGroupLayoutDescriptor {
636        label,
637        flags: hal::BindGroupLayoutFlags::empty(),
638        entries: &[wgt::BindGroupLayoutEntry {
639            binding: 0,
640            visibility: wgt::ShaderStages::COMPUTE,
641            ty: wgt::BindingType::Buffer {
642                ty: wgt::BufferBindingType::Storage { read_only },
643                has_dynamic_offset,
644                min_binding_size: Some(min_binding_size),
645            },
646            count: None,
647        }],
648    };
649    let bind_group_layout = unsafe {
650        device
651            .create_bind_group_layout(&bind_group_layout_desc)
652            .map_err(DeviceError::from_hal)?
653    };
654
655    Ok(bind_group_layout)
656}
657
658/// Returns the largest binding size that when combined with dynamic offsets can address the whole buffer.
659fn calculate_src_buffer_binding_size(buffer_size: u64, limits: &Limits) -> u64 {
660    let max_storage_buffer_binding_size = limits.max_storage_buffer_binding_size;
661    let min_storage_buffer_offset_alignment = limits.min_storage_buffer_offset_alignment as u64;
662
663    if buffer_size <= max_storage_buffer_binding_size {
664        buffer_size
665    } else {
666        let buffer_rem = buffer_size % min_storage_buffer_offset_alignment;
667        let binding_rem = max_storage_buffer_binding_size % min_storage_buffer_offset_alignment;
668
669        // Can the buffer remainder fit in the binding remainder?
670        // If so, align max binding size and add buffer remainder
671        if buffer_rem <= binding_rem {
672            max_storage_buffer_binding_size - binding_rem + buffer_rem
673        }
674        // If not, align max binding size, shorten it by a chunk and add buffer remainder
675        else {
676            max_storage_buffer_binding_size - binding_rem - min_storage_buffer_offset_alignment
677                + buffer_rem
678        }
679    }
680}
681
682/// Splits the given `offset` into a dynamic offset & offset.
683fn calculate_src_offsets(
684    buffer_size: u64,
685    limits: &Limits,
686    offset: u64,
687    data_size: u64,
688) -> (u64, u64) {
689    const MAX_DATA_SIZE: u64 = 20; // indexed indirect draw params are 20B
690    let binding_size = calculate_src_buffer_binding_size(buffer_size, limits);
691    let min_storage_buffer_offset_alignment = limits.min_storage_buffer_offset_alignment as u64;
692
693    assert!([16, MAX_DATA_SIZE].contains(&data_size));
694    assert!([32, 64, 128, 256].contains(&min_storage_buffer_offset_alignment));
695    assert!(buffer_size >= data_size);
696    assert!(offset <= buffer_size - data_size);
697    assert!(binding_size <= buffer_size);
698
699    // Invariants that the outputs of this function must satisfy:
700    // - out_dynamic_offset + out_offset = offset
701    // - out_dynamic_offset % min_storage_buffer_offset_alignment = 0
702    // - out_dynamic_offset + binding_size <= buffer_size
703    // - out_offset + data_size <= binding_size
704
705    // Align the max offset in the binding and treat it as the stride between
706    // dynamic offsets.
707    //
708    // `dynamic_offset_stride` could just be `min_storage_buffer_offset_alignment`
709    // but we want to make it as large as possible since setting dynamic
710    // offsets requires extra calls to setBindGroup and then to dispatch,
711    // calls which we want to minimize.
712    //
713    // Use `MAX_DATA_SIZE` instead of the actual `data_size` so that the
714    // resulting stride is the same for both indexed and non-indexed draw calls,
715    // reducing the likelihood of `out_dynamic_offset` being different.
716    let dynamic_offset_stride = binding_size.saturating_sub(MAX_DATA_SIZE)
717        / min_storage_buffer_offset_alignment
718        * min_storage_buffer_offset_alignment;
719    if dynamic_offset_stride == 0 {
720        return (0, offset);
721    }
722
723    let max_dynamic_offset = buffer_size - binding_size;
724    let out_dynamic_offset =
725        max_dynamic_offset.min(offset / dynamic_offset_stride * dynamic_offset_stride);
726    let out_offset = offset - out_dynamic_offset;
727
728    (out_dynamic_offset, out_offset)
729}
730
731#[derive(Debug)]
732struct BufferPoolEntry {
733    buffer: Box<dyn hal::DynBuffer>,
734    bind_group: Box<dyn hal::DynBindGroup>,
735}
736
737fn create_buffer_and_bind_group(
738    device: &dyn hal::DynDevice,
739    usage: wgt::BufferUses,
740    bind_group_layout: &dyn hal::DynBindGroupLayout,
741    buffer_label: Option<&'static str>,
742    bind_group_label: Option<&'static str>,
743) -> Result<BufferPoolEntry, hal::DeviceError> {
744    let buffer_desc = hal::BufferDescriptor {
745        label: buffer_label,
746        size: BUFFER_SIZE.get(),
747        usage,
748        memory_flags: hal::MemoryFlags::empty(),
749    };
750    let buffer = unsafe { device.create_buffer(&buffer_desc) }?;
751    let bind_group_desc = hal::BindGroupDescriptor {
752        label: bind_group_label,
753        layout: bind_group_layout,
754        entries: &[hal::BindGroupEntry {
755            binding: 0,
756            resource_index: 0,
757            count: 1,
758        }],
759        // SAFETY: We just created the buffer with this size.
760        buffers: &[hal::BufferBinding::new_unchecked(
761            buffer.as_ref(),
762            0,
763            BUFFER_SIZE,
764        )],
765        samplers: &[],
766        textures: &[],
767        acceleration_structures: &[],
768        external_textures: &[],
769    };
770    let bind_group = unsafe { device.create_bind_group(&bind_group_desc) }?;
771    Ok(BufferPoolEntry { buffer, bind_group })
772}
773
774#[derive(Clone)]
775struct CurrentEntry {
776    index: usize,
777    offset: u64,
778}
779
780/// Holds all command buffer-level resources that are needed to validate indirect draws.
781pub(crate) struct DrawResources {
782    device: Arc<Device>,
783    dst_entries: Vec<BufferPoolEntry>,
784    metadata_entries: Vec<BufferPoolEntry>,
785}
786
787impl Drop for DrawResources {
788    fn drop(&mut self) {
789        if let Some(ref indirect_validation) = self.device.indirect_validation {
790            let indirect_draw_validation = &indirect_validation.draw;
791            indirect_draw_validation.release_dst_entries(self.dst_entries.drain(..));
792            indirect_draw_validation.release_metadata_entries(self.metadata_entries.drain(..));
793        }
794    }
795}
796
797impl DrawResources {
798    pub(crate) fn new(device: Arc<Device>) -> Self {
799        DrawResources {
800            device,
801            dst_entries: Vec::new(),
802            metadata_entries: Vec::new(),
803        }
804    }
805
806    pub(crate) fn get_dst_buffer(&self, index: usize) -> &dyn hal::DynBuffer {
807        self.dst_entries.get(index).unwrap().buffer.as_ref()
808    }
809
810    fn get_dst_bind_group(&self, index: usize) -> &dyn hal::DynBindGroup {
811        self.dst_entries.get(index).unwrap().bind_group.as_ref()
812    }
813
814    fn get_metadata_buffer(&self, index: usize) -> &dyn hal::DynBuffer {
815        self.metadata_entries.get(index).unwrap().buffer.as_ref()
816    }
817
818    fn get_metadata_bind_group(&self, index: usize) -> &dyn hal::DynBindGroup {
819        self.metadata_entries
820            .get(index)
821            .unwrap()
822            .bind_group
823            .as_ref()
824    }
825
826    fn get_dst_subrange(
827        &mut self,
828        size: u64,
829        current_entry: &mut Option<CurrentEntry>,
830    ) -> Result<(usize, u64), DeviceError> {
831        let indirect_draw_validation = &self.device.indirect_validation.as_ref().unwrap().draw;
832        let ensure_entry = |index: usize| {
833            if self.dst_entries.len() <= index {
834                let entry = indirect_draw_validation
835                    .acquire_dst_entry(self.device.raw(), self.device.instance_flags)?;
836                self.dst_entries.push(entry);
837            }
838            Ok(())
839        };
840        let entry_data = Self::get_subrange_impl(ensure_entry, current_entry, size)?;
841        Ok((entry_data.index, entry_data.offset))
842    }
843
844    fn get_metadata_subrange(
845        &mut self,
846        size: u64,
847        current_entry: &mut Option<CurrentEntry>,
848    ) -> Result<(usize, u64), DeviceError> {
849        let indirect_draw_validation = &self.device.indirect_validation.as_ref().unwrap().draw;
850        let ensure_entry = |index: usize| {
851            if self.metadata_entries.len() <= index {
852                let entry = indirect_draw_validation
853                    .acquire_metadata_entry(self.device.raw(), self.device.instance_flags)?;
854                self.metadata_entries.push(entry);
855            }
856            Ok(())
857        };
858        let entry_data = Self::get_subrange_impl(ensure_entry, current_entry, size)?;
859        Ok((entry_data.index, entry_data.offset))
860    }
861
862    fn get_subrange_impl(
863        ensure_entry: impl FnOnce(usize) -> Result<(), hal::DeviceError>,
864        current_entry: &mut Option<CurrentEntry>,
865        size: u64,
866    ) -> Result<CurrentEntry, DeviceError> {
867        let index = if let Some(current_entry) = current_entry.as_mut() {
868            if current_entry.offset + size <= BUFFER_SIZE.get() {
869                let entry_data = current_entry.clone();
870                current_entry.offset += size;
871                return Ok(entry_data);
872            } else {
873                current_entry.index + 1
874            }
875        } else {
876            0
877        };
878
879        ensure_entry(index).map_err(DeviceError::from_hal)?;
880
881        let entry_data = CurrentEntry { index, offset: 0 };
882
883        *current_entry = Some(CurrentEntry {
884            index,
885            offset: size,
886        });
887
888        Ok(entry_data)
889    }
890}
891
892/// This must match the `MetadataEntry` struct used by the shader.
893#[repr(C)]
894struct MetadataEntry {
895    src_offset: u32,
896    dst_offset: u32,
897    vertex_or_index_limit: u32,
898    instance_limit: u32,
899}
900
901impl MetadataEntry {
902    fn new(
903        indexed: bool,
904        src_offset: u64,
905        dst_offset: u64,
906        vertex_or_index_limit: u64,
907        instance_limit: u64,
908    ) -> Self {
909        const U32_MAX_AS_U64: u64 = u32::MAX as u64;
910
911        let src_offset = u64_offset_to_u32_offset(src_offset);
912        let src_offset = src_offset / 4; // translate byte offset to offset in u32's
913
914        // `src_offset` needs at most 30 bits,
915        // pack `indexed` in bit 31 of `src_offset`
916        let src_offset = src_offset | ((indexed as u32) << 31);
917
918        // max value for limits since first_X and X_count indirect draw arguments are u32
919        let max_limit = U32_MAX_AS_U64 + U32_MAX_AS_U64; // 1 11111111 11111111 11111111 11111110
920
921        let vertex_or_index_limit = vertex_or_index_limit.min(max_limit);
922        let vertex_or_index_limit_bit_32 = (vertex_or_index_limit >> 32) as u32; // extract bit 32
923        let vertex_or_index_limit = vertex_or_index_limit as u32; // truncate the limit to a u32
924
925        let instance_limit = instance_limit.min(max_limit);
926        let instance_limit_bit_32 = (instance_limit >> 32) as u32; // extract bit 32
927        let instance_limit = instance_limit as u32; // truncate the limit to a u32
928
929        let dst_offset = u64_offset_to_u32_offset(dst_offset);
930        let dst_offset = dst_offset / 4; // translate byte offset to offset in u32's
931
932        // `dst_offset` needs at most 30 bits,
933        // pack `vertex_or_index_limit_bit_32` in bit 30 of `dst_offset` and
934        // pack `instance_limit_bit_32` in bit 31 of `dst_offset`
935        let dst_offset =
936            dst_offset | (vertex_or_index_limit_bit_32 << 30) | (instance_limit_bit_32 << 31);
937
938        Self {
939            src_offset,
940            dst_offset,
941            vertex_or_index_limit,
942            instance_limit,
943        }
944    }
945}
946
947struct DrawIndirectValidationBatch {
948    src_buffer: Arc<crate::resource::Buffer>,
949    src_dynamic_offset: u64,
950    dst_resource_index: usize,
951    entries: Vec<MetadataEntry>,
952
953    staging_buffer_index: usize,
954    staging_buffer_offset: u64,
955    metadata_resource_index: usize,
956    metadata_buffer_offset: u64,
957}
958
959impl DrawIndirectValidationBatch {
960    /// Data to be written to the metadata buffer.
961    fn metadata(&self) -> &[u8] {
962        unsafe {
963            core::slice::from_raw_parts(
964                self.entries.as_ptr().cast::<u8>(),
965                self.entries.len() * size_of::<MetadataEntry>(),
966            )
967        }
968    }
969}
970
971/// Accumulates all needed data needed to validate indirect draws.
972pub(crate) struct DrawBatcher {
973    batches: FastHashMap<(TrackerIndex, u64, usize), DrawIndirectValidationBatch>,
974    current_dst_entry: Option<CurrentEntry>,
975}
976
977impl DrawBatcher {
978    pub(crate) fn new() -> Self {
979        Self {
980            batches: FastHashMap::default(),
981            current_dst_entry: None,
982        }
983    }
984
985    /// Add an indirect draw to be validated.
986    ///
987    /// Returns the index of the indirect buffer in `indirect_draw_validation_resources`
988    /// and the offset to be used for the draw.
989    pub(crate) fn add<'a>(
990        &mut self,
991        indirect_draw_validation_resources: &'a mut DrawResources,
992        device: &Device,
993        src_buffer: &Arc<crate::resource::Buffer>,
994        offset: u64,
995        family: crate::command::DrawCommandFamily,
996        vertex_or_index_limit: u64,
997        instance_limit: u64,
998    ) -> Result<(usize, u64), DeviceError> {
999        let stride = crate::command::get_dst_stride_of_indirect_args(device.backend(), family);
1000
1001        let (dst_resource_index, dst_offset) = indirect_draw_validation_resources
1002            .get_dst_subrange(stride, &mut self.current_dst_entry)?;
1003
1004        let buffer_size = src_buffer.size;
1005        let limits = device.adapter.limits();
1006        let data_size = get_src_stride_of_indirect_args(family);
1007        let (src_dynamic_offset, src_offset) =
1008            calculate_src_offsets(buffer_size, &limits, offset, data_size);
1009
1010        let src_buffer_tracker_index = src_buffer.tracker_index();
1011
1012        let entry = MetadataEntry::new(
1013            family == crate::command::DrawCommandFamily::DrawIndexed,
1014            src_offset,
1015            dst_offset,
1016            vertex_or_index_limit,
1017            instance_limit,
1018        );
1019
1020        match self.batches.entry((
1021            src_buffer_tracker_index,
1022            src_dynamic_offset,
1023            dst_resource_index,
1024        )) {
1025            hashbrown::hash_map::Entry::Occupied(mut occupied_entry) => {
1026                occupied_entry.get_mut().entries.push(entry)
1027            }
1028            hashbrown::hash_map::Entry::Vacant(vacant_entry) => {
1029                vacant_entry.insert(DrawIndirectValidationBatch {
1030                    src_buffer: src_buffer.clone(),
1031                    src_dynamic_offset,
1032                    dst_resource_index,
1033                    entries: vec![entry],
1034
1035                    // these will be initialized once we accumulated all entries for the batch
1036                    staging_buffer_index: 0,
1037                    staging_buffer_offset: 0,
1038                    metadata_resource_index: 0,
1039                    metadata_buffer_offset: 0,
1040                });
1041            }
1042        }
1043
1044        Ok((dst_resource_index, dst_offset))
1045    }
1046}
1047
1048/// Indirect draw validation doesn't support u64 offsets.
1049///
1050/// This fn should never panic due to the assert in [`Draw::new`].
1051fn u64_offset_to_u32_offset(offset: u64) -> u32 {
1052    offset.try_into().unwrap()
1053}
1054
1055#[cfg(test)]
1056mod tests {
1057    use super::*;
1058
1059    #[test]
1060    fn calculate_src_offsets_test() {
1061        const MBUS: u64 = 256 << 20; // default max_buffer_size
1062        const MBIS: u64 = 128 << 20; // default max_storage_buffer_binding_size
1063
1064        #[rustfmt::skip]
1065        let cases: &[(u64, u64, u32, u64, u64, u64, u64)] = &[
1066            // (buffer_size, max_binding_size, offset_alignment, data_size, offset, out_dynamic_offset, out_offset)
1067
1068            // data at start of buffer
1069            (MBUS, MBIS, 32,  16, 0, 0, 0),
1070            // data at end of buffer
1071            (MBUS, MBIS, 32,  16, MBUS - 16, MBIS, MBIS - 16),
1072            // data at end of buffer, where buffer_size % alignment != 0
1073            (MBUS + 4, MBUS, 32, 16, MBUS + 4 - 16, 32, MBUS - 32 + 4 - 16),
1074            // data before/straddling/after middle of the buffer with
1075            // max binding size limit being half of the buffer size
1076            // alignment = 32
1077            (512, 256, 32,  16, 240, 224, 16), // before middle
1078            (512, 256, 32,  16, 248, 224, 24), // straddling middle
1079            (512, 256, 32,  16, 256, 224, 32), // after middle
1080            // alignment = 64
1081            (512, 256, 64,  16, 240, 192, 48), // before middle
1082            (512, 256, 64,  16, 248, 192, 56), // straddling middle
1083            (512, 256, 64,  16, 256, 192, 64), // after middle
1084            // alignment = 128
1085            (512, 256, 128, 16, 240, 128, 112), // before middle
1086            (512, 256, 128, 16, 248, 128, 120), // straddling middle
1087            (512, 256, 128, 16, 256, 256, 0), // after middle
1088            // as above but with data_size = 20
1089            // alignment = 32
1090            (512, 256, 32,  20, 236, 224, 12), // before middle
1091            (512, 256, 32,  20, 244, 224, 20), // straddling middle
1092            (512, 256, 32,  20, 252, 224, 28), // after middle
1093            // alignment = 64
1094            (512, 256, 64,  20, 236, 192, 44), // before middle
1095            (512, 256, 64,  20, 244, 192, 52), // straddling middle
1096            (512, 256, 64,  20, 252, 192, 60), // after middle
1097            // alignment = 128
1098            (512, 256, 128, 20, 236, 128, 108), // before middle
1099            (512, 256, 128, 20, 244, 128, 116), // straddling middle
1100            (512, 256, 128, 20, 252, 128, 124), // after middle
1101        ];
1102
1103        for &(
1104            buffer_size,
1105            max_storage_buffer_binding_size,
1106            min_storage_buffer_offset_alignment,
1107            data_size,
1108            offset,
1109            expected_out_dynamic_offset,
1110            expected_out_offset,
1111        ) in cases
1112        {
1113            let limits = Limits {
1114                max_storage_buffer_binding_size,
1115                min_storage_buffer_offset_alignment,
1116                ..Limits::default()
1117            };
1118            let (out_dynamic_offset, out_offset) =
1119                calculate_src_offsets(buffer_size, &limits, offset, data_size);
1120            let binding_size = calculate_src_buffer_binding_size(buffer_size, &limits);
1121            // check invariants
1122            assert_eq!(out_dynamic_offset + out_offset, offset);
1123            assert_eq!(
1124                out_dynamic_offset % min_storage_buffer_offset_alignment as u64,
1125                0
1126            );
1127            assert!(out_dynamic_offset + binding_size <= buffer_size);
1128            assert!(out_offset + data_size <= binding_size);
1129            // check output matches
1130            assert_eq!(
1131                (out_dynamic_offset, out_offset),
1132                (expected_out_dynamic_offset, expected_out_offset),
1133                "buffer_size={buffer_size} \
1134                 max_binding_size={max_storage_buffer_binding_size} \
1135                 offset_alignment={min_storage_buffer_offset_alignment} \
1136                 data_size={data_size} \
1137                 offset={offset}"
1138            );
1139        }
1140    }
1141}