wgpu_core/indirect_validation/
draw.rs

1use super::{
2    utils::{BufferBarrierScratch, BufferBarriers, UniqueIndexExt as _, UniqueIndexScratch},
3    CreateIndirectValidationPipelineError,
4};
5use crate::{
6    command::RenderPassErrorInner,
7    device::{queue::TempResource, Device, DeviceError},
8    hal_label,
9    lock::{rank, Mutex},
10    pipeline::{CreateComputePipelineError, CreateShaderModuleError},
11    resource::{RawResourceAccess as _, StagingBuffer, Trackable},
12    snatch::SnatchGuard,
13    track::TrackerIndex,
14    FastHashMap,
15};
16use alloc::{boxed::Box, string::ToString, sync::Arc, vec, vec::Vec};
17use core::{mem::size_of, num::NonZeroU64};
18use wgt::Limits;
19
20/// Note: This needs to be under:
21///
22/// default max_compute_workgroups_per_dimension * size_of::<wgt::DrawIndirectArgs>() * `workgroup_size` used by the shader
23///
24/// = (2^16 - 1) * 2^4 * 2^6
25///
26/// It is currently set to:
27///
28/// = (2^16 - 1) * 2^4
29///
30/// This is enough space for:
31///
32/// - 65535 [`wgt::DrawIndirectArgs`] / [`MetadataEntry`]
33/// - 52428 [`wgt::DrawIndexedIndirectArgs`]
34const BUFFER_SIZE: wgt::BufferSize = wgt::BufferSize::new(1_048_560).unwrap();
35
36/// Holds all device-level resources that are needed to validate indirect draws.
37///
38/// This machinery requires the following limits:
39///
40/// - max_bind_groups: 3,
41/// - max_dynamic_storage_buffers_per_pipeline_layout: 1,
42/// - max_storage_buffers_per_shader_stage: 3,
43/// - max_immediate_size: 8,
44///
45/// These are all indirectly satisfied by `DownlevelFlags::INDIRECT_EXECUTION`, which is also
46/// required for this module's functionality to work.
47#[derive(Debug)]
48pub(crate) struct Draw {
49    module: Box<dyn hal::DynShaderModule>,
50    metadata_bind_group_layout: Box<dyn hal::DynBindGroupLayout>,
51    src_bind_group_layout: Box<dyn hal::DynBindGroupLayout>,
52    dst_bind_group_layout: Box<dyn hal::DynBindGroupLayout>,
53    pipeline_layout: Box<dyn hal::DynPipelineLayout>,
54    pipeline: Box<dyn hal::DynComputePipeline>,
55
56    free_indirect_entries: Mutex<Vec<BufferPoolEntry>>,
57    free_metadata_entries: Mutex<Vec<BufferPoolEntry>>,
58}
59
60impl Draw {
61    pub(super) fn new(
62        device: &dyn hal::DynDevice,
63        required_features: &wgt::Features,
64        instance_flags: wgt::InstanceFlags,
65        backend: wgt::Backend,
66    ) -> Result<Self, CreateIndirectValidationPipelineError> {
67        let module = create_validation_module(device, instance_flags)?;
68
69        let metadata_bind_group_layout = create_bind_group_layout(
70            device,
71            true,
72            false,
73            BUFFER_SIZE,
74            hal_label(
75                Some("(wgpu internal) Indirect draw validation metadata bind group layout"),
76                instance_flags,
77            ),
78        )?;
79        let src_bind_group_layout = create_bind_group_layout(
80            device,
81            true,
82            true,
83            wgt::BufferSize::new(4 * 4).unwrap(),
84            hal_label(
85                Some("(wgpu internal) Indirect draw validation source bind group layout"),
86                instance_flags,
87            ),
88        )?;
89        let dst_bind_group_layout = create_bind_group_layout(
90            device,
91            false,
92            false,
93            BUFFER_SIZE,
94            hal_label(
95                Some("(wgpu internal) Indirect draw validation destination bind group layout"),
96                instance_flags,
97            ),
98        )?;
99
100        let pipeline_layout_desc = hal::PipelineLayoutDescriptor {
101            label: hal_label(
102                Some("(wgpu internal) Indirect draw validation pipeline layout"),
103                instance_flags,
104            ),
105            flags: hal::PipelineLayoutFlags::empty(),
106            bind_group_layouts: &[
107                Some(metadata_bind_group_layout.as_ref()),
108                Some(src_bind_group_layout.as_ref()),
109                Some(dst_bind_group_layout.as_ref()),
110            ],
111            immediate_size: 8,
112        };
113        let pipeline_layout = unsafe {
114            device
115                .create_pipeline_layout(&pipeline_layout_desc)
116                .map_err(DeviceError::from_hal)?
117        };
118
119        let supports_indirect_first_instance =
120            required_features.contains(wgt::Features::INDIRECT_FIRST_INSTANCE);
121        let write_d3d12_special_constants = backend == wgt::Backend::Dx12;
122        let pipeline = create_validation_pipeline(
123            device,
124            module.as_ref(),
125            pipeline_layout.as_ref(),
126            supports_indirect_first_instance,
127            write_d3d12_special_constants,
128            instance_flags,
129        )?;
130
131        Ok(Self {
132            module,
133            metadata_bind_group_layout,
134            src_bind_group_layout,
135            dst_bind_group_layout,
136            pipeline_layout,
137            pipeline,
138
139            free_indirect_entries: Mutex::new(rank::BUFFER_POOL, Vec::new()),
140            free_metadata_entries: Mutex::new(rank::BUFFER_POOL, Vec::new()),
141        })
142    }
143
144    /// `Ok(None)` will only be returned if `buffer_size` is `0`.
145    pub(super) fn create_src_bind_group(
146        &self,
147        device: &dyn hal::DynDevice,
148        limits: &Limits,
149        buffer_size: u64,
150        buffer: &dyn hal::DynBuffer,
151        instance_flags: wgt::InstanceFlags,
152    ) -> Result<Option<Box<dyn hal::DynBindGroup>>, DeviceError> {
153        let binding_size = calculate_src_buffer_binding_size(buffer_size, limits);
154        let Some(binding_size) = NonZeroU64::new(binding_size) else {
155            return Ok(None);
156        };
157        let hal_desc = hal::BindGroupDescriptor {
158            label: hal_label(
159                Some("(wgpu internal) Indirect draw validation source bind group"),
160                instance_flags,
161            ),
162            layout: self.src_bind_group_layout.as_ref(),
163            entries: &[hal::BindGroupEntry {
164                binding: 0,
165                resource_index: 0,
166                count: 1,
167            }],
168            // SAFETY: We calculated the binding size to fit within the buffer.
169            buffers: &[hal::BufferBinding::new_unchecked(buffer, 0, binding_size)],
170            samplers: &[],
171            textures: &[],
172            acceleration_structures: &[],
173            external_textures: &[],
174        };
175        unsafe {
176            device
177                .create_bind_group(&hal_desc)
178                .map(Some)
179                .map_err(DeviceError::from_hal)
180        }
181    }
182
183    fn acquire_dst_entry(
184        &self,
185        device: &dyn hal::DynDevice,
186        instance_flags: wgt::InstanceFlags,
187    ) -> Result<BufferPoolEntry, hal::DeviceError> {
188        let mut free_buffers = self.free_indirect_entries.lock();
189        match free_buffers.pop() {
190            Some(buffer) => Ok(buffer),
191            None => {
192                let usage = wgt::BufferUses::INDIRECT | wgt::BufferUses::STORAGE_READ_WRITE;
193                create_buffer_and_bind_group(
194                    device,
195                    usage,
196                    self.dst_bind_group_layout.as_ref(),
197                    hal_label(Some("(wgpu internal) Indirect draw validation destination buffer"), instance_flags),
198                    hal_label(Some("(wgpu internal) Indirect draw validation destination bind group layout"), instance_flags),
199                )
200            }
201        }
202    }
203
204    fn release_dst_entries(&self, entries: impl Iterator<Item = BufferPoolEntry>) {
205        self.free_indirect_entries.lock().extend(entries);
206    }
207
208    fn acquire_metadata_entry(
209        &self,
210        device: &dyn hal::DynDevice,
211        instance_flags: wgt::InstanceFlags,
212    ) -> Result<BufferPoolEntry, hal::DeviceError> {
213        let mut free_buffers = self.free_metadata_entries.lock();
214        match free_buffers.pop() {
215            Some(buffer) => Ok(buffer),
216            None => {
217                let usage = wgt::BufferUses::COPY_DST | wgt::BufferUses::STORAGE_READ_ONLY;
218                create_buffer_and_bind_group(
219                    device,
220                    usage,
221                    self.metadata_bind_group_layout.as_ref(),
222                    hal_label(
223                        Some("(wgpu internal) Indirect draw validation metadata buffer"),
224                        instance_flags,
225                    ),
226                    hal_label(
227                        Some("(wgpu internal) Indirect draw validation metadata bind group layout"),
228                        instance_flags,
229                    ),
230                )
231            }
232        }
233    }
234
235    fn release_metadata_entries(&self, entries: impl Iterator<Item = BufferPoolEntry>) {
236        self.free_metadata_entries.lock().extend(entries);
237    }
238
239    /// Injects a compute pass that will validate all indirect draws in the current render pass.
240    pub(crate) fn inject_validation_pass(
241        &self,
242        device: &Arc<Device>,
243        snatch_guard: &SnatchGuard,
244        resources: &mut DrawResources,
245        temp_resources: &mut Vec<TempResource>,
246        encoder: &mut dyn hal::DynCommandEncoder,
247        batcher: DrawBatcher,
248    ) -> Result<(), RenderPassErrorInner> {
249        let mut batches = batcher.batches;
250
251        if batches.is_empty() {
252            return Ok(());
253        }
254
255        let max_staging_buffer_size = 1 << 26; // ~67MiB
256
257        let mut staging_buffers = Vec::new();
258
259        let mut current_size = 0;
260        for batch in batches.values_mut() {
261            let data = batch.metadata();
262            let offset = if current_size + data.len() > max_staging_buffer_size {
263                let staging_buffer =
264                    StagingBuffer::new(device, NonZeroU64::new(current_size as u64).unwrap())?;
265                staging_buffers.push(staging_buffer);
266                current_size = data.len();
267                0
268            } else {
269                let offset = current_size;
270                current_size += data.len();
271                offset as u64
272            };
273            batch.staging_buffer_index = staging_buffers.len();
274            batch.staging_buffer_offset = offset;
275        }
276        if current_size != 0 {
277            let staging_buffer =
278                StagingBuffer::new(device, NonZeroU64::new(current_size as u64).unwrap())?;
279            staging_buffers.push(staging_buffer);
280        }
281
282        for batch in batches.values() {
283            let data = batch.metadata();
284            let staging_buffer = &mut staging_buffers[batch.staging_buffer_index];
285            unsafe {
286                staging_buffer.write_with_offset(
287                    data,
288                    0,
289                    batch.staging_buffer_offset as isize,
290                    data.len(),
291                )
292            };
293        }
294
295        let staging_buffers: Vec<_> = staging_buffers
296            .into_iter()
297            .map(|buffer| buffer.flush())
298            .collect();
299
300        let mut current_metadata_entry = None;
301        for batch in batches.values_mut() {
302            let data = batch.metadata();
303            let (metadata_resource_index, metadata_buffer_offset) =
304                resources.get_metadata_subrange(data.len() as u64, &mut current_metadata_entry)?;
305            batch.metadata_resource_index = metadata_resource_index;
306            batch.metadata_buffer_offset = metadata_buffer_offset;
307        }
308
309        let buffer_barrier_scratch = &mut BufferBarrierScratch::new();
310        let unique_index_scratch = &mut UniqueIndexScratch::new();
311
312        BufferBarriers::new(buffer_barrier_scratch)
313            .extend(
314                batches
315                    .values()
316                    .map(|batch| batch.staging_buffer_index)
317                    .unique(unique_index_scratch)
318                    .map(|index| hal::BufferBarrier {
319                        buffer: staging_buffers[index].raw(),
320                        usage: hal::StateTransition {
321                            from: wgt::BufferUses::MAP_WRITE,
322                            to: wgt::BufferUses::COPY_SRC,
323                        },
324                    }),
325            )
326            .extend(
327                batches
328                    .values()
329                    .map(|batch| batch.metadata_resource_index)
330                    .unique(unique_index_scratch)
331                    .map(|index| hal::BufferBarrier {
332                        buffer: resources.get_metadata_buffer(index),
333                        usage: hal::StateTransition {
334                            from: wgt::BufferUses::STORAGE_READ_ONLY,
335                            to: wgt::BufferUses::COPY_DST,
336                        },
337                    }),
338            )
339            .encode(encoder);
340
341        for batch in batches.values() {
342            let data = batch.metadata();
343            let data_size = NonZeroU64::new(data.len() as u64).unwrap();
344
345            let staging_buffer = &staging_buffers[batch.staging_buffer_index];
346
347            let metadata_buffer = resources.get_metadata_buffer(batch.metadata_resource_index);
348
349            unsafe {
350                encoder.copy_buffer_to_buffer(
351                    staging_buffer.raw(),
352                    metadata_buffer,
353                    &[hal::BufferCopy {
354                        src_offset: batch.staging_buffer_offset,
355                        dst_offset: batch.metadata_buffer_offset,
356                        size: data_size,
357                    }],
358                );
359            }
360        }
361
362        for staging_buffer in staging_buffers {
363            temp_resources.push(TempResource::StagingBuffer(staging_buffer));
364        }
365
366        BufferBarriers::new(buffer_barrier_scratch)
367            .extend(
368                batches
369                    .values()
370                    .map(|batch| batch.metadata_resource_index)
371                    .unique(unique_index_scratch)
372                    .map(|index| hal::BufferBarrier {
373                        buffer: resources.get_metadata_buffer(index),
374                        usage: hal::StateTransition {
375                            from: wgt::BufferUses::COPY_DST,
376                            to: wgt::BufferUses::STORAGE_READ_ONLY,
377                        },
378                    }),
379            )
380            .extend(
381                batches
382                    .values()
383                    .map(|batch| batch.dst_resource_index)
384                    .unique(unique_index_scratch)
385                    .map(|index| hal::BufferBarrier {
386                        buffer: resources.get_dst_buffer(index),
387                        usage: hal::StateTransition {
388                            from: wgt::BufferUses::INDIRECT,
389                            to: wgt::BufferUses::STORAGE_READ_WRITE,
390                        },
391                    }),
392            )
393            .encode(encoder);
394
395        let desc = hal::ComputePassDescriptor {
396            label: hal_label(
397                Some("(wgpu internal) Indirect draw validation pass"),
398                device.instance_flags,
399            ),
400            timestamp_writes: None,
401        };
402        unsafe {
403            encoder.begin_compute_pass(&desc);
404        }
405        unsafe {
406            encoder.set_compute_pipeline(self.pipeline.as_ref());
407        }
408
409        for batch in batches.values() {
410            let pipeline_layout = self.pipeline_layout.as_ref();
411
412            let metadata_start =
413                (batch.metadata_buffer_offset / size_of::<MetadataEntry>() as u64) as u32;
414            let metadata_count = batch.entries.len() as u32;
415            unsafe {
416                encoder.set_immediates(pipeline_layout, 0, &[metadata_start, metadata_count]);
417            }
418
419            let metadata_bind_group =
420                resources.get_metadata_bind_group(batch.metadata_resource_index);
421            unsafe {
422                encoder.set_bind_group(pipeline_layout, 0, metadata_bind_group, &[]);
423            }
424
425            // Make sure the indirect buffer is still valid.
426            batch.src_buffer.try_raw(snatch_guard)?;
427
428            let src_bind_group = batch
429                .src_buffer
430                .indirect_validation_bind_groups
431                .get(snatch_guard)
432                .unwrap()
433                .draw
434                .as_ref();
435            unsafe {
436                encoder.set_bind_group(
437                    pipeline_layout,
438                    1,
439                    src_bind_group,
440                    &[batch.src_dynamic_offset as u32],
441                );
442            }
443
444            let dst_bind_group = resources.get_dst_bind_group(batch.dst_resource_index);
445            unsafe {
446                encoder.set_bind_group(pipeline_layout, 2, dst_bind_group, &[]);
447            }
448
449            unsafe {
450                encoder.dispatch([(batch.entries.len() as u32).div_ceil(64), 1, 1]);
451            }
452        }
453
454        unsafe {
455            encoder.end_compute_pass();
456        }
457
458        BufferBarriers::new(buffer_barrier_scratch)
459            .extend(
460                batches
461                    .values()
462                    .map(|batch| batch.dst_resource_index)
463                    .unique(unique_index_scratch)
464                    .map(|index| hal::BufferBarrier {
465                        buffer: resources.get_dst_buffer(index),
466                        usage: hal::StateTransition {
467                            from: wgt::BufferUses::STORAGE_READ_WRITE,
468                            to: wgt::BufferUses::INDIRECT,
469                        },
470                    }),
471            )
472            .encode(encoder);
473
474        Ok(())
475    }
476
477    pub(super) fn dispose(self, device: &dyn hal::DynDevice) {
478        let Draw {
479            module,
480            metadata_bind_group_layout,
481            src_bind_group_layout,
482            dst_bind_group_layout,
483            pipeline_layout,
484            pipeline,
485
486            free_indirect_entries,
487            free_metadata_entries,
488        } = self;
489
490        for entry in free_indirect_entries.into_inner().drain(..) {
491            unsafe {
492                device.destroy_bind_group(entry.bind_group);
493                device.destroy_buffer(entry.buffer);
494            }
495        }
496
497        for entry in free_metadata_entries.into_inner().drain(..) {
498            unsafe {
499                device.destroy_bind_group(entry.bind_group);
500                device.destroy_buffer(entry.buffer);
501            }
502        }
503
504        unsafe {
505            device.destroy_compute_pipeline(pipeline);
506            device.destroy_pipeline_layout(pipeline_layout);
507            device.destroy_bind_group_layout(metadata_bind_group_layout);
508            device.destroy_bind_group_layout(src_bind_group_layout);
509            device.destroy_bind_group_layout(dst_bind_group_layout);
510            device.destroy_shader_module(module);
511        }
512    }
513}
514
515fn create_validation_module(
516    device: &dyn hal::DynDevice,
517    instance_flags: wgt::InstanceFlags,
518) -> Result<Box<dyn hal::DynShaderModule>, CreateIndirectValidationPipelineError> {
519    let src = include_str!("./validate_draw.wgsl");
520
521    #[cfg(feature = "wgsl")]
522    let module = naga::front::wgsl::parse_str(src).map_err(|inner| {
523        CreateShaderModuleError::Parsing(naga::error::ShaderError {
524            source: src.to_string(),
525            label: None,
526            inner: Box::new(inner),
527        })
528    })?;
529    #[cfg(not(feature = "wgsl"))]
530    #[allow(clippy::diverging_sub_expression)]
531    let module = panic!("Indirect validation requires the wgsl feature flag to be enabled!");
532
533    let info = crate::device::create_validator(
534        wgt::Features::IMMEDIATES,
535        wgt::DownlevelFlags::empty(),
536        naga::valid::ValidationFlags::all(),
537    )
538    .validate(&module)
539    .map_err(|inner| {
540        CreateShaderModuleError::Validation(naga::error::ShaderError {
541            source: src.to_string(),
542            label: None,
543            inner: Box::new(inner),
544        })
545    })?;
546    let hal_shader = hal::ShaderInput::Naga(hal::NagaShader {
547        module: alloc::borrow::Cow::Owned(module),
548        info,
549        debug_source: None,
550    });
551    let hal_desc = hal::ShaderModuleDescriptor {
552        label: hal_label(
553            Some("(wgpu internal) Indirect draw validation shader module"),
554            instance_flags,
555        ),
556        runtime_checks: wgt::ShaderRuntimeChecks::unchecked(),
557    };
558    let module = unsafe { device.create_shader_module(&hal_desc, hal_shader) }.map_err(
559        |error| match error {
560            hal::ShaderError::Device(error) => {
561                CreateShaderModuleError::Device(DeviceError::from_hal(error))
562            }
563            hal::ShaderError::Compilation(ref msg) => {
564                log::error!("Shader error: {msg}");
565                CreateShaderModuleError::Generation
566            }
567        },
568    )?;
569
570    Ok(module)
571}
572
573fn create_validation_pipeline(
574    device: &dyn hal::DynDevice,
575    module: &dyn hal::DynShaderModule,
576    pipeline_layout: &dyn hal::DynPipelineLayout,
577    supports_indirect_first_instance: bool,
578    write_d3d12_special_constants: bool,
579    instance_flags: wgt::InstanceFlags,
580) -> Result<Box<dyn hal::DynComputePipeline>, CreateIndirectValidationPipelineError> {
581    let pipeline_desc = hal::ComputePipelineDescriptor {
582        label: hal_label(
583            Some("(wgpu internal) Indirect draw validation pipeline"),
584            instance_flags,
585        ),
586        layout: pipeline_layout,
587        stage: hal::ProgrammableStage {
588            module,
589            entry_point: "main",
590            constants: &hashbrown::HashMap::from([
591                (
592                    "supports_indirect_first_instance".to_string(),
593                    f64::from(supports_indirect_first_instance),
594                ),
595                (
596                    "write_d3d12_special_constants".to_string(),
597                    f64::from(write_d3d12_special_constants),
598                ),
599            ]),
600            zero_initialize_workgroup_memory: false,
601        },
602        cache: None,
603    };
604    let pipeline =
605        unsafe { device.create_compute_pipeline(&pipeline_desc) }.map_err(|err| match err {
606            hal::PipelineError::Device(error) => {
607                CreateComputePipelineError::Device(DeviceError::from_hal(error))
608            }
609            hal::PipelineError::Linkage(_stages, msg) => CreateComputePipelineError::Internal(msg),
610            hal::PipelineError::EntryPoint(_stage) => CreateComputePipelineError::Internal(
611                crate::device::ENTRYPOINT_FAILURE_ERROR.to_string(),
612            ),
613            hal::PipelineError::PipelineConstants(_, error) => {
614                CreateComputePipelineError::PipelineConstants(error)
615            }
616        })?;
617
618    Ok(pipeline)
619}
620
621fn create_bind_group_layout(
622    device: &dyn hal::DynDevice,
623    read_only: bool,
624    has_dynamic_offset: bool,
625    min_binding_size: wgt::BufferSize,
626    label: Option<&'static str>,
627) -> Result<Box<dyn hal::DynBindGroupLayout>, CreateIndirectValidationPipelineError> {
628    let bind_group_layout_desc = hal::BindGroupLayoutDescriptor {
629        label,
630        flags: hal::BindGroupLayoutFlags::empty(),
631        entries: &[wgt::BindGroupLayoutEntry {
632            binding: 0,
633            visibility: wgt::ShaderStages::COMPUTE,
634            ty: wgt::BindingType::Buffer {
635                ty: wgt::BufferBindingType::Storage { read_only },
636                has_dynamic_offset,
637                min_binding_size: Some(min_binding_size),
638            },
639            count: None,
640        }],
641    };
642    let bind_group_layout = unsafe {
643        device
644            .create_bind_group_layout(&bind_group_layout_desc)
645            .map_err(DeviceError::from_hal)?
646    };
647
648    Ok(bind_group_layout)
649}
650
651/// Returns the largest binding size that when combined with dynamic offsets can address the whole buffer.
652fn calculate_src_buffer_binding_size(buffer_size: u64, limits: &Limits) -> u64 {
653    let max_storage_buffer_binding_size = limits.max_storage_buffer_binding_size;
654    let min_storage_buffer_offset_alignment = limits.min_storage_buffer_offset_alignment as u64;
655
656    if buffer_size <= max_storage_buffer_binding_size {
657        buffer_size
658    } else {
659        let buffer_rem = buffer_size % min_storage_buffer_offset_alignment;
660        let binding_rem = max_storage_buffer_binding_size % min_storage_buffer_offset_alignment;
661
662        // Can the buffer remainder fit in the binding remainder?
663        // If so, align max binding size and add buffer remainder
664        if buffer_rem <= binding_rem {
665            max_storage_buffer_binding_size - binding_rem + buffer_rem
666        }
667        // If not, align max binding size, shorten it by a chunk and add buffer remainder
668        else {
669            max_storage_buffer_binding_size - binding_rem - min_storage_buffer_offset_alignment
670                + buffer_rem
671        }
672    }
673}
674
675/// Splits the given `offset` into a dynamic offset & offset.
676fn calculate_src_offsets(buffer_size: u64, limits: &Limits, offset: u64) -> (u64, u64) {
677    let binding_size = calculate_src_buffer_binding_size(buffer_size, limits);
678
679    let min_storage_buffer_offset_alignment = limits.min_storage_buffer_offset_alignment as u64;
680
681    let chunk_adjustment = match min_storage_buffer_offset_alignment {
682        // No need to adjust since the src_offset is 4 byte aligned.
683        4 => 0,
684        // With 16/20 bytes of data we can straddle up to 2 8 byte boundaries:
685        //  - 16 bytes of data: (4|8|4)
686        //  - 20 bytes of data: (4|8|8, 8|8|4)
687        8 => 2,
688        // With 16/20 bytes of data we can straddle up to 1 16+ byte boundary:
689        //  - 16 bytes of data: (4|12, 8|8, 12|4)
690        //  - 20 bytes of data: (4|16, 8|12, 12|8, 16|4)
691        16.. => 1,
692        _ => unreachable!(),
693    };
694
695    let chunks = binding_size / min_storage_buffer_offset_alignment;
696    let dynamic_offset_stride =
697        chunks.saturating_sub(chunk_adjustment) * min_storage_buffer_offset_alignment;
698
699    if dynamic_offset_stride == 0 {
700        return (0, offset);
701    }
702
703    let max_dynamic_offset = buffer_size - binding_size;
704    let max_dynamic_offset_index = max_dynamic_offset / dynamic_offset_stride;
705
706    let src_dynamic_offset_index = offset / dynamic_offset_stride;
707
708    let src_dynamic_offset =
709        src_dynamic_offset_index.min(max_dynamic_offset_index) * dynamic_offset_stride;
710    let src_offset = offset - src_dynamic_offset;
711
712    (src_dynamic_offset, src_offset)
713}
714
715#[derive(Debug)]
716struct BufferPoolEntry {
717    buffer: Box<dyn hal::DynBuffer>,
718    bind_group: Box<dyn hal::DynBindGroup>,
719}
720
721fn create_buffer_and_bind_group(
722    device: &dyn hal::DynDevice,
723    usage: wgt::BufferUses,
724    bind_group_layout: &dyn hal::DynBindGroupLayout,
725    buffer_label: Option<&'static str>,
726    bind_group_label: Option<&'static str>,
727) -> Result<BufferPoolEntry, hal::DeviceError> {
728    let buffer_desc = hal::BufferDescriptor {
729        label: buffer_label,
730        size: BUFFER_SIZE.get(),
731        usage,
732        memory_flags: hal::MemoryFlags::empty(),
733    };
734    let buffer = unsafe { device.create_buffer(&buffer_desc) }?;
735    let bind_group_desc = hal::BindGroupDescriptor {
736        label: bind_group_label,
737        layout: bind_group_layout,
738        entries: &[hal::BindGroupEntry {
739            binding: 0,
740            resource_index: 0,
741            count: 1,
742        }],
743        // SAFETY: We just created the buffer with this size.
744        buffers: &[hal::BufferBinding::new_unchecked(
745            buffer.as_ref(),
746            0,
747            BUFFER_SIZE,
748        )],
749        samplers: &[],
750        textures: &[],
751        acceleration_structures: &[],
752        external_textures: &[],
753    };
754    let bind_group = unsafe { device.create_bind_group(&bind_group_desc) }?;
755    Ok(BufferPoolEntry { buffer, bind_group })
756}
757
758#[derive(Clone)]
759struct CurrentEntry {
760    index: usize,
761    offset: u64,
762}
763
764/// Holds all command buffer-level resources that are needed to validate indirect draws.
765pub(crate) struct DrawResources {
766    device: Arc<Device>,
767    dst_entries: Vec<BufferPoolEntry>,
768    metadata_entries: Vec<BufferPoolEntry>,
769}
770
771impl Drop for DrawResources {
772    fn drop(&mut self) {
773        if let Some(ref indirect_validation) = self.device.indirect_validation {
774            let indirect_draw_validation = &indirect_validation.draw;
775            indirect_draw_validation.release_dst_entries(self.dst_entries.drain(..));
776            indirect_draw_validation.release_metadata_entries(self.metadata_entries.drain(..));
777        }
778    }
779}
780
781impl DrawResources {
782    pub(crate) fn new(device: Arc<Device>) -> Self {
783        DrawResources {
784            device,
785            dst_entries: Vec::new(),
786            metadata_entries: Vec::new(),
787        }
788    }
789
790    pub(crate) fn get_dst_buffer(&self, index: usize) -> &dyn hal::DynBuffer {
791        self.dst_entries.get(index).unwrap().buffer.as_ref()
792    }
793
794    fn get_dst_bind_group(&self, index: usize) -> &dyn hal::DynBindGroup {
795        self.dst_entries.get(index).unwrap().bind_group.as_ref()
796    }
797
798    fn get_metadata_buffer(&self, index: usize) -> &dyn hal::DynBuffer {
799        self.metadata_entries.get(index).unwrap().buffer.as_ref()
800    }
801
802    fn get_metadata_bind_group(&self, index: usize) -> &dyn hal::DynBindGroup {
803        self.metadata_entries
804            .get(index)
805            .unwrap()
806            .bind_group
807            .as_ref()
808    }
809
810    fn get_dst_subrange(
811        &mut self,
812        size: u64,
813        current_entry: &mut Option<CurrentEntry>,
814    ) -> Result<(usize, u64), DeviceError> {
815        let indirect_draw_validation = &self.device.indirect_validation.as_ref().unwrap().draw;
816        let ensure_entry = |index: usize| {
817            if self.dst_entries.len() <= index {
818                let entry = indirect_draw_validation
819                    .acquire_dst_entry(self.device.raw(), self.device.instance_flags)?;
820                self.dst_entries.push(entry);
821            }
822            Ok(())
823        };
824        let entry_data = Self::get_subrange_impl(ensure_entry, current_entry, size)?;
825        Ok((entry_data.index, entry_data.offset))
826    }
827
828    fn get_metadata_subrange(
829        &mut self,
830        size: u64,
831        current_entry: &mut Option<CurrentEntry>,
832    ) -> Result<(usize, u64), DeviceError> {
833        let indirect_draw_validation = &self.device.indirect_validation.as_ref().unwrap().draw;
834        let ensure_entry = |index: usize| {
835            if self.metadata_entries.len() <= index {
836                let entry = indirect_draw_validation
837                    .acquire_metadata_entry(self.device.raw(), self.device.instance_flags)?;
838                self.metadata_entries.push(entry);
839            }
840            Ok(())
841        };
842        let entry_data = Self::get_subrange_impl(ensure_entry, current_entry, size)?;
843        Ok((entry_data.index, entry_data.offset))
844    }
845
846    fn get_subrange_impl(
847        ensure_entry: impl FnOnce(usize) -> Result<(), hal::DeviceError>,
848        current_entry: &mut Option<CurrentEntry>,
849        size: u64,
850    ) -> Result<CurrentEntry, DeviceError> {
851        let index = if let Some(current_entry) = current_entry.as_mut() {
852            if current_entry.offset + size <= BUFFER_SIZE.get() {
853                let entry_data = current_entry.clone();
854                current_entry.offset += size;
855                return Ok(entry_data);
856            } else {
857                current_entry.index + 1
858            }
859        } else {
860            0
861        };
862
863        ensure_entry(index).map_err(DeviceError::from_hal)?;
864
865        let entry_data = CurrentEntry { index, offset: 0 };
866
867        *current_entry = Some(CurrentEntry {
868            index,
869            offset: size,
870        });
871
872        Ok(entry_data)
873    }
874}
875
876/// This must match the `MetadataEntry` struct used by the shader.
877#[repr(C)]
878struct MetadataEntry {
879    src_offset: u32,
880    dst_offset: u32,
881    vertex_or_index_limit: u32,
882    instance_limit: u32,
883}
884
885impl MetadataEntry {
886    fn new(
887        indexed: bool,
888        src_offset: u64,
889        dst_offset: u64,
890        vertex_or_index_limit: u64,
891        instance_limit: u64,
892    ) -> Self {
893        const U32_MAX_AS_U64: u64 = u32::MAX as u64;
894
895        // NOTE: buffer sizes should never exceed `u32::MAX`.
896        assert!(src_offset <= U32_MAX_AS_U64);
897        assert!(dst_offset <= U32_MAX_AS_U64);
898
899        let src_offset = src_offset as u32;
900        let src_offset = src_offset / 4; // translate byte offset to offset in u32's
901
902        // `src_offset` needs at most 30 bits,
903        // pack `indexed` in bit 31 of `src_offset`
904        let src_offset = src_offset | ((indexed as u32) << 31);
905
906        // max value for limits since first_X and X_count indirect draw arguments are u32
907        let max_limit = U32_MAX_AS_U64 + U32_MAX_AS_U64; // 1 11111111 11111111 11111111 11111110
908
909        let vertex_or_index_limit = vertex_or_index_limit.min(max_limit);
910        let vertex_or_index_limit_bit_32 = (vertex_or_index_limit >> 32) as u32; // extract bit 32
911        let vertex_or_index_limit = vertex_or_index_limit as u32; // truncate the limit to a u32
912
913        let instance_limit = instance_limit.min(max_limit);
914        let instance_limit_bit_32 = (instance_limit >> 32) as u32; // extract bit 32
915        let instance_limit = instance_limit as u32; // truncate the limit to a u32
916
917        let dst_offset = dst_offset as u32;
918        let dst_offset = dst_offset / 4; // translate byte offset to offset in u32's
919
920        // `dst_offset` needs at most 30 bits,
921        // pack `vertex_or_index_limit_bit_32` in bit 30 of `dst_offset` and
922        // pack `instance_limit_bit_32` in bit 31 of `dst_offset`
923        let dst_offset =
924            dst_offset | (vertex_or_index_limit_bit_32 << 30) | (instance_limit_bit_32 << 31);
925
926        Self {
927            src_offset,
928            dst_offset,
929            vertex_or_index_limit,
930            instance_limit,
931        }
932    }
933}
934
935struct DrawIndirectValidationBatch {
936    src_buffer: Arc<crate::resource::Buffer>,
937    src_dynamic_offset: u64,
938    dst_resource_index: usize,
939    entries: Vec<MetadataEntry>,
940
941    staging_buffer_index: usize,
942    staging_buffer_offset: u64,
943    metadata_resource_index: usize,
944    metadata_buffer_offset: u64,
945}
946
947impl DrawIndirectValidationBatch {
948    /// Data to be written to the metadata buffer.
949    fn metadata(&self) -> &[u8] {
950        unsafe {
951            core::slice::from_raw_parts(
952                self.entries.as_ptr().cast::<u8>(),
953                self.entries.len() * size_of::<MetadataEntry>(),
954            )
955        }
956    }
957}
958
959/// Accumulates all needed data needed to validate indirect draws.
960pub(crate) struct DrawBatcher {
961    batches: FastHashMap<(TrackerIndex, u64, usize), DrawIndirectValidationBatch>,
962    current_dst_entry: Option<CurrentEntry>,
963}
964
965impl DrawBatcher {
966    pub(crate) fn new() -> Self {
967        Self {
968            batches: FastHashMap::default(),
969            current_dst_entry: None,
970        }
971    }
972
973    /// Add an indirect draw to be validated.
974    ///
975    /// Returns the index of the indirect buffer in `indirect_draw_validation_resources`
976    /// and the offset to be used for the draw.
977    pub(crate) fn add<'a>(
978        &mut self,
979        indirect_draw_validation_resources: &'a mut DrawResources,
980        device: &Device,
981        src_buffer: &Arc<crate::resource::Buffer>,
982        offset: u64,
983        family: crate::command::DrawCommandFamily,
984        vertex_or_index_limit: u64,
985        instance_limit: u64,
986    ) -> Result<(usize, u64), DeviceError> {
987        // space for D3D12 special constants
988        let extra = if device.backend() == wgt::Backend::Dx12 {
989            3 * size_of::<u32>() as u64
990        } else {
991            0
992        };
993        let stride = extra + crate::command::get_stride_of_indirect_args(family);
994
995        let (dst_resource_index, dst_offset) = indirect_draw_validation_resources
996            .get_dst_subrange(stride, &mut self.current_dst_entry)?;
997
998        let buffer_size = src_buffer.size;
999        let limits = device.adapter.limits();
1000        let (src_dynamic_offset, src_offset) = calculate_src_offsets(buffer_size, &limits, offset);
1001
1002        let src_buffer_tracker_index = src_buffer.tracker_index();
1003
1004        let entry = MetadataEntry::new(
1005            family == crate::command::DrawCommandFamily::DrawIndexed,
1006            src_offset,
1007            dst_offset,
1008            vertex_or_index_limit,
1009            instance_limit,
1010        );
1011
1012        match self.batches.entry((
1013            src_buffer_tracker_index,
1014            src_dynamic_offset,
1015            dst_resource_index,
1016        )) {
1017            hashbrown::hash_map::Entry::Occupied(mut occupied_entry) => {
1018                occupied_entry.get_mut().entries.push(entry)
1019            }
1020            hashbrown::hash_map::Entry::Vacant(vacant_entry) => {
1021                vacant_entry.insert(DrawIndirectValidationBatch {
1022                    src_buffer: src_buffer.clone(),
1023                    src_dynamic_offset,
1024                    dst_resource_index,
1025                    entries: vec![entry],
1026
1027                    // these will be initialized once we accumulated all entries for the batch
1028                    staging_buffer_index: 0,
1029                    staging_buffer_offset: 0,
1030                    metadata_resource_index: 0,
1031                    metadata_buffer_offset: 0,
1032                });
1033            }
1034        }
1035
1036        Ok((dst_resource_index, dst_offset))
1037    }
1038}