wgpu_core/indirect_validation/
draw.rs

1use super::{
2    utils::{BufferBarrierScratch, BufferBarriers, UniqueIndexExt as _, UniqueIndexScratch},
3    CreateIndirectValidationPipelineError,
4};
5use crate::{
6    command::RenderPassErrorInner,
7    device::{queue::TempResource, Device, DeviceError},
8    hal_label,
9    lock::{rank, Mutex},
10    pipeline::{CreateComputePipelineError, CreateShaderModuleError},
11    resource::{RawResourceAccess as _, StagingBuffer, Trackable},
12    snatch::SnatchGuard,
13    track::TrackerIndex,
14    FastHashMap,
15};
16use alloc::{boxed::Box, string::ToString, sync::Arc, vec, vec::Vec};
17use core::{
18    mem::{size_of, size_of_val},
19    num::NonZeroU64,
20};
21use wgt::Limits;
22
23/// Note: This needs to be under:
24///
25/// default max_compute_workgroups_per_dimension * size_of::<wgt::DrawIndirectArgs>() * `workgroup_size` used by the shader
26///
27/// = (2^16 - 1) * 2^4 * 2^6
28///
29/// It is currently set to:
30///
31/// = (2^16 - 1) * 2^4
32///
33/// This is enough space for:
34///
35/// - 65535 [`wgt::DrawIndirectArgs`] / [`MetadataEntry`]
36/// - 52428 [`wgt::DrawIndexedIndirectArgs`]
37const BUFFER_SIZE: wgt::BufferSize = wgt::BufferSize::new(1_048_560).unwrap();
38
39/// Holds all device-level resources that are needed to validate indirect draws.
40///
41/// This machinery requires the following limits:
42///
43/// - max_bind_groups: 3,
44/// - max_dynamic_storage_buffers_per_pipeline_layout: 1,
45/// - max_storage_buffers_per_shader_stage: 3,
46/// - max_immediate_size: 8,
47///
48/// These are all indirectly satisfied by `DownlevelFlags::INDIRECT_EXECUTION`, which is also
49/// required for this module's functionality to work.
50#[derive(Debug)]
51pub(crate) struct Draw {
52    module: Box<dyn hal::DynShaderModule>,
53    metadata_bind_group_layout: Box<dyn hal::DynBindGroupLayout>,
54    src_bind_group_layout: Box<dyn hal::DynBindGroupLayout>,
55    dst_bind_group_layout: Box<dyn hal::DynBindGroupLayout>,
56    pipeline_layout: Box<dyn hal::DynPipelineLayout>,
57    pipeline: Box<dyn hal::DynComputePipeline>,
58
59    free_indirect_entries: Mutex<Vec<BufferPoolEntry>>,
60    free_metadata_entries: Mutex<Vec<BufferPoolEntry>>,
61}
62
63impl Draw {
64    pub(super) fn new(
65        device: &dyn hal::DynDevice,
66        required_features: &wgt::Features,
67        instance_flags: wgt::InstanceFlags,
68        backend: wgt::Backend,
69    ) -> Result<Self, CreateIndirectValidationPipelineError> {
70        let module = create_validation_module(device, instance_flags)?;
71
72        let metadata_bind_group_layout = create_bind_group_layout(
73            device,
74            true,
75            false,
76            BUFFER_SIZE,
77            hal_label(
78                Some("(wgpu internal) Indirect draw validation metadata bind group layout"),
79                instance_flags,
80            ),
81        )?;
82        let src_bind_group_layout = create_bind_group_layout(
83            device,
84            true,
85            true,
86            wgt::BufferSize::new(4 * 4).unwrap(),
87            hal_label(
88                Some("(wgpu internal) Indirect draw validation source bind group layout"),
89                instance_flags,
90            ),
91        )?;
92        let dst_bind_group_layout = create_bind_group_layout(
93            device,
94            false,
95            false,
96            BUFFER_SIZE,
97            hal_label(
98                Some("(wgpu internal) Indirect draw validation destination bind group layout"),
99                instance_flags,
100            ),
101        )?;
102
103        let pipeline_layout_desc = hal::PipelineLayoutDescriptor {
104            label: hal_label(
105                Some("(wgpu internal) Indirect draw validation pipeline layout"),
106                instance_flags,
107            ),
108            flags: hal::PipelineLayoutFlags::empty(),
109            bind_group_layouts: &[
110                Some(metadata_bind_group_layout.as_ref()),
111                Some(src_bind_group_layout.as_ref()),
112                Some(dst_bind_group_layout.as_ref()),
113            ],
114            immediate_size: 8,
115        };
116        let pipeline_layout = unsafe {
117            device
118                .create_pipeline_layout(&pipeline_layout_desc)
119                .map_err(DeviceError::from_hal)?
120        };
121
122        let supports_indirect_first_instance =
123            required_features.contains(wgt::Features::INDIRECT_FIRST_INSTANCE);
124        let write_d3d12_special_constants = backend == wgt::Backend::Dx12;
125        let pipeline = create_validation_pipeline(
126            device,
127            module.as_ref(),
128            pipeline_layout.as_ref(),
129            supports_indirect_first_instance,
130            write_d3d12_special_constants,
131            instance_flags,
132        )?;
133
134        Ok(Self {
135            module,
136            metadata_bind_group_layout,
137            src_bind_group_layout,
138            dst_bind_group_layout,
139            pipeline_layout,
140            pipeline,
141
142            free_indirect_entries: Mutex::new(rank::BUFFER_POOL, Vec::new()),
143            free_metadata_entries: Mutex::new(rank::BUFFER_POOL, Vec::new()),
144        })
145    }
146
147    /// `Ok(None)` will only be returned if `buffer_size` is `0`.
148    pub(super) fn create_src_bind_group(
149        &self,
150        device: &dyn hal::DynDevice,
151        limits: &Limits,
152        buffer_size: u64,
153        buffer: &dyn hal::DynBuffer,
154        instance_flags: wgt::InstanceFlags,
155    ) -> Result<Option<Box<dyn hal::DynBindGroup>>, DeviceError> {
156        let binding_size = calculate_src_buffer_binding_size(buffer_size, limits);
157        let Some(binding_size) = NonZeroU64::new(binding_size) else {
158            return Ok(None);
159        };
160        let hal_desc = hal::BindGroupDescriptor {
161            label: hal_label(
162                Some("(wgpu internal) Indirect draw validation source bind group"),
163                instance_flags,
164            ),
165            layout: self.src_bind_group_layout.as_ref(),
166            entries: &[hal::BindGroupEntry {
167                binding: 0,
168                resource_index: 0,
169                count: 1,
170            }],
171            // SAFETY: We calculated the binding size to fit within the buffer.
172            buffers: &[hal::BufferBinding::new_unchecked(buffer, 0, binding_size)],
173            samplers: &[],
174            textures: &[],
175            acceleration_structures: &[],
176            external_textures: &[],
177        };
178        unsafe {
179            device
180                .create_bind_group(&hal_desc)
181                .map(Some)
182                .map_err(DeviceError::from_hal)
183        }
184    }
185
186    fn acquire_dst_entry(
187        &self,
188        device: &dyn hal::DynDevice,
189        instance_flags: wgt::InstanceFlags,
190    ) -> Result<BufferPoolEntry, hal::DeviceError> {
191        let mut free_buffers = self.free_indirect_entries.lock();
192        match free_buffers.pop() {
193            Some(buffer) => Ok(buffer),
194            None => {
195                let usage = wgt::BufferUses::INDIRECT | wgt::BufferUses::STORAGE_READ_WRITE;
196                create_buffer_and_bind_group(
197                    device,
198                    usage,
199                    self.dst_bind_group_layout.as_ref(),
200                    hal_label(Some("(wgpu internal) Indirect draw validation destination buffer"), instance_flags),
201                    hal_label(Some("(wgpu internal) Indirect draw validation destination bind group layout"), instance_flags),
202                )
203            }
204        }
205    }
206
207    fn release_dst_entries(&self, entries: impl Iterator<Item = BufferPoolEntry>) {
208        self.free_indirect_entries.lock().extend(entries);
209    }
210
211    fn acquire_metadata_entry(
212        &self,
213        device: &dyn hal::DynDevice,
214        instance_flags: wgt::InstanceFlags,
215    ) -> Result<BufferPoolEntry, hal::DeviceError> {
216        let mut free_buffers = self.free_metadata_entries.lock();
217        match free_buffers.pop() {
218            Some(buffer) => Ok(buffer),
219            None => {
220                let usage = wgt::BufferUses::COPY_DST | wgt::BufferUses::STORAGE_READ_ONLY;
221                create_buffer_and_bind_group(
222                    device,
223                    usage,
224                    self.metadata_bind_group_layout.as_ref(),
225                    hal_label(
226                        Some("(wgpu internal) Indirect draw validation metadata buffer"),
227                        instance_flags,
228                    ),
229                    hal_label(
230                        Some("(wgpu internal) Indirect draw validation metadata bind group layout"),
231                        instance_flags,
232                    ),
233                )
234            }
235        }
236    }
237
238    fn release_metadata_entries(&self, entries: impl Iterator<Item = BufferPoolEntry>) {
239        self.free_metadata_entries.lock().extend(entries);
240    }
241
242    /// Injects a compute pass that will validate all indirect draws in the current render pass.
243    pub(crate) fn inject_validation_pass(
244        &self,
245        device: &Arc<Device>,
246        snatch_guard: &SnatchGuard,
247        resources: &mut DrawResources,
248        temp_resources: &mut Vec<TempResource>,
249        encoder: &mut dyn hal::DynCommandEncoder,
250        batcher: DrawBatcher,
251    ) -> Result<(), RenderPassErrorInner> {
252        let mut batches = batcher.batches;
253
254        if batches.is_empty() {
255            return Ok(());
256        }
257
258        let max_staging_buffer_size = 1 << 26; // ~67MiB
259
260        let mut staging_buffers = Vec::new();
261
262        let mut current_size = 0;
263        for batch in batches.values_mut() {
264            let data = batch.metadata();
265            let offset = if current_size + data.len() > max_staging_buffer_size {
266                let staging_buffer =
267                    StagingBuffer::new(device, NonZeroU64::new(current_size as u64).unwrap())?;
268                staging_buffers.push(staging_buffer);
269                current_size = data.len();
270                0
271            } else {
272                let offset = current_size;
273                current_size += data.len();
274                offset as u64
275            };
276            batch.staging_buffer_index = staging_buffers.len();
277            batch.staging_buffer_offset = offset;
278        }
279        if current_size != 0 {
280            let staging_buffer =
281                StagingBuffer::new(device, NonZeroU64::new(current_size as u64).unwrap())?;
282            staging_buffers.push(staging_buffer);
283        }
284
285        for batch in batches.values() {
286            let data = batch.metadata();
287            let staging_buffer = &mut staging_buffers[batch.staging_buffer_index];
288            unsafe {
289                staging_buffer.write_with_offset(
290                    data,
291                    0,
292                    batch.staging_buffer_offset as isize,
293                    data.len(),
294                )
295            };
296        }
297
298        let staging_buffers: Vec<_> = staging_buffers
299            .into_iter()
300            .map(|buffer| buffer.flush())
301            .collect();
302
303        let mut current_metadata_entry = None;
304        for batch in batches.values_mut() {
305            let data = batch.metadata();
306            let (metadata_resource_index, metadata_buffer_offset) =
307                resources.get_metadata_subrange(data.len() as u64, &mut current_metadata_entry)?;
308            batch.metadata_resource_index = metadata_resource_index;
309            batch.metadata_buffer_offset = metadata_buffer_offset;
310        }
311
312        let buffer_barrier_scratch = &mut BufferBarrierScratch::new();
313        let unique_index_scratch = &mut UniqueIndexScratch::new();
314
315        BufferBarriers::new(buffer_barrier_scratch)
316            .extend(
317                batches
318                    .values()
319                    .map(|batch| batch.staging_buffer_index)
320                    .unique(unique_index_scratch)
321                    .map(|index| hal::BufferBarrier {
322                        buffer: staging_buffers[index].raw(),
323                        usage: hal::StateTransition {
324                            from: wgt::BufferUses::MAP_WRITE,
325                            to: wgt::BufferUses::COPY_SRC,
326                        },
327                    }),
328            )
329            .extend(
330                batches
331                    .values()
332                    .map(|batch| batch.metadata_resource_index)
333                    .unique(unique_index_scratch)
334                    .map(|index| hal::BufferBarrier {
335                        buffer: resources.get_metadata_buffer(index),
336                        usage: hal::StateTransition {
337                            from: wgt::BufferUses::STORAGE_READ_ONLY,
338                            to: wgt::BufferUses::COPY_DST,
339                        },
340                    }),
341            )
342            .encode(encoder);
343
344        for batch in batches.values() {
345            let data = batch.metadata();
346            let data_size = NonZeroU64::new(data.len() as u64).unwrap();
347
348            let staging_buffer = &staging_buffers[batch.staging_buffer_index];
349
350            let metadata_buffer = resources.get_metadata_buffer(batch.metadata_resource_index);
351
352            unsafe {
353                encoder.copy_buffer_to_buffer(
354                    staging_buffer.raw(),
355                    metadata_buffer,
356                    &[hal::BufferCopy {
357                        src_offset: batch.staging_buffer_offset,
358                        dst_offset: batch.metadata_buffer_offset,
359                        size: data_size,
360                    }],
361                );
362            }
363        }
364
365        for staging_buffer in staging_buffers {
366            temp_resources.push(TempResource::StagingBuffer(staging_buffer));
367        }
368
369        BufferBarriers::new(buffer_barrier_scratch)
370            .extend(
371                batches
372                    .values()
373                    .map(|batch| batch.metadata_resource_index)
374                    .unique(unique_index_scratch)
375                    .map(|index| hal::BufferBarrier {
376                        buffer: resources.get_metadata_buffer(index),
377                        usage: hal::StateTransition {
378                            from: wgt::BufferUses::COPY_DST,
379                            to: wgt::BufferUses::STORAGE_READ_ONLY,
380                        },
381                    }),
382            )
383            .extend(
384                batches
385                    .values()
386                    .map(|batch| batch.dst_resource_index)
387                    .unique(unique_index_scratch)
388                    .map(|index| hal::BufferBarrier {
389                        buffer: resources.get_dst_buffer(index),
390                        usage: hal::StateTransition {
391                            from: wgt::BufferUses::INDIRECT,
392                            to: wgt::BufferUses::STORAGE_READ_WRITE,
393                        },
394                    }),
395            )
396            .encode(encoder);
397
398        let desc = hal::ComputePassDescriptor {
399            label: hal_label(
400                Some("(wgpu internal) Indirect draw validation pass"),
401                device.instance_flags,
402            ),
403            timestamp_writes: None,
404        };
405        unsafe {
406            encoder.begin_compute_pass(&desc);
407        }
408        unsafe {
409            encoder.set_compute_pipeline(self.pipeline.as_ref());
410        }
411
412        for batch in batches.values() {
413            let pipeline_layout = self.pipeline_layout.as_ref();
414
415            let metadata_start =
416                (batch.metadata_buffer_offset / size_of::<MetadataEntry>() as u64) as u32;
417            let metadata_count = batch.entries.len() as u32;
418            unsafe {
419                encoder.set_immediates(pipeline_layout, 0, &[metadata_start, metadata_count]);
420            }
421
422            let metadata_bind_group =
423                resources.get_metadata_bind_group(batch.metadata_resource_index);
424            unsafe {
425                encoder.set_bind_group(pipeline_layout, 0, metadata_bind_group, &[]);
426            }
427
428            // Make sure the indirect buffer is still valid.
429            batch.src_buffer.try_raw(snatch_guard)?;
430
431            let src_bind_group = batch
432                .src_buffer
433                .indirect_validation_bind_groups
434                .get(snatch_guard)
435                .unwrap()
436                .draw
437                .as_ref();
438            unsafe {
439                encoder.set_bind_group(
440                    pipeline_layout,
441                    1,
442                    src_bind_group,
443                    &[batch.src_dynamic_offset as u32],
444                );
445            }
446
447            let dst_bind_group = resources.get_dst_bind_group(batch.dst_resource_index);
448            unsafe {
449                encoder.set_bind_group(pipeline_layout, 2, dst_bind_group, &[]);
450            }
451
452            unsafe {
453                encoder.dispatch([(batch.entries.len() as u32).div_ceil(64), 1, 1]);
454            }
455        }
456
457        unsafe {
458            encoder.end_compute_pass();
459        }
460
461        BufferBarriers::new(buffer_barrier_scratch)
462            .extend(
463                batches
464                    .values()
465                    .map(|batch| batch.dst_resource_index)
466                    .unique(unique_index_scratch)
467                    .map(|index| hal::BufferBarrier {
468                        buffer: resources.get_dst_buffer(index),
469                        usage: hal::StateTransition {
470                            from: wgt::BufferUses::STORAGE_READ_WRITE,
471                            to: wgt::BufferUses::INDIRECT,
472                        },
473                    }),
474            )
475            .encode(encoder);
476
477        Ok(())
478    }
479
480    pub(super) fn dispose(self, device: &dyn hal::DynDevice) {
481        let Draw {
482            module,
483            metadata_bind_group_layout,
484            src_bind_group_layout,
485            dst_bind_group_layout,
486            pipeline_layout,
487            pipeline,
488
489            free_indirect_entries,
490            free_metadata_entries,
491        } = self;
492
493        for entry in free_indirect_entries.into_inner().drain(..) {
494            unsafe {
495                device.destroy_bind_group(entry.bind_group);
496                device.destroy_buffer(entry.buffer);
497            }
498        }
499
500        for entry in free_metadata_entries.into_inner().drain(..) {
501            unsafe {
502                device.destroy_bind_group(entry.bind_group);
503                device.destroy_buffer(entry.buffer);
504            }
505        }
506
507        unsafe {
508            device.destroy_compute_pipeline(pipeline);
509            device.destroy_pipeline_layout(pipeline_layout);
510            device.destroy_bind_group_layout(metadata_bind_group_layout);
511            device.destroy_bind_group_layout(src_bind_group_layout);
512            device.destroy_bind_group_layout(dst_bind_group_layout);
513            device.destroy_shader_module(module);
514        }
515    }
516}
517
518fn create_validation_module(
519    device: &dyn hal::DynDevice,
520    instance_flags: wgt::InstanceFlags,
521) -> Result<Box<dyn hal::DynShaderModule>, CreateIndirectValidationPipelineError> {
522    let src = include_str!("./validate_draw.wgsl");
523
524    #[cfg(feature = "wgsl")]
525    let module = naga::front::wgsl::parse_str(src).map_err(|inner| {
526        CreateShaderModuleError::Parsing(naga::error::ShaderError {
527            source: src.to_string(),
528            label: None,
529            inner: Box::new(inner),
530        })
531    })?;
532    #[cfg(not(feature = "wgsl"))]
533    #[allow(clippy::diverging_sub_expression)]
534    let module = panic!("Indirect validation requires the wgsl feature flag to be enabled!");
535
536    let info = crate::device::create_validator(
537        wgt::Features::IMMEDIATES,
538        wgt::DownlevelFlags::empty(),
539        naga::valid::ValidationFlags::all(),
540    )
541    .validate(&module)
542    .map_err(|inner| {
543        CreateShaderModuleError::Validation(naga::error::ShaderError {
544            source: src.to_string(),
545            label: None,
546            inner: Box::new(inner),
547        })
548    })?;
549    let hal_shader = hal::ShaderInput::Naga(hal::NagaShader {
550        module: alloc::borrow::Cow::Owned(module),
551        info,
552        debug_source: None,
553    });
554    let hal_desc = hal::ShaderModuleDescriptor {
555        label: hal_label(
556            Some("(wgpu internal) Indirect draw validation shader module"),
557            instance_flags,
558        ),
559        runtime_checks: wgt::ShaderRuntimeChecks::unchecked(),
560    };
561    let module = unsafe { device.create_shader_module(&hal_desc, hal_shader) }.map_err(
562        |error| match error {
563            hal::ShaderError::Device(error) => {
564                CreateShaderModuleError::Device(DeviceError::from_hal(error))
565            }
566            hal::ShaderError::Compilation(ref msg) => {
567                log::error!("Shader error: {msg}");
568                CreateShaderModuleError::Generation
569            }
570        },
571    )?;
572
573    Ok(module)
574}
575
576fn create_validation_pipeline(
577    device: &dyn hal::DynDevice,
578    module: &dyn hal::DynShaderModule,
579    pipeline_layout: &dyn hal::DynPipelineLayout,
580    supports_indirect_first_instance: bool,
581    write_d3d12_special_constants: bool,
582    instance_flags: wgt::InstanceFlags,
583) -> Result<Box<dyn hal::DynComputePipeline>, CreateIndirectValidationPipelineError> {
584    let pipeline_desc = hal::ComputePipelineDescriptor {
585        label: hal_label(
586            Some("(wgpu internal) Indirect draw validation pipeline"),
587            instance_flags,
588        ),
589        layout: pipeline_layout,
590        stage: hal::ProgrammableStage {
591            module,
592            entry_point: "main",
593            constants: &hashbrown::HashMap::from([
594                (
595                    "supports_indirect_first_instance".to_string(),
596                    f64::from(supports_indirect_first_instance),
597                ),
598                (
599                    "write_d3d12_special_constants".to_string(),
600                    f64::from(write_d3d12_special_constants),
601                ),
602            ]),
603            zero_initialize_workgroup_memory: false,
604        },
605        cache: None,
606    };
607    let pipeline =
608        unsafe { device.create_compute_pipeline(&pipeline_desc) }.map_err(|err| match err {
609            hal::PipelineError::Device(error) => {
610                CreateComputePipelineError::Device(DeviceError::from_hal(error))
611            }
612            hal::PipelineError::Linkage(_stages, msg) => CreateComputePipelineError::Internal(msg),
613            hal::PipelineError::EntryPoint(_stage) => CreateComputePipelineError::Internal(
614                crate::device::ENTRYPOINT_FAILURE_ERROR.to_string(),
615            ),
616            hal::PipelineError::PipelineConstants(_, error) => {
617                CreateComputePipelineError::PipelineConstants(error)
618            }
619        })?;
620
621    Ok(pipeline)
622}
623
624fn create_bind_group_layout(
625    device: &dyn hal::DynDevice,
626    read_only: bool,
627    has_dynamic_offset: bool,
628    min_binding_size: wgt::BufferSize,
629    label: Option<&'static str>,
630) -> Result<Box<dyn hal::DynBindGroupLayout>, CreateIndirectValidationPipelineError> {
631    let bind_group_layout_desc = hal::BindGroupLayoutDescriptor {
632        label,
633        flags: hal::BindGroupLayoutFlags::empty(),
634        entries: &[wgt::BindGroupLayoutEntry {
635            binding: 0,
636            visibility: wgt::ShaderStages::COMPUTE,
637            ty: wgt::BindingType::Buffer {
638                ty: wgt::BufferBindingType::Storage { read_only },
639                has_dynamic_offset,
640                min_binding_size: Some(min_binding_size),
641            },
642            count: None,
643        }],
644    };
645    let bind_group_layout = unsafe {
646        device
647            .create_bind_group_layout(&bind_group_layout_desc)
648            .map_err(DeviceError::from_hal)?
649    };
650
651    Ok(bind_group_layout)
652}
653
654/// Returns the largest binding size that when combined with dynamic offsets can address the whole buffer.
655fn calculate_src_buffer_binding_size(buffer_size: u64, limits: &Limits) -> u64 {
656    let max_storage_buffer_binding_size = limits.max_storage_buffer_binding_size as u64;
657    let min_storage_buffer_offset_alignment = limits.min_storage_buffer_offset_alignment as u64;
658
659    if buffer_size <= max_storage_buffer_binding_size {
660        buffer_size
661    } else {
662        let buffer_rem = buffer_size % min_storage_buffer_offset_alignment;
663        let binding_rem = max_storage_buffer_binding_size % min_storage_buffer_offset_alignment;
664
665        // Can the buffer remainder fit in the binding remainder?
666        // If so, align max binding size and add buffer remainder
667        if buffer_rem <= binding_rem {
668            max_storage_buffer_binding_size - binding_rem + buffer_rem
669        }
670        // If not, align max binding size, shorten it by a chunk and add buffer remainder
671        else {
672            max_storage_buffer_binding_size - binding_rem - min_storage_buffer_offset_alignment
673                + buffer_rem
674        }
675    }
676}
677
678/// Splits the given `offset` into a dynamic offset & offset.
679fn calculate_src_offsets(buffer_size: u64, limits: &Limits, offset: u64) -> (u64, u64) {
680    let binding_size = calculate_src_buffer_binding_size(buffer_size, limits);
681
682    let min_storage_buffer_offset_alignment = limits.min_storage_buffer_offset_alignment as u64;
683
684    let chunk_adjustment = match min_storage_buffer_offset_alignment {
685        // No need to adjust since the src_offset is 4 byte aligned.
686        4 => 0,
687        // With 16/20 bytes of data we can straddle up to 2 8 byte boundaries:
688        //  - 16 bytes of data: (4|8|4)
689        //  - 20 bytes of data: (4|8|8, 8|8|4)
690        8 => 2,
691        // With 16/20 bytes of data we can straddle up to 1 16+ byte boundary:
692        //  - 16 bytes of data: (4|12, 8|8, 12|4)
693        //  - 20 bytes of data: (4|16, 8|12, 12|8, 16|4)
694        16.. => 1,
695        _ => unreachable!(),
696    };
697
698    let chunks = binding_size / min_storage_buffer_offset_alignment;
699    let dynamic_offset_stride =
700        chunks.saturating_sub(chunk_adjustment) * min_storage_buffer_offset_alignment;
701
702    if dynamic_offset_stride == 0 {
703        return (0, offset);
704    }
705
706    let max_dynamic_offset = buffer_size - binding_size;
707    let max_dynamic_offset_index = max_dynamic_offset / dynamic_offset_stride;
708
709    let src_dynamic_offset_index = offset / dynamic_offset_stride;
710
711    let src_dynamic_offset =
712        src_dynamic_offset_index.min(max_dynamic_offset_index) * dynamic_offset_stride;
713    let src_offset = offset - src_dynamic_offset;
714
715    (src_dynamic_offset, src_offset)
716}
717
718#[derive(Debug)]
719struct BufferPoolEntry {
720    buffer: Box<dyn hal::DynBuffer>,
721    bind_group: Box<dyn hal::DynBindGroup>,
722}
723
724fn create_buffer_and_bind_group(
725    device: &dyn hal::DynDevice,
726    usage: wgt::BufferUses,
727    bind_group_layout: &dyn hal::DynBindGroupLayout,
728    buffer_label: Option<&'static str>,
729    bind_group_label: Option<&'static str>,
730) -> Result<BufferPoolEntry, hal::DeviceError> {
731    let buffer_desc = hal::BufferDescriptor {
732        label: buffer_label,
733        size: BUFFER_SIZE.get(),
734        usage,
735        memory_flags: hal::MemoryFlags::empty(),
736    };
737    let buffer = unsafe { device.create_buffer(&buffer_desc) }?;
738    let bind_group_desc = hal::BindGroupDescriptor {
739        label: bind_group_label,
740        layout: bind_group_layout,
741        entries: &[hal::BindGroupEntry {
742            binding: 0,
743            resource_index: 0,
744            count: 1,
745        }],
746        // SAFETY: We just created the buffer with this size.
747        buffers: &[hal::BufferBinding::new_unchecked(
748            buffer.as_ref(),
749            0,
750            BUFFER_SIZE,
751        )],
752        samplers: &[],
753        textures: &[],
754        acceleration_structures: &[],
755        external_textures: &[],
756    };
757    let bind_group = unsafe { device.create_bind_group(&bind_group_desc) }?;
758    Ok(BufferPoolEntry { buffer, bind_group })
759}
760
761#[derive(Clone)]
762struct CurrentEntry {
763    index: usize,
764    offset: u64,
765}
766
767/// Holds all command buffer-level resources that are needed to validate indirect draws.
768pub(crate) struct DrawResources {
769    device: Arc<Device>,
770    dst_entries: Vec<BufferPoolEntry>,
771    metadata_entries: Vec<BufferPoolEntry>,
772}
773
774impl Drop for DrawResources {
775    fn drop(&mut self) {
776        if let Some(ref indirect_validation) = self.device.indirect_validation {
777            let indirect_draw_validation = &indirect_validation.draw;
778            indirect_draw_validation.release_dst_entries(self.dst_entries.drain(..));
779            indirect_draw_validation.release_metadata_entries(self.metadata_entries.drain(..));
780        }
781    }
782}
783
784impl DrawResources {
785    pub(crate) fn new(device: Arc<Device>) -> Self {
786        DrawResources {
787            device,
788            dst_entries: Vec::new(),
789            metadata_entries: Vec::new(),
790        }
791    }
792
793    pub(crate) fn get_dst_buffer(&self, index: usize) -> &dyn hal::DynBuffer {
794        self.dst_entries.get(index).unwrap().buffer.as_ref()
795    }
796
797    fn get_dst_bind_group(&self, index: usize) -> &dyn hal::DynBindGroup {
798        self.dst_entries.get(index).unwrap().bind_group.as_ref()
799    }
800
801    fn get_metadata_buffer(&self, index: usize) -> &dyn hal::DynBuffer {
802        self.metadata_entries.get(index).unwrap().buffer.as_ref()
803    }
804
805    fn get_metadata_bind_group(&self, index: usize) -> &dyn hal::DynBindGroup {
806        self.metadata_entries
807            .get(index)
808            .unwrap()
809            .bind_group
810            .as_ref()
811    }
812
813    fn get_dst_subrange(
814        &mut self,
815        size: u64,
816        current_entry: &mut Option<CurrentEntry>,
817    ) -> Result<(usize, u64), DeviceError> {
818        let indirect_draw_validation = &self.device.indirect_validation.as_ref().unwrap().draw;
819        let ensure_entry = |index: usize| {
820            if self.dst_entries.len() <= index {
821                let entry = indirect_draw_validation
822                    .acquire_dst_entry(self.device.raw(), self.device.instance_flags)?;
823                self.dst_entries.push(entry);
824            }
825            Ok(())
826        };
827        let entry_data = Self::get_subrange_impl(ensure_entry, current_entry, size)?;
828        Ok((entry_data.index, entry_data.offset))
829    }
830
831    fn get_metadata_subrange(
832        &mut self,
833        size: u64,
834        current_entry: &mut Option<CurrentEntry>,
835    ) -> Result<(usize, u64), DeviceError> {
836        let indirect_draw_validation = &self.device.indirect_validation.as_ref().unwrap().draw;
837        let ensure_entry = |index: usize| {
838            if self.metadata_entries.len() <= index {
839                let entry = indirect_draw_validation
840                    .acquire_metadata_entry(self.device.raw(), self.device.instance_flags)?;
841                self.metadata_entries.push(entry);
842            }
843            Ok(())
844        };
845        let entry_data = Self::get_subrange_impl(ensure_entry, current_entry, size)?;
846        Ok((entry_data.index, entry_data.offset))
847    }
848
849    fn get_subrange_impl(
850        ensure_entry: impl FnOnce(usize) -> Result<(), hal::DeviceError>,
851        current_entry: &mut Option<CurrentEntry>,
852        size: u64,
853    ) -> Result<CurrentEntry, DeviceError> {
854        let index = if let Some(current_entry) = current_entry.as_mut() {
855            if current_entry.offset + size <= BUFFER_SIZE.get() {
856                let entry_data = current_entry.clone();
857                current_entry.offset += size;
858                return Ok(entry_data);
859            } else {
860                current_entry.index + 1
861            }
862        } else {
863            0
864        };
865
866        ensure_entry(index).map_err(DeviceError::from_hal)?;
867
868        let entry_data = CurrentEntry { index, offset: 0 };
869
870        *current_entry = Some(CurrentEntry {
871            index,
872            offset: size,
873        });
874
875        Ok(entry_data)
876    }
877}
878
879/// This must match the `MetadataEntry` struct used by the shader.
880#[repr(C)]
881struct MetadataEntry {
882    src_offset: u32,
883    dst_offset: u32,
884    vertex_or_index_limit: u32,
885    instance_limit: u32,
886}
887
888impl MetadataEntry {
889    fn new(
890        indexed: bool,
891        src_offset: u64,
892        dst_offset: u64,
893        vertex_or_index_limit: u64,
894        instance_limit: u64,
895    ) -> Self {
896        debug_assert_eq!(
897            4,
898            size_of_val(&Limits::default().max_storage_buffer_binding_size)
899        );
900
901        let src_offset = src_offset as u32; // max_storage_buffer_binding_size is a u32
902        let src_offset = src_offset / 4; // translate byte offset to offset in u32's
903
904        // `src_offset` needs at most 30 bits,
905        // pack `indexed` in bit 31 of `src_offset`
906        let src_offset = src_offset | ((indexed as u32) << 31);
907
908        // max value for limits since first_X and X_count indirect draw arguments are u32
909        let max_limit = u32::MAX as u64 + u32::MAX as u64; // 1 11111111 11111111 11111111 11111110
910
911        let vertex_or_index_limit = vertex_or_index_limit.min(max_limit);
912        let vertex_or_index_limit_bit_32 = (vertex_or_index_limit >> 32) as u32; // extract bit 32
913        let vertex_or_index_limit = vertex_or_index_limit as u32; // truncate the limit to a u32
914
915        let instance_limit = instance_limit.min(max_limit);
916        let instance_limit_bit_32 = (instance_limit >> 32) as u32; // extract bit 32
917        let instance_limit = instance_limit as u32; // truncate the limit to a u32
918
919        let dst_offset = dst_offset as u32; // max_storage_buffer_binding_size is a u32
920        let dst_offset = dst_offset / 4; // translate byte offset to offset in u32's
921
922        // `dst_offset` needs at most 30 bits,
923        // pack `vertex_or_index_limit_bit_32` in bit 30 of `dst_offset` and
924        // pack `instance_limit_bit_32` in bit 31 of `dst_offset`
925        let dst_offset =
926            dst_offset | (vertex_or_index_limit_bit_32 << 30) | (instance_limit_bit_32 << 31);
927
928        Self {
929            src_offset,
930            dst_offset,
931            vertex_or_index_limit,
932            instance_limit,
933        }
934    }
935}
936
937struct DrawIndirectValidationBatch {
938    src_buffer: Arc<crate::resource::Buffer>,
939    src_dynamic_offset: u64,
940    dst_resource_index: usize,
941    entries: Vec<MetadataEntry>,
942
943    staging_buffer_index: usize,
944    staging_buffer_offset: u64,
945    metadata_resource_index: usize,
946    metadata_buffer_offset: u64,
947}
948
949impl DrawIndirectValidationBatch {
950    /// Data to be written to the metadata buffer.
951    fn metadata(&self) -> &[u8] {
952        unsafe {
953            core::slice::from_raw_parts(
954                self.entries.as_ptr().cast::<u8>(),
955                self.entries.len() * size_of::<MetadataEntry>(),
956            )
957        }
958    }
959}
960
961/// Accumulates all needed data needed to validate indirect draws.
962pub(crate) struct DrawBatcher {
963    batches: FastHashMap<(TrackerIndex, u64, usize), DrawIndirectValidationBatch>,
964    current_dst_entry: Option<CurrentEntry>,
965}
966
967impl DrawBatcher {
968    pub(crate) fn new() -> Self {
969        Self {
970            batches: FastHashMap::default(),
971            current_dst_entry: None,
972        }
973    }
974
975    /// Add an indirect draw to be validated.
976    ///
977    /// Returns the index of the indirect buffer in `indirect_draw_validation_resources`
978    /// and the offset to be used for the draw.
979    pub(crate) fn add<'a>(
980        &mut self,
981        indirect_draw_validation_resources: &'a mut DrawResources,
982        device: &Device,
983        src_buffer: &Arc<crate::resource::Buffer>,
984        offset: u64,
985        family: crate::command::DrawCommandFamily,
986        vertex_or_index_limit: u64,
987        instance_limit: u64,
988    ) -> Result<(usize, u64), DeviceError> {
989        // space for D3D12 special constants
990        let extra = if device.backend() == wgt::Backend::Dx12 {
991            3 * size_of::<u32>() as u64
992        } else {
993            0
994        };
995        let stride = extra + crate::command::get_stride_of_indirect_args(family);
996
997        let (dst_resource_index, dst_offset) = indirect_draw_validation_resources
998            .get_dst_subrange(stride, &mut self.current_dst_entry)?;
999
1000        let buffer_size = src_buffer.size;
1001        let limits = device.adapter.limits();
1002        let (src_dynamic_offset, src_offset) = calculate_src_offsets(buffer_size, &limits, offset);
1003
1004        let src_buffer_tracker_index = src_buffer.tracker_index();
1005
1006        let entry = MetadataEntry::new(
1007            family == crate::command::DrawCommandFamily::DrawIndexed,
1008            src_offset,
1009            dst_offset,
1010            vertex_or_index_limit,
1011            instance_limit,
1012        );
1013
1014        match self.batches.entry((
1015            src_buffer_tracker_index,
1016            src_dynamic_offset,
1017            dst_resource_index,
1018        )) {
1019            hashbrown::hash_map::Entry::Occupied(mut occupied_entry) => {
1020                occupied_entry.get_mut().entries.push(entry)
1021            }
1022            hashbrown::hash_map::Entry::Vacant(vacant_entry) => {
1023                vacant_entry.insert(DrawIndirectValidationBatch {
1024                    src_buffer: src_buffer.clone(),
1025                    src_dynamic_offset,
1026                    dst_resource_index,
1027                    entries: vec![entry],
1028
1029                    // these will be initialized once we accumulated all entries for the batch
1030                    staging_buffer_index: 0,
1031                    staging_buffer_offset: 0,
1032                    metadata_resource_index: 0,
1033                    metadata_buffer_offset: 0,
1034                });
1035            }
1036        }
1037
1038        Ok((dst_resource_index, dst_offset))
1039    }
1040}