wgpu_core/indirect_validation/
draw.rs

1use super::{
2    utils::{BufferBarrierScratch, BufferBarriers, UniqueIndexExt as _, UniqueIndexScratch},
3    CreateIndirectValidationPipelineError,
4};
5use crate::{
6    command::RenderPassErrorInner,
7    device::{queue::TempResource, Device, DeviceError},
8    lock::{rank, Mutex},
9    pipeline::{CreateComputePipelineError, CreateShaderModuleError},
10    resource::{RawResourceAccess as _, StagingBuffer, Trackable},
11    snatch::SnatchGuard,
12    track::TrackerIndex,
13    FastHashMap,
14};
15use alloc::{boxed::Box, string::ToString, sync::Arc, vec, vec::Vec};
16use core::{
17    mem::{size_of, size_of_val},
18    num::NonZeroU64,
19};
20use wgt::Limits;
21
22/// Note: This needs to be under:
23///
24/// default max_compute_workgroups_per_dimension * size_of::<wgt::DrawIndirectArgs>() * `workgroup_size` used by the shader
25///
26/// = (2^16 - 1) * 2^4 * 2^6
27///
28/// It is currently set to:
29///
30/// = (2^16 - 1) * 2^4
31///
32/// This is enough space for:
33///
34/// - 65535 [`wgt::DrawIndirectArgs`] / [`MetadataEntry`]
35/// - 52428 [`wgt::DrawIndexedIndirectArgs`]
36const BUFFER_SIZE: wgt::BufferSize = unsafe { wgt::BufferSize::new_unchecked(1_048_560) };
37
38/// Holds all device-level resources that are needed to validate indirect draws.
39///
40/// This machinery requires the following limits:
41///
42/// - max_bind_groups: 3,
43/// - max_dynamic_storage_buffers_per_pipeline_layout: 1,
44/// - max_storage_buffers_per_shader_stage: 3,
45/// - max_immediate_size: 8,
46///
47/// These are all indirectly satisfied by `DownlevelFlags::INDIRECT_EXECUTION`, which is also
48/// required for this module's functionality to work.
49#[derive(Debug)]
50pub(crate) struct Draw {
51    module: Box<dyn hal::DynShaderModule>,
52    metadata_bind_group_layout: Box<dyn hal::DynBindGroupLayout>,
53    src_bind_group_layout: Box<dyn hal::DynBindGroupLayout>,
54    dst_bind_group_layout: Box<dyn hal::DynBindGroupLayout>,
55    pipeline_layout: Box<dyn hal::DynPipelineLayout>,
56    pipeline: Box<dyn hal::DynComputePipeline>,
57
58    free_indirect_entries: Mutex<Vec<BufferPoolEntry>>,
59    free_metadata_entries: Mutex<Vec<BufferPoolEntry>>,
60}
61
62impl Draw {
63    pub(super) fn new(
64        device: &dyn hal::DynDevice,
65        required_features: &wgt::Features,
66        backend: wgt::Backend,
67    ) -> Result<Self, CreateIndirectValidationPipelineError> {
68        let module = create_validation_module(device)?;
69
70        let metadata_bind_group_layout =
71            create_bind_group_layout(device, true, false, BUFFER_SIZE)?;
72        let src_bind_group_layout =
73            create_bind_group_layout(device, true, true, wgt::BufferSize::new(4 * 4).unwrap())?;
74        let dst_bind_group_layout = create_bind_group_layout(device, false, false, BUFFER_SIZE)?;
75
76        let pipeline_layout_desc = hal::PipelineLayoutDescriptor {
77            label: None,
78            flags: hal::PipelineLayoutFlags::empty(),
79            bind_group_layouts: &[
80                metadata_bind_group_layout.as_ref(),
81                src_bind_group_layout.as_ref(),
82                dst_bind_group_layout.as_ref(),
83            ],
84            immediate_size: 8,
85        };
86        let pipeline_layout = unsafe {
87            device
88                .create_pipeline_layout(&pipeline_layout_desc)
89                .map_err(DeviceError::from_hal)?
90        };
91
92        let supports_indirect_first_instance =
93            required_features.contains(wgt::Features::INDIRECT_FIRST_INSTANCE);
94        let write_d3d12_special_constants = backend == wgt::Backend::Dx12;
95        let pipeline = create_validation_pipeline(
96            device,
97            module.as_ref(),
98            pipeline_layout.as_ref(),
99            supports_indirect_first_instance,
100            write_d3d12_special_constants,
101        )?;
102
103        Ok(Self {
104            module,
105            metadata_bind_group_layout,
106            src_bind_group_layout,
107            dst_bind_group_layout,
108            pipeline_layout,
109            pipeline,
110
111            free_indirect_entries: Mutex::new(rank::BUFFER_POOL, Vec::new()),
112            free_metadata_entries: Mutex::new(rank::BUFFER_POOL, Vec::new()),
113        })
114    }
115
116    /// `Ok(None)` will only be returned if `buffer_size` is `0`.
117    pub(super) fn create_src_bind_group(
118        &self,
119        device: &dyn hal::DynDevice,
120        limits: &Limits,
121        buffer_size: u64,
122        buffer: &dyn hal::DynBuffer,
123    ) -> Result<Option<Box<dyn hal::DynBindGroup>>, DeviceError> {
124        let binding_size = calculate_src_buffer_binding_size(buffer_size, limits);
125        let Some(binding_size) = NonZeroU64::new(binding_size) else {
126            return Ok(None);
127        };
128        let hal_desc = hal::BindGroupDescriptor {
129            label: None,
130            layout: self.src_bind_group_layout.as_ref(),
131            entries: &[hal::BindGroupEntry {
132                binding: 0,
133                resource_index: 0,
134                count: 1,
135            }],
136            // SAFETY: We calculated the binding size to fit within the buffer.
137            buffers: &[hal::BufferBinding::new_unchecked(buffer, 0, binding_size)],
138            samplers: &[],
139            textures: &[],
140            acceleration_structures: &[],
141            external_textures: &[],
142        };
143        unsafe {
144            device
145                .create_bind_group(&hal_desc)
146                .map(Some)
147                .map_err(DeviceError::from_hal)
148        }
149    }
150
151    fn acquire_dst_entry(
152        &self,
153        device: &dyn hal::DynDevice,
154    ) -> Result<BufferPoolEntry, hal::DeviceError> {
155        let mut free_buffers = self.free_indirect_entries.lock();
156        match free_buffers.pop() {
157            Some(buffer) => Ok(buffer),
158            None => {
159                let usage = wgt::BufferUses::INDIRECT | wgt::BufferUses::STORAGE_READ_WRITE;
160                create_buffer_and_bind_group(device, usage, self.dst_bind_group_layout.as_ref())
161            }
162        }
163    }
164
165    fn release_dst_entries(&self, entries: impl Iterator<Item = BufferPoolEntry>) {
166        self.free_indirect_entries.lock().extend(entries);
167    }
168
169    fn acquire_metadata_entry(
170        &self,
171        device: &dyn hal::DynDevice,
172    ) -> Result<BufferPoolEntry, hal::DeviceError> {
173        let mut free_buffers = self.free_metadata_entries.lock();
174        match free_buffers.pop() {
175            Some(buffer) => Ok(buffer),
176            None => {
177                let usage = wgt::BufferUses::COPY_DST | wgt::BufferUses::STORAGE_READ_ONLY;
178                create_buffer_and_bind_group(
179                    device,
180                    usage,
181                    self.metadata_bind_group_layout.as_ref(),
182                )
183            }
184        }
185    }
186
187    fn release_metadata_entries(&self, entries: impl Iterator<Item = BufferPoolEntry>) {
188        self.free_metadata_entries.lock().extend(entries);
189    }
190
191    /// Injects a compute pass that will validate all indirect draws in the current render pass.
192    pub(crate) fn inject_validation_pass(
193        &self,
194        device: &Arc<Device>,
195        snatch_guard: &SnatchGuard,
196        resources: &mut DrawResources,
197        temp_resources: &mut Vec<TempResource>,
198        encoder: &mut dyn hal::DynCommandEncoder,
199        batcher: DrawBatcher,
200    ) -> Result<(), RenderPassErrorInner> {
201        let mut batches = batcher.batches;
202
203        if batches.is_empty() {
204            return Ok(());
205        }
206
207        let max_staging_buffer_size = 1 << 26; // ~67MiB
208
209        let mut staging_buffers = Vec::new();
210
211        let mut current_size = 0;
212        for batch in batches.values_mut() {
213            let data = batch.metadata();
214            let offset = if current_size + data.len() > max_staging_buffer_size {
215                let staging_buffer =
216                    StagingBuffer::new(device, NonZeroU64::new(current_size as u64).unwrap())?;
217                staging_buffers.push(staging_buffer);
218                current_size = data.len();
219                0
220            } else {
221                let offset = current_size;
222                current_size += data.len();
223                offset as u64
224            };
225            batch.staging_buffer_index = staging_buffers.len();
226            batch.staging_buffer_offset = offset;
227        }
228        if current_size != 0 {
229            let staging_buffer =
230                StagingBuffer::new(device, NonZeroU64::new(current_size as u64).unwrap())?;
231            staging_buffers.push(staging_buffer);
232        }
233
234        for batch in batches.values() {
235            let data = batch.metadata();
236            let staging_buffer = &mut staging_buffers[batch.staging_buffer_index];
237            unsafe {
238                staging_buffer.write_with_offset(
239                    data,
240                    0,
241                    batch.staging_buffer_offset as isize,
242                    data.len(),
243                )
244            };
245        }
246
247        let staging_buffers: Vec<_> = staging_buffers
248            .into_iter()
249            .map(|buffer| buffer.flush())
250            .collect();
251
252        let mut current_metadata_entry = None;
253        for batch in batches.values_mut() {
254            let data = batch.metadata();
255            let (metadata_resource_index, metadata_buffer_offset) =
256                resources.get_metadata_subrange(data.len() as u64, &mut current_metadata_entry)?;
257            batch.metadata_resource_index = metadata_resource_index;
258            batch.metadata_buffer_offset = metadata_buffer_offset;
259        }
260
261        let buffer_barrier_scratch = &mut BufferBarrierScratch::new();
262        let unique_index_scratch = &mut UniqueIndexScratch::new();
263
264        BufferBarriers::new(buffer_barrier_scratch)
265            .extend(
266                batches
267                    .values()
268                    .map(|batch| batch.staging_buffer_index)
269                    .unique(unique_index_scratch)
270                    .map(|index| hal::BufferBarrier {
271                        buffer: staging_buffers[index].raw(),
272                        usage: hal::StateTransition {
273                            from: wgt::BufferUses::MAP_WRITE,
274                            to: wgt::BufferUses::COPY_SRC,
275                        },
276                    }),
277            )
278            .extend(
279                batches
280                    .values()
281                    .map(|batch| batch.metadata_resource_index)
282                    .unique(unique_index_scratch)
283                    .map(|index| hal::BufferBarrier {
284                        buffer: resources.get_metadata_buffer(index),
285                        usage: hal::StateTransition {
286                            from: wgt::BufferUses::STORAGE_READ_ONLY,
287                            to: wgt::BufferUses::COPY_DST,
288                        },
289                    }),
290            )
291            .encode(encoder);
292
293        for batch in batches.values() {
294            let data = batch.metadata();
295            let data_size = NonZeroU64::new(data.len() as u64).unwrap();
296
297            let staging_buffer = &staging_buffers[batch.staging_buffer_index];
298
299            let metadata_buffer = resources.get_metadata_buffer(batch.metadata_resource_index);
300
301            unsafe {
302                encoder.copy_buffer_to_buffer(
303                    staging_buffer.raw(),
304                    metadata_buffer,
305                    &[hal::BufferCopy {
306                        src_offset: batch.staging_buffer_offset,
307                        dst_offset: batch.metadata_buffer_offset,
308                        size: data_size,
309                    }],
310                );
311            }
312        }
313
314        for staging_buffer in staging_buffers {
315            temp_resources.push(TempResource::StagingBuffer(staging_buffer));
316        }
317
318        BufferBarriers::new(buffer_barrier_scratch)
319            .extend(
320                batches
321                    .values()
322                    .map(|batch| batch.metadata_resource_index)
323                    .unique(unique_index_scratch)
324                    .map(|index| hal::BufferBarrier {
325                        buffer: resources.get_metadata_buffer(index),
326                        usage: hal::StateTransition {
327                            from: wgt::BufferUses::COPY_DST,
328                            to: wgt::BufferUses::STORAGE_READ_ONLY,
329                        },
330                    }),
331            )
332            .extend(
333                batches
334                    .values()
335                    .map(|batch| batch.dst_resource_index)
336                    .unique(unique_index_scratch)
337                    .map(|index| hal::BufferBarrier {
338                        buffer: resources.get_dst_buffer(index),
339                        usage: hal::StateTransition {
340                            from: wgt::BufferUses::INDIRECT,
341                            to: wgt::BufferUses::STORAGE_READ_WRITE,
342                        },
343                    }),
344            )
345            .encode(encoder);
346
347        let desc = hal::ComputePassDescriptor {
348            label: None,
349            timestamp_writes: None,
350        };
351        unsafe {
352            encoder.begin_compute_pass(&desc);
353        }
354        unsafe {
355            encoder.set_compute_pipeline(self.pipeline.as_ref());
356        }
357
358        for batch in batches.values() {
359            let pipeline_layout = self.pipeline_layout.as_ref();
360
361            let metadata_start =
362                (batch.metadata_buffer_offset / size_of::<MetadataEntry>() as u64) as u32;
363            let metadata_count = batch.entries.len() as u32;
364            unsafe {
365                encoder.set_immediates(pipeline_layout, 0, &[metadata_start, metadata_count]);
366            }
367
368            let metadata_bind_group =
369                resources.get_metadata_bind_group(batch.metadata_resource_index);
370            unsafe {
371                encoder.set_bind_group(pipeline_layout, 0, Some(metadata_bind_group), &[]);
372            }
373
374            // Make sure the indirect buffer is still valid.
375            batch.src_buffer.try_raw(snatch_guard)?;
376
377            let src_bind_group = batch
378                .src_buffer
379                .indirect_validation_bind_groups
380                .get(snatch_guard)
381                .unwrap()
382                .draw
383                .as_ref();
384            unsafe {
385                encoder.set_bind_group(
386                    pipeline_layout,
387                    1,
388                    Some(src_bind_group),
389                    &[batch.src_dynamic_offset as u32],
390                );
391            }
392
393            let dst_bind_group = resources.get_dst_bind_group(batch.dst_resource_index);
394            unsafe {
395                encoder.set_bind_group(pipeline_layout, 2, Some(dst_bind_group), &[]);
396            }
397
398            unsafe {
399                encoder.dispatch([(batch.entries.len() as u32).div_ceil(64), 1, 1]);
400            }
401        }
402
403        unsafe {
404            encoder.end_compute_pass();
405        }
406
407        BufferBarriers::new(buffer_barrier_scratch)
408            .extend(
409                batches
410                    .values()
411                    .map(|batch| batch.dst_resource_index)
412                    .unique(unique_index_scratch)
413                    .map(|index| hal::BufferBarrier {
414                        buffer: resources.get_dst_buffer(index),
415                        usage: hal::StateTransition {
416                            from: wgt::BufferUses::STORAGE_READ_WRITE,
417                            to: wgt::BufferUses::INDIRECT,
418                        },
419                    }),
420            )
421            .encode(encoder);
422
423        Ok(())
424    }
425
426    pub(super) fn dispose(self, device: &dyn hal::DynDevice) {
427        let Draw {
428            module,
429            metadata_bind_group_layout,
430            src_bind_group_layout,
431            dst_bind_group_layout,
432            pipeline_layout,
433            pipeline,
434
435            free_indirect_entries,
436            free_metadata_entries,
437        } = self;
438
439        for entry in free_indirect_entries.into_inner().drain(..) {
440            unsafe {
441                device.destroy_bind_group(entry.bind_group);
442                device.destroy_buffer(entry.buffer);
443            }
444        }
445
446        for entry in free_metadata_entries.into_inner().drain(..) {
447            unsafe {
448                device.destroy_bind_group(entry.bind_group);
449                device.destroy_buffer(entry.buffer);
450            }
451        }
452
453        unsafe {
454            device.destroy_compute_pipeline(pipeline);
455            device.destroy_pipeline_layout(pipeline_layout);
456            device.destroy_bind_group_layout(metadata_bind_group_layout);
457            device.destroy_bind_group_layout(src_bind_group_layout);
458            device.destroy_bind_group_layout(dst_bind_group_layout);
459            device.destroy_shader_module(module);
460        }
461    }
462}
463
464fn create_validation_module(
465    device: &dyn hal::DynDevice,
466) -> Result<Box<dyn hal::DynShaderModule>, CreateIndirectValidationPipelineError> {
467    let src = include_str!("./validate_draw.wgsl");
468
469    #[cfg(feature = "wgsl")]
470    let module = naga::front::wgsl::parse_str(src).map_err(|inner| {
471        CreateShaderModuleError::Parsing(naga::error::ShaderError {
472            source: src.to_string(),
473            label: None,
474            inner: Box::new(inner),
475        })
476    })?;
477    #[cfg(not(feature = "wgsl"))]
478    #[allow(clippy::diverging_sub_expression)]
479    let module = panic!("Indirect validation requires the wgsl feature flag to be enabled!");
480
481    let info = crate::device::create_validator(
482        wgt::Features::IMMEDIATES,
483        wgt::DownlevelFlags::empty(),
484        naga::valid::ValidationFlags::all(),
485    )
486    .validate(&module)
487    .map_err(|inner| {
488        CreateShaderModuleError::Validation(naga::error::ShaderError {
489            source: src.to_string(),
490            label: None,
491            inner: Box::new(inner),
492        })
493    })?;
494    let hal_shader = hal::ShaderInput::Naga(hal::NagaShader {
495        module: alloc::borrow::Cow::Owned(module),
496        info,
497        debug_source: None,
498    });
499    let hal_desc = hal::ShaderModuleDescriptor {
500        label: None,
501        runtime_checks: wgt::ShaderRuntimeChecks::unchecked(),
502    };
503    let module = unsafe { device.create_shader_module(&hal_desc, hal_shader) }.map_err(
504        |error| match error {
505            hal::ShaderError::Device(error) => {
506                CreateShaderModuleError::Device(DeviceError::from_hal(error))
507            }
508            hal::ShaderError::Compilation(ref msg) => {
509                log::error!("Shader error: {msg}");
510                CreateShaderModuleError::Generation
511            }
512        },
513    )?;
514
515    Ok(module)
516}
517
518fn create_validation_pipeline(
519    device: &dyn hal::DynDevice,
520    module: &dyn hal::DynShaderModule,
521    pipeline_layout: &dyn hal::DynPipelineLayout,
522    supports_indirect_first_instance: bool,
523    write_d3d12_special_constants: bool,
524) -> Result<Box<dyn hal::DynComputePipeline>, CreateIndirectValidationPipelineError> {
525    let pipeline_desc = hal::ComputePipelineDescriptor {
526        label: None,
527        layout: pipeline_layout,
528        stage: hal::ProgrammableStage {
529            module,
530            entry_point: "main",
531            constants: &hashbrown::HashMap::from([
532                (
533                    "supports_indirect_first_instance".to_string(),
534                    f64::from(supports_indirect_first_instance),
535                ),
536                (
537                    "write_d3d12_special_constants".to_string(),
538                    f64::from(write_d3d12_special_constants),
539                ),
540            ]),
541            zero_initialize_workgroup_memory: false,
542        },
543        cache: None,
544    };
545    let pipeline =
546        unsafe { device.create_compute_pipeline(&pipeline_desc) }.map_err(|err| match err {
547            hal::PipelineError::Device(error) => {
548                CreateComputePipelineError::Device(DeviceError::from_hal(error))
549            }
550            hal::PipelineError::Linkage(_stages, msg) => CreateComputePipelineError::Internal(msg),
551            hal::PipelineError::EntryPoint(_stage) => CreateComputePipelineError::Internal(
552                crate::device::ENTRYPOINT_FAILURE_ERROR.to_string(),
553            ),
554            hal::PipelineError::PipelineConstants(_, error) => {
555                CreateComputePipelineError::PipelineConstants(error)
556            }
557        })?;
558
559    Ok(pipeline)
560}
561
562fn create_bind_group_layout(
563    device: &dyn hal::DynDevice,
564    read_only: bool,
565    has_dynamic_offset: bool,
566    min_binding_size: wgt::BufferSize,
567) -> Result<Box<dyn hal::DynBindGroupLayout>, CreateIndirectValidationPipelineError> {
568    let bind_group_layout_desc = hal::BindGroupLayoutDescriptor {
569        label: None,
570        flags: hal::BindGroupLayoutFlags::empty(),
571        entries: &[wgt::BindGroupLayoutEntry {
572            binding: 0,
573            visibility: wgt::ShaderStages::COMPUTE,
574            ty: wgt::BindingType::Buffer {
575                ty: wgt::BufferBindingType::Storage { read_only },
576                has_dynamic_offset,
577                min_binding_size: Some(min_binding_size),
578            },
579            count: None,
580        }],
581    };
582    let bind_group_layout = unsafe {
583        device
584            .create_bind_group_layout(&bind_group_layout_desc)
585            .map_err(DeviceError::from_hal)?
586    };
587
588    Ok(bind_group_layout)
589}
590
591/// Returns the largest binding size that when combined with dynamic offsets can address the whole buffer.
592fn calculate_src_buffer_binding_size(buffer_size: u64, limits: &Limits) -> u64 {
593    let max_storage_buffer_binding_size = limits.max_storage_buffer_binding_size as u64;
594    let min_storage_buffer_offset_alignment = limits.min_storage_buffer_offset_alignment as u64;
595
596    if buffer_size <= max_storage_buffer_binding_size {
597        buffer_size
598    } else {
599        let buffer_rem = buffer_size % min_storage_buffer_offset_alignment;
600        let binding_rem = max_storage_buffer_binding_size % min_storage_buffer_offset_alignment;
601
602        // Can the buffer remainder fit in the binding remainder?
603        // If so, align max binding size and add buffer remainder
604        if buffer_rem <= binding_rem {
605            max_storage_buffer_binding_size - binding_rem + buffer_rem
606        }
607        // If not, align max binding size, shorten it by a chunk and add buffer remainder
608        else {
609            max_storage_buffer_binding_size - binding_rem - min_storage_buffer_offset_alignment
610                + buffer_rem
611        }
612    }
613}
614
615/// Splits the given `offset` into a dynamic offset & offset.
616fn calculate_src_offsets(buffer_size: u64, limits: &Limits, offset: u64) -> (u64, u64) {
617    let binding_size = calculate_src_buffer_binding_size(buffer_size, limits);
618
619    let min_storage_buffer_offset_alignment = limits.min_storage_buffer_offset_alignment as u64;
620
621    let chunk_adjustment = match min_storage_buffer_offset_alignment {
622        // No need to adjust since the src_offset is 4 byte aligned.
623        4 => 0,
624        // With 16/20 bytes of data we can straddle up to 2 8 byte boundaries:
625        //  - 16 bytes of data: (4|8|4)
626        //  - 20 bytes of data: (4|8|8, 8|8|4)
627        8 => 2,
628        // With 16/20 bytes of data we can straddle up to 1 16+ byte boundary:
629        //  - 16 bytes of data: (4|12, 8|8, 12|4)
630        //  - 20 bytes of data: (4|16, 8|12, 12|8, 16|4)
631        16.. => 1,
632        _ => unreachable!(),
633    };
634
635    let chunks = binding_size / min_storage_buffer_offset_alignment;
636    let dynamic_offset_stride =
637        chunks.saturating_sub(chunk_adjustment) * min_storage_buffer_offset_alignment;
638
639    if dynamic_offset_stride == 0 {
640        return (0, offset);
641    }
642
643    let max_dynamic_offset = buffer_size - binding_size;
644    let max_dynamic_offset_index = max_dynamic_offset / dynamic_offset_stride;
645
646    let src_dynamic_offset_index = offset / dynamic_offset_stride;
647
648    let src_dynamic_offset =
649        src_dynamic_offset_index.min(max_dynamic_offset_index) * dynamic_offset_stride;
650    let src_offset = offset - src_dynamic_offset;
651
652    (src_dynamic_offset, src_offset)
653}
654
655#[derive(Debug)]
656struct BufferPoolEntry {
657    buffer: Box<dyn hal::DynBuffer>,
658    bind_group: Box<dyn hal::DynBindGroup>,
659}
660
661fn create_buffer_and_bind_group(
662    device: &dyn hal::DynDevice,
663    usage: wgt::BufferUses,
664    bind_group_layout: &dyn hal::DynBindGroupLayout,
665) -> Result<BufferPoolEntry, hal::DeviceError> {
666    let buffer_desc = hal::BufferDescriptor {
667        label: None,
668        size: BUFFER_SIZE.get(),
669        usage,
670        memory_flags: hal::MemoryFlags::empty(),
671    };
672    let buffer = unsafe { device.create_buffer(&buffer_desc) }?;
673    let bind_group_desc = hal::BindGroupDescriptor {
674        label: None,
675        layout: bind_group_layout,
676        entries: &[hal::BindGroupEntry {
677            binding: 0,
678            resource_index: 0,
679            count: 1,
680        }],
681        // SAFETY: We just created the buffer with this size.
682        buffers: &[hal::BufferBinding::new_unchecked(
683            buffer.as_ref(),
684            0,
685            BUFFER_SIZE,
686        )],
687        samplers: &[],
688        textures: &[],
689        acceleration_structures: &[],
690        external_textures: &[],
691    };
692    let bind_group = unsafe { device.create_bind_group(&bind_group_desc) }?;
693    Ok(BufferPoolEntry { buffer, bind_group })
694}
695
696#[derive(Clone)]
697struct CurrentEntry {
698    index: usize,
699    offset: u64,
700}
701
702/// Holds all command buffer-level resources that are needed to validate indirect draws.
703pub(crate) struct DrawResources {
704    device: Arc<Device>,
705    dst_entries: Vec<BufferPoolEntry>,
706    metadata_entries: Vec<BufferPoolEntry>,
707}
708
709impl Drop for DrawResources {
710    fn drop(&mut self) {
711        if let Some(ref indirect_validation) = self.device.indirect_validation {
712            let indirect_draw_validation = &indirect_validation.draw;
713            indirect_draw_validation.release_dst_entries(self.dst_entries.drain(..));
714            indirect_draw_validation.release_metadata_entries(self.metadata_entries.drain(..));
715        }
716    }
717}
718
719impl DrawResources {
720    pub(crate) fn new(device: Arc<Device>) -> Self {
721        DrawResources {
722            device,
723            dst_entries: Vec::new(),
724            metadata_entries: Vec::new(),
725        }
726    }
727
728    pub(crate) fn get_dst_buffer(&self, index: usize) -> &dyn hal::DynBuffer {
729        self.dst_entries.get(index).unwrap().buffer.as_ref()
730    }
731
732    fn get_dst_bind_group(&self, index: usize) -> &dyn hal::DynBindGroup {
733        self.dst_entries.get(index).unwrap().bind_group.as_ref()
734    }
735
736    fn get_metadata_buffer(&self, index: usize) -> &dyn hal::DynBuffer {
737        self.metadata_entries.get(index).unwrap().buffer.as_ref()
738    }
739
740    fn get_metadata_bind_group(&self, index: usize) -> &dyn hal::DynBindGroup {
741        self.metadata_entries
742            .get(index)
743            .unwrap()
744            .bind_group
745            .as_ref()
746    }
747
748    fn get_dst_subrange(
749        &mut self,
750        size: u64,
751        current_entry: &mut Option<CurrentEntry>,
752    ) -> Result<(usize, u64), DeviceError> {
753        let indirect_draw_validation = &self.device.indirect_validation.as_ref().unwrap().draw;
754        let ensure_entry = |index: usize| {
755            if self.dst_entries.len() <= index {
756                let entry = indirect_draw_validation.acquire_dst_entry(self.device.raw())?;
757                self.dst_entries.push(entry);
758            }
759            Ok(())
760        };
761        let entry_data = Self::get_subrange_impl(ensure_entry, current_entry, size)?;
762        Ok((entry_data.index, entry_data.offset))
763    }
764
765    fn get_metadata_subrange(
766        &mut self,
767        size: u64,
768        current_entry: &mut Option<CurrentEntry>,
769    ) -> Result<(usize, u64), DeviceError> {
770        let indirect_draw_validation = &self.device.indirect_validation.as_ref().unwrap().draw;
771        let ensure_entry = |index: usize| {
772            if self.metadata_entries.len() <= index {
773                let entry = indirect_draw_validation.acquire_metadata_entry(self.device.raw())?;
774                self.metadata_entries.push(entry);
775            }
776            Ok(())
777        };
778        let entry_data = Self::get_subrange_impl(ensure_entry, current_entry, size)?;
779        Ok((entry_data.index, entry_data.offset))
780    }
781
782    fn get_subrange_impl(
783        ensure_entry: impl FnOnce(usize) -> Result<(), hal::DeviceError>,
784        current_entry: &mut Option<CurrentEntry>,
785        size: u64,
786    ) -> Result<CurrentEntry, DeviceError> {
787        let index = if let Some(current_entry) = current_entry.as_mut() {
788            if current_entry.offset + size <= BUFFER_SIZE.get() {
789                let entry_data = current_entry.clone();
790                current_entry.offset += size;
791                return Ok(entry_data);
792            } else {
793                current_entry.index + 1
794            }
795        } else {
796            0
797        };
798
799        ensure_entry(index).map_err(DeviceError::from_hal)?;
800
801        let entry_data = CurrentEntry { index, offset: 0 };
802
803        *current_entry = Some(CurrentEntry {
804            index,
805            offset: size,
806        });
807
808        Ok(entry_data)
809    }
810}
811
812/// This must match the `MetadataEntry` struct used by the shader.
813#[repr(C)]
814struct MetadataEntry {
815    src_offset: u32,
816    dst_offset: u32,
817    vertex_or_index_limit: u32,
818    instance_limit: u32,
819}
820
821impl MetadataEntry {
822    fn new(
823        indexed: bool,
824        src_offset: u64,
825        dst_offset: u64,
826        vertex_or_index_limit: u64,
827        instance_limit: u64,
828    ) -> Self {
829        debug_assert_eq!(
830            4,
831            size_of_val(&Limits::default().max_storage_buffer_binding_size)
832        );
833
834        let src_offset = src_offset as u32; // max_storage_buffer_binding_size is a u32
835        let src_offset = src_offset / 4; // translate byte offset to offset in u32's
836
837        // `src_offset` needs at most 30 bits,
838        // pack `indexed` in bit 31 of `src_offset`
839        let src_offset = src_offset | ((indexed as u32) << 31);
840
841        // max value for limits since first_X and X_count indirect draw arguments are u32
842        let max_limit = u32::MAX as u64 + u32::MAX as u64; // 1 11111111 11111111 11111111 11111110
843
844        let vertex_or_index_limit = vertex_or_index_limit.min(max_limit);
845        let vertex_or_index_limit_bit_32 = (vertex_or_index_limit >> 32) as u32; // extract bit 32
846        let vertex_or_index_limit = vertex_or_index_limit as u32; // truncate the limit to a u32
847
848        let instance_limit = instance_limit.min(max_limit);
849        let instance_limit_bit_32 = (instance_limit >> 32) as u32; // extract bit 32
850        let instance_limit = instance_limit as u32; // truncate the limit to a u32
851
852        let dst_offset = dst_offset as u32; // max_storage_buffer_binding_size is a u32
853        let dst_offset = dst_offset / 4; // translate byte offset to offset in u32's
854
855        // `dst_offset` needs at most 30 bits,
856        // pack `vertex_or_index_limit_bit_32` in bit 30 of `dst_offset` and
857        // pack `instance_limit_bit_32` in bit 31 of `dst_offset`
858        let dst_offset =
859            dst_offset | (vertex_or_index_limit_bit_32 << 30) | (instance_limit_bit_32 << 31);
860
861        Self {
862            src_offset,
863            dst_offset,
864            vertex_or_index_limit,
865            instance_limit,
866        }
867    }
868}
869
870struct DrawIndirectValidationBatch {
871    src_buffer: Arc<crate::resource::Buffer>,
872    src_dynamic_offset: u64,
873    dst_resource_index: usize,
874    entries: Vec<MetadataEntry>,
875
876    staging_buffer_index: usize,
877    staging_buffer_offset: u64,
878    metadata_resource_index: usize,
879    metadata_buffer_offset: u64,
880}
881
882impl DrawIndirectValidationBatch {
883    /// Data to be written to the metadata buffer.
884    fn metadata(&self) -> &[u8] {
885        unsafe {
886            core::slice::from_raw_parts(
887                self.entries.as_ptr().cast::<u8>(),
888                self.entries.len() * size_of::<MetadataEntry>(),
889            )
890        }
891    }
892}
893
894/// Accumulates all needed data needed to validate indirect draws.
895pub(crate) struct DrawBatcher {
896    batches: FastHashMap<(TrackerIndex, u64, usize), DrawIndirectValidationBatch>,
897    current_dst_entry: Option<CurrentEntry>,
898}
899
900impl DrawBatcher {
901    pub(crate) fn new() -> Self {
902        Self {
903            batches: FastHashMap::default(),
904            current_dst_entry: None,
905        }
906    }
907
908    /// Add an indirect draw to be validated.
909    ///
910    /// Returns the index of the indirect buffer in `indirect_draw_validation_resources`
911    /// and the offset to be used for the draw.
912    pub(crate) fn add<'a>(
913        &mut self,
914        indirect_draw_validation_resources: &'a mut DrawResources,
915        device: &Device,
916        src_buffer: &Arc<crate::resource::Buffer>,
917        offset: u64,
918        family: crate::command::DrawCommandFamily,
919        vertex_or_index_limit: u64,
920        instance_limit: u64,
921    ) -> Result<(usize, u64), DeviceError> {
922        // space for D3D12 special constants
923        let extra = if device.backend() == wgt::Backend::Dx12 {
924            3 * size_of::<u32>() as u64
925        } else {
926            0
927        };
928        let stride = extra + crate::command::get_stride_of_indirect_args(family);
929
930        let (dst_resource_index, dst_offset) = indirect_draw_validation_resources
931            .get_dst_subrange(stride, &mut self.current_dst_entry)?;
932
933        let buffer_size = src_buffer.size;
934        let limits = device.adapter.limits();
935        let (src_dynamic_offset, src_offset) = calculate_src_offsets(buffer_size, &limits, offset);
936
937        let src_buffer_tracker_index = src_buffer.tracker_index();
938
939        let entry = MetadataEntry::new(
940            family == crate::command::DrawCommandFamily::DrawIndexed,
941            src_offset,
942            dst_offset,
943            vertex_or_index_limit,
944            instance_limit,
945        );
946
947        match self.batches.entry((
948            src_buffer_tracker_index,
949            src_dynamic_offset,
950            dst_resource_index,
951        )) {
952            hashbrown::hash_map::Entry::Occupied(mut occupied_entry) => {
953                occupied_entry.get_mut().entries.push(entry)
954            }
955            hashbrown::hash_map::Entry::Vacant(vacant_entry) => {
956                vacant_entry.insert(DrawIndirectValidationBatch {
957                    src_buffer: src_buffer.clone(),
958                    src_dynamic_offset,
959                    dst_resource_index,
960                    entries: vec![entry],
961
962                    // these will be initialized once we accumulated all entries for the batch
963                    staging_buffer_index: 0,
964                    staging_buffer_offset: 0,
965                    metadata_resource_index: 0,
966                    metadata_buffer_offset: 0,
967                });
968            }
969        }
970
971        Ok((dst_resource_index, dst_offset))
972    }
973}