1use super::{
2 utils::{BufferBarrierScratch, BufferBarriers, UniqueIndexExt as _, UniqueIndexScratch},
3 CreateIndirectValidationPipelineError,
4};
5use crate::{
6 command::{get_src_stride_of_indirect_args, RenderPassErrorInner},
7 device::{queue::TempResource, Device, DeviceError},
8 hal_label,
9 lock::{rank, Mutex},
10 pipeline::{CreateComputePipelineError, CreateShaderModuleError},
11 resource::{RawResourceAccess as _, StagingBuffer, Trackable},
12 snatch::SnatchGuard,
13 track::TrackerIndex,
14 FastHashMap,
15};
16use alloc::{boxed::Box, string::ToString, sync::Arc, vec, vec::Vec};
17use core::{mem::size_of, num::NonZeroU64};
18use wgt::Limits;
19
20const BUFFER_SIZE: wgt::BufferSize = wgt::BufferSize::new(1_048_560).unwrap();
35
36#[derive(Debug)]
48pub(crate) struct Draw {
49 module: Box<dyn hal::DynShaderModule>,
50 metadata_bind_group_layout: Box<dyn hal::DynBindGroupLayout>,
51 src_bind_group_layout: Box<dyn hal::DynBindGroupLayout>,
52 dst_bind_group_layout: Box<dyn hal::DynBindGroupLayout>,
53 pipeline_layout: Box<dyn hal::DynPipelineLayout>,
54 pipeline: Box<dyn hal::DynComputePipeline>,
55
56 free_indirect_entries: Mutex<Vec<BufferPoolEntry>>,
57 free_metadata_entries: Mutex<Vec<BufferPoolEntry>>,
58}
59
60impl Draw {
61 pub(super) fn new(
62 device: &dyn hal::DynDevice,
63 required_features: &wgt::Features,
64 instance_flags: wgt::InstanceFlags,
65 backend: wgt::Backend,
66 limits: &Limits,
67 ) -> Result<Self, CreateIndirectValidationPipelineError> {
68 assert!(limits.max_buffer_size <= u32::MAX as u64);
73
74 let module = create_validation_module(device, instance_flags)?;
75
76 let metadata_bind_group_layout = create_bind_group_layout(
77 device,
78 true,
79 false,
80 BUFFER_SIZE,
81 hal_label(
82 Some("(wgpu internal) Indirect draw validation metadata bind group layout"),
83 instance_flags,
84 ),
85 )?;
86 let src_bind_group_layout = create_bind_group_layout(
87 device,
88 true,
89 true,
90 wgt::BufferSize::new(4 * 4).unwrap(),
91 hal_label(
92 Some("(wgpu internal) Indirect draw validation source bind group layout"),
93 instance_flags,
94 ),
95 )?;
96 let dst_bind_group_layout = create_bind_group_layout(
97 device,
98 false,
99 false,
100 BUFFER_SIZE,
101 hal_label(
102 Some("(wgpu internal) Indirect draw validation destination bind group layout"),
103 instance_flags,
104 ),
105 )?;
106
107 let pipeline_layout_desc = hal::PipelineLayoutDescriptor {
108 label: hal_label(
109 Some("(wgpu internal) Indirect draw validation pipeline layout"),
110 instance_flags,
111 ),
112 flags: hal::PipelineLayoutFlags::empty(),
113 bind_group_layouts: &[
114 Some(metadata_bind_group_layout.as_ref()),
115 Some(src_bind_group_layout.as_ref()),
116 Some(dst_bind_group_layout.as_ref()),
117 ],
118 immediate_size: 8,
119 };
120 let pipeline_layout = unsafe {
121 device
122 .create_pipeline_layout(&pipeline_layout_desc)
123 .map_err(DeviceError::from_hal)?
124 };
125
126 let supports_indirect_first_instance =
127 required_features.contains(wgt::Features::INDIRECT_FIRST_INSTANCE);
128 let write_d3d12_special_constants = backend == wgt::Backend::Dx12;
129 let pipeline = create_validation_pipeline(
130 device,
131 module.as_ref(),
132 pipeline_layout.as_ref(),
133 supports_indirect_first_instance,
134 write_d3d12_special_constants,
135 instance_flags,
136 )?;
137
138 Ok(Self {
139 module,
140 metadata_bind_group_layout,
141 src_bind_group_layout,
142 dst_bind_group_layout,
143 pipeline_layout,
144 pipeline,
145
146 free_indirect_entries: Mutex::new(rank::BUFFER_POOL, Vec::new()),
147 free_metadata_entries: Mutex::new(rank::BUFFER_POOL, Vec::new()),
148 })
149 }
150
151 pub(super) fn create_src_bind_group(
153 &self,
154 device: &dyn hal::DynDevice,
155 limits: &Limits,
156 buffer_size: u64,
157 buffer: &dyn hal::DynBuffer,
158 instance_flags: wgt::InstanceFlags,
159 ) -> Result<Option<Box<dyn hal::DynBindGroup>>, DeviceError> {
160 let binding_size = calculate_src_buffer_binding_size(buffer_size, limits);
161 let Some(binding_size) = NonZeroU64::new(binding_size) else {
162 return Ok(None);
163 };
164 let hal_desc = hal::BindGroupDescriptor {
165 label: hal_label(
166 Some("(wgpu internal) Indirect draw validation source bind group"),
167 instance_flags,
168 ),
169 layout: self.src_bind_group_layout.as_ref(),
170 entries: &[hal::BindGroupEntry {
171 binding: 0,
172 resource_index: 0,
173 count: 1,
174 }],
175 buffers: &[hal::BufferBinding::new_unchecked(buffer, 0, binding_size)],
177 samplers: &[],
178 textures: &[],
179 acceleration_structures: &[],
180 external_textures: &[],
181 };
182 unsafe {
183 device
184 .create_bind_group(&hal_desc)
185 .map(Some)
186 .map_err(DeviceError::from_hal)
187 }
188 }
189
190 fn acquire_dst_entry(
191 &self,
192 device: &dyn hal::DynDevice,
193 instance_flags: wgt::InstanceFlags,
194 ) -> Result<BufferPoolEntry, hal::DeviceError> {
195 let mut free_buffers = self.free_indirect_entries.lock();
196 match free_buffers.pop() {
197 Some(buffer) => Ok(buffer),
198 None => {
199 let usage = wgt::BufferUses::INDIRECT | wgt::BufferUses::STORAGE_READ_WRITE;
200 create_buffer_and_bind_group(
201 device,
202 usage,
203 self.dst_bind_group_layout.as_ref(),
204 hal_label(Some("(wgpu internal) Indirect draw validation destination buffer"), instance_flags),
205 hal_label(Some("(wgpu internal) Indirect draw validation destination bind group layout"), instance_flags),
206 )
207 }
208 }
209 }
210
211 fn release_dst_entries(&self, entries: impl Iterator<Item = BufferPoolEntry>) {
212 self.free_indirect_entries.lock().extend(entries);
213 }
214
215 fn acquire_metadata_entry(
216 &self,
217 device: &dyn hal::DynDevice,
218 instance_flags: wgt::InstanceFlags,
219 ) -> Result<BufferPoolEntry, hal::DeviceError> {
220 let mut free_buffers = self.free_metadata_entries.lock();
221 match free_buffers.pop() {
222 Some(buffer) => Ok(buffer),
223 None => {
224 let usage = wgt::BufferUses::COPY_DST | wgt::BufferUses::STORAGE_READ_ONLY;
225 create_buffer_and_bind_group(
226 device,
227 usage,
228 self.metadata_bind_group_layout.as_ref(),
229 hal_label(
230 Some("(wgpu internal) Indirect draw validation metadata buffer"),
231 instance_flags,
232 ),
233 hal_label(
234 Some("(wgpu internal) Indirect draw validation metadata bind group layout"),
235 instance_flags,
236 ),
237 )
238 }
239 }
240 }
241
242 fn release_metadata_entries(&self, entries: impl Iterator<Item = BufferPoolEntry>) {
243 self.free_metadata_entries.lock().extend(entries);
244 }
245
246 pub(crate) fn inject_validation_pass(
248 &self,
249 device: &Arc<Device>,
250 snatch_guard: &SnatchGuard,
251 resources: &mut DrawResources,
252 temp_resources: &mut Vec<TempResource>,
253 encoder: &mut dyn hal::DynCommandEncoder,
254 batcher: DrawBatcher,
255 ) -> Result<(), RenderPassErrorInner> {
256 let mut batches = batcher.batches;
257
258 if batches.is_empty() {
259 return Ok(());
260 }
261
262 let max_staging_buffer_size = 1 << 26; let mut staging_buffers = Vec::new();
265
266 let mut current_size = 0;
267 for batch in batches.values_mut() {
268 let data = batch.metadata();
269 let offset = if current_size + data.len() > max_staging_buffer_size {
270 let staging_buffer =
271 StagingBuffer::new(device, NonZeroU64::new(current_size as u64).unwrap())?;
272 staging_buffers.push(staging_buffer);
273 current_size = data.len();
274 0
275 } else {
276 let offset = current_size;
277 current_size += data.len();
278 offset as u64
279 };
280 batch.staging_buffer_index = staging_buffers.len();
281 batch.staging_buffer_offset = offset;
282 }
283 if current_size != 0 {
284 let staging_buffer =
285 StagingBuffer::new(device, NonZeroU64::new(current_size as u64).unwrap())?;
286 staging_buffers.push(staging_buffer);
287 }
288
289 for batch in batches.values() {
290 let data = batch.metadata();
291 let staging_buffer = &mut staging_buffers[batch.staging_buffer_index];
292 unsafe {
293 staging_buffer.write_with_offset(
294 data,
295 0,
296 batch.staging_buffer_offset as isize,
297 data.len(),
298 )
299 };
300 }
301
302 let staging_buffers: Vec<_> = staging_buffers
303 .into_iter()
304 .map(|buffer| buffer.flush())
305 .collect();
306
307 let mut current_metadata_entry = None;
308 for batch in batches.values_mut() {
309 let data = batch.metadata();
310 let (metadata_resource_index, metadata_buffer_offset) =
311 resources.get_metadata_subrange(data.len() as u64, &mut current_metadata_entry)?;
312 batch.metadata_resource_index = metadata_resource_index;
313 batch.metadata_buffer_offset = metadata_buffer_offset;
314 }
315
316 let buffer_barrier_scratch = &mut BufferBarrierScratch::new();
317 let unique_index_scratch = &mut UniqueIndexScratch::new();
318
319 BufferBarriers::new(buffer_barrier_scratch)
320 .extend(
321 batches
322 .values()
323 .map(|batch| batch.staging_buffer_index)
324 .unique(unique_index_scratch)
325 .map(|index| hal::BufferBarrier {
326 buffer: staging_buffers[index].raw(),
327 usage: hal::StateTransition {
328 from: wgt::BufferUses::MAP_WRITE,
329 to: wgt::BufferUses::COPY_SRC,
330 },
331 }),
332 )
333 .extend(
334 batches
335 .values()
336 .map(|batch| batch.metadata_resource_index)
337 .unique(unique_index_scratch)
338 .map(|index| hal::BufferBarrier {
339 buffer: resources.get_metadata_buffer(index),
340 usage: hal::StateTransition {
341 from: wgt::BufferUses::STORAGE_READ_ONLY,
342 to: wgt::BufferUses::COPY_DST,
343 },
344 }),
345 )
346 .encode(encoder);
347
348 for batch in batches.values() {
349 let data = batch.metadata();
350 let data_size = NonZeroU64::new(data.len() as u64).unwrap();
351
352 let staging_buffer = &staging_buffers[batch.staging_buffer_index];
353
354 let metadata_buffer = resources.get_metadata_buffer(batch.metadata_resource_index);
355
356 unsafe {
357 encoder.copy_buffer_to_buffer(
358 staging_buffer.raw(),
359 metadata_buffer,
360 &[hal::BufferCopy {
361 src_offset: batch.staging_buffer_offset,
362 dst_offset: batch.metadata_buffer_offset,
363 size: data_size,
364 }],
365 );
366 }
367 }
368
369 for staging_buffer in staging_buffers {
370 temp_resources.push(TempResource::StagingBuffer(staging_buffer));
371 }
372
373 BufferBarriers::new(buffer_barrier_scratch)
374 .extend(
375 batches
376 .values()
377 .map(|batch| batch.metadata_resource_index)
378 .unique(unique_index_scratch)
379 .map(|index| hal::BufferBarrier {
380 buffer: resources.get_metadata_buffer(index),
381 usage: hal::StateTransition {
382 from: wgt::BufferUses::COPY_DST,
383 to: wgt::BufferUses::STORAGE_READ_ONLY,
384 },
385 }),
386 )
387 .extend(
388 batches
389 .values()
390 .map(|batch| batch.dst_resource_index)
391 .unique(unique_index_scratch)
392 .map(|index| hal::BufferBarrier {
393 buffer: resources.get_dst_buffer(index),
394 usage: hal::StateTransition {
395 from: wgt::BufferUses::INDIRECT,
396 to: wgt::BufferUses::STORAGE_READ_WRITE,
397 },
398 }),
399 )
400 .encode(encoder);
401
402 let desc = hal::ComputePassDescriptor {
403 label: hal_label(
404 Some("(wgpu internal) Indirect draw validation pass"),
405 device.instance_flags,
406 ),
407 timestamp_writes: None,
408 };
409 unsafe {
410 encoder.begin_compute_pass(&desc);
411 }
412 unsafe {
413 encoder.set_compute_pipeline(self.pipeline.as_ref());
414 }
415
416 for batch in batches.values() {
417 let pipeline_layout = self.pipeline_layout.as_ref();
418
419 let metadata_start =
420 (batch.metadata_buffer_offset / size_of::<MetadataEntry>() as u64) as u32;
421 let metadata_count = batch.entries.len() as u32;
422 unsafe {
423 encoder.set_immediates(pipeline_layout, 0, &[metadata_start, metadata_count]);
424 }
425
426 let metadata_bind_group =
427 resources.get_metadata_bind_group(batch.metadata_resource_index);
428 unsafe {
429 encoder.set_bind_group(pipeline_layout, 0, metadata_bind_group, &[]);
430 }
431
432 batch.src_buffer.try_raw(snatch_guard)?;
434
435 let src_bind_group = batch
436 .src_buffer
437 .indirect_validation_bind_groups
438 .get(snatch_guard)
439 .unwrap()
440 .draw
441 .as_ref();
442 unsafe {
443 encoder.set_bind_group(
444 pipeline_layout,
445 1,
446 src_bind_group,
447 &[u64_offset_to_u32_offset(batch.src_dynamic_offset)],
448 );
449 }
450
451 let dst_bind_group = resources.get_dst_bind_group(batch.dst_resource_index);
452 unsafe {
453 encoder.set_bind_group(pipeline_layout, 2, dst_bind_group, &[]);
454 }
455
456 unsafe {
457 encoder.dispatch_workgroups([(batch.entries.len() as u32).div_ceil(64), 1, 1]);
458 }
459 }
460
461 unsafe {
462 encoder.end_compute_pass();
463 }
464
465 BufferBarriers::new(buffer_barrier_scratch)
466 .extend(
467 batches
468 .values()
469 .map(|batch| batch.dst_resource_index)
470 .unique(unique_index_scratch)
471 .map(|index| hal::BufferBarrier {
472 buffer: resources.get_dst_buffer(index),
473 usage: hal::StateTransition {
474 from: wgt::BufferUses::STORAGE_READ_WRITE,
475 to: wgt::BufferUses::INDIRECT,
476 },
477 }),
478 )
479 .encode(encoder);
480
481 Ok(())
482 }
483
484 pub(super) fn dispose(self, device: &dyn hal::DynDevice) {
485 let Draw {
486 module,
487 metadata_bind_group_layout,
488 src_bind_group_layout,
489 dst_bind_group_layout,
490 pipeline_layout,
491 pipeline,
492
493 free_indirect_entries,
494 free_metadata_entries,
495 } = self;
496
497 for entry in free_indirect_entries.into_inner().drain(..) {
498 unsafe {
499 device.destroy_bind_group(entry.bind_group);
500 device.destroy_buffer(entry.buffer);
501 }
502 }
503
504 for entry in free_metadata_entries.into_inner().drain(..) {
505 unsafe {
506 device.destroy_bind_group(entry.bind_group);
507 device.destroy_buffer(entry.buffer);
508 }
509 }
510
511 unsafe {
512 device.destroy_compute_pipeline(pipeline);
513 device.destroy_pipeline_layout(pipeline_layout);
514 device.destroy_bind_group_layout(metadata_bind_group_layout);
515 device.destroy_bind_group_layout(src_bind_group_layout);
516 device.destroy_bind_group_layout(dst_bind_group_layout);
517 device.destroy_shader_module(module);
518 }
519 }
520}
521
522fn create_validation_module(
523 device: &dyn hal::DynDevice,
524 instance_flags: wgt::InstanceFlags,
525) -> Result<Box<dyn hal::DynShaderModule>, CreateIndirectValidationPipelineError> {
526 let src = include_str!("./validate_draw.wgsl");
527
528 #[cfg(feature = "wgsl")]
529 let module = naga::front::wgsl::parse_str(src).map_err(|inner| {
530 CreateShaderModuleError::Parsing(naga::error::ShaderError {
531 source: src.to_string(),
532 label: None,
533 inner: Box::new(inner),
534 })
535 })?;
536 #[cfg(not(feature = "wgsl"))]
537 #[allow(clippy::diverging_sub_expression)]
538 let module = panic!("Indirect validation requires the wgsl feature flag to be enabled!");
539
540 let info = crate::device::create_validator(
541 wgt::Features::IMMEDIATES,
542 wgt::DownlevelFlags::empty(),
543 naga::valid::ValidationFlags::all(),
544 )
545 .validate(&module)
546 .map_err(|inner| {
547 CreateShaderModuleError::Validation(naga::error::ShaderError {
548 source: src.to_string(),
549 label: None,
550 inner: Box::new(inner),
551 })
552 })?;
553 let hal_shader = hal::ShaderInput::Naga(hal::NagaShader {
554 module: alloc::borrow::Cow::Owned(module),
555 info,
556 debug_source: None,
557 });
558 let hal_desc = hal::ShaderModuleDescriptor {
559 label: hal_label(
560 Some("(wgpu internal) Indirect draw validation shader module"),
561 instance_flags,
562 ),
563 runtime_checks: wgt::ShaderRuntimeChecks::unchecked(),
564 };
565 let module = unsafe { device.create_shader_module(&hal_desc, hal_shader) }.map_err(
566 |error| match error {
567 hal::ShaderError::Device(error) => {
568 CreateShaderModuleError::Device(DeviceError::from_hal(error))
569 }
570 hal::ShaderError::Compilation(ref msg) => {
571 log::error!("Shader error: {msg}");
572 CreateShaderModuleError::Generation
573 }
574 },
575 )?;
576
577 Ok(module)
578}
579
580fn create_validation_pipeline(
581 device: &dyn hal::DynDevice,
582 module: &dyn hal::DynShaderModule,
583 pipeline_layout: &dyn hal::DynPipelineLayout,
584 supports_indirect_first_instance: bool,
585 write_d3d12_special_constants: bool,
586 instance_flags: wgt::InstanceFlags,
587) -> Result<Box<dyn hal::DynComputePipeline>, CreateIndirectValidationPipelineError> {
588 let pipeline_desc = hal::ComputePipelineDescriptor {
589 label: hal_label(
590 Some("(wgpu internal) Indirect draw validation pipeline"),
591 instance_flags,
592 ),
593 layout: pipeline_layout,
594 stage: hal::ProgrammableStage {
595 module,
596 entry_point: "main",
597 constants: &hashbrown::HashMap::from([
598 (
599 "supports_indirect_first_instance".to_string(),
600 f64::from(supports_indirect_first_instance),
601 ),
602 (
603 "write_d3d12_special_constants".to_string(),
604 f64::from(write_d3d12_special_constants),
605 ),
606 ]),
607 zero_initialize_workgroup_memory: false,
608 },
609 cache: None,
610 };
611 let pipeline =
612 unsafe { device.create_compute_pipeline(&pipeline_desc) }.map_err(|err| match err {
613 hal::PipelineError::Device(error) => {
614 CreateComputePipelineError::Device(DeviceError::from_hal(error))
615 }
616 hal::PipelineError::Linkage(_stages, msg) => CreateComputePipelineError::Internal(msg),
617 hal::PipelineError::EntryPoint(_stage) => CreateComputePipelineError::Internal(
618 crate::device::ENTRYPOINT_FAILURE_ERROR.to_string(),
619 ),
620 hal::PipelineError::PipelineConstants(_, error) => {
621 CreateComputePipelineError::PipelineConstants(error)
622 }
623 })?;
624
625 Ok(pipeline)
626}
627
628fn create_bind_group_layout(
629 device: &dyn hal::DynDevice,
630 read_only: bool,
631 has_dynamic_offset: bool,
632 min_binding_size: wgt::BufferSize,
633 label: Option<&'static str>,
634) -> Result<Box<dyn hal::DynBindGroupLayout>, CreateIndirectValidationPipelineError> {
635 let bind_group_layout_desc = hal::BindGroupLayoutDescriptor {
636 label,
637 flags: hal::BindGroupLayoutFlags::empty(),
638 entries: &[wgt::BindGroupLayoutEntry {
639 binding: 0,
640 visibility: wgt::ShaderStages::COMPUTE,
641 ty: wgt::BindingType::Buffer {
642 ty: wgt::BufferBindingType::Storage { read_only },
643 has_dynamic_offset,
644 min_binding_size: Some(min_binding_size),
645 },
646 count: None,
647 }],
648 };
649 let bind_group_layout = unsafe {
650 device
651 .create_bind_group_layout(&bind_group_layout_desc)
652 .map_err(DeviceError::from_hal)?
653 };
654
655 Ok(bind_group_layout)
656}
657
658fn calculate_src_buffer_binding_size(buffer_size: u64, limits: &Limits) -> u64 {
660 let max_storage_buffer_binding_size = limits.max_storage_buffer_binding_size;
661 let min_storage_buffer_offset_alignment = limits.min_storage_buffer_offset_alignment as u64;
662
663 if buffer_size <= max_storage_buffer_binding_size {
664 buffer_size
665 } else {
666 let buffer_rem = buffer_size % min_storage_buffer_offset_alignment;
667 let binding_rem = max_storage_buffer_binding_size % min_storage_buffer_offset_alignment;
668
669 if buffer_rem <= binding_rem {
672 max_storage_buffer_binding_size - binding_rem + buffer_rem
673 }
674 else {
676 max_storage_buffer_binding_size - binding_rem - min_storage_buffer_offset_alignment
677 + buffer_rem
678 }
679 }
680}
681
682fn calculate_src_offsets(
684 buffer_size: u64,
685 limits: &Limits,
686 offset: u64,
687 data_size: u64,
688) -> (u64, u64) {
689 const MAX_DATA_SIZE: u64 = 20; let binding_size = calculate_src_buffer_binding_size(buffer_size, limits);
691 let min_storage_buffer_offset_alignment = limits.min_storage_buffer_offset_alignment as u64;
692
693 assert!([16, MAX_DATA_SIZE].contains(&data_size));
694 assert!([32, 64, 128, 256].contains(&min_storage_buffer_offset_alignment));
695 assert!(buffer_size >= data_size);
696 assert!(offset <= buffer_size - data_size);
697 assert!(binding_size <= buffer_size);
698
699 let dynamic_offset_stride = binding_size.saturating_sub(MAX_DATA_SIZE)
717 / min_storage_buffer_offset_alignment
718 * min_storage_buffer_offset_alignment;
719 if dynamic_offset_stride == 0 {
720 return (0, offset);
721 }
722
723 let max_dynamic_offset = buffer_size - binding_size;
724 let out_dynamic_offset =
725 max_dynamic_offset.min(offset / dynamic_offset_stride * dynamic_offset_stride);
726 let out_offset = offset - out_dynamic_offset;
727
728 (out_dynamic_offset, out_offset)
729}
730
731#[derive(Debug)]
732struct BufferPoolEntry {
733 buffer: Box<dyn hal::DynBuffer>,
734 bind_group: Box<dyn hal::DynBindGroup>,
735}
736
737fn create_buffer_and_bind_group(
738 device: &dyn hal::DynDevice,
739 usage: wgt::BufferUses,
740 bind_group_layout: &dyn hal::DynBindGroupLayout,
741 buffer_label: Option<&'static str>,
742 bind_group_label: Option<&'static str>,
743) -> Result<BufferPoolEntry, hal::DeviceError> {
744 let buffer_desc = hal::BufferDescriptor {
745 label: buffer_label,
746 size: BUFFER_SIZE.get(),
747 usage,
748 memory_flags: hal::MemoryFlags::empty(),
749 };
750 let buffer = unsafe { device.create_buffer(&buffer_desc) }?;
751 let bind_group_desc = hal::BindGroupDescriptor {
752 label: bind_group_label,
753 layout: bind_group_layout,
754 entries: &[hal::BindGroupEntry {
755 binding: 0,
756 resource_index: 0,
757 count: 1,
758 }],
759 buffers: &[hal::BufferBinding::new_unchecked(
761 buffer.as_ref(),
762 0,
763 BUFFER_SIZE,
764 )],
765 samplers: &[],
766 textures: &[],
767 acceleration_structures: &[],
768 external_textures: &[],
769 };
770 let bind_group = unsafe { device.create_bind_group(&bind_group_desc) }?;
771 Ok(BufferPoolEntry { buffer, bind_group })
772}
773
774#[derive(Clone)]
775struct CurrentEntry {
776 index: usize,
777 offset: u64,
778}
779
780pub(crate) struct DrawResources {
782 device: Arc<Device>,
783 dst_entries: Vec<BufferPoolEntry>,
784 metadata_entries: Vec<BufferPoolEntry>,
785}
786
787impl Drop for DrawResources {
788 fn drop(&mut self) {
789 if let Some(ref indirect_validation) = self.device.indirect_validation {
790 let indirect_draw_validation = &indirect_validation.draw;
791 indirect_draw_validation.release_dst_entries(self.dst_entries.drain(..));
792 indirect_draw_validation.release_metadata_entries(self.metadata_entries.drain(..));
793 }
794 }
795}
796
797impl DrawResources {
798 pub(crate) fn new(device: Arc<Device>) -> Self {
799 DrawResources {
800 device,
801 dst_entries: Vec::new(),
802 metadata_entries: Vec::new(),
803 }
804 }
805
806 pub(crate) fn get_dst_buffer(&self, index: usize) -> &dyn hal::DynBuffer {
807 self.dst_entries.get(index).unwrap().buffer.as_ref()
808 }
809
810 fn get_dst_bind_group(&self, index: usize) -> &dyn hal::DynBindGroup {
811 self.dst_entries.get(index).unwrap().bind_group.as_ref()
812 }
813
814 fn get_metadata_buffer(&self, index: usize) -> &dyn hal::DynBuffer {
815 self.metadata_entries.get(index).unwrap().buffer.as_ref()
816 }
817
818 fn get_metadata_bind_group(&self, index: usize) -> &dyn hal::DynBindGroup {
819 self.metadata_entries
820 .get(index)
821 .unwrap()
822 .bind_group
823 .as_ref()
824 }
825
826 fn get_dst_subrange(
827 &mut self,
828 size: u64,
829 current_entry: &mut Option<CurrentEntry>,
830 ) -> Result<(usize, u64), DeviceError> {
831 let indirect_draw_validation = &self.device.indirect_validation.as_ref().unwrap().draw;
832 let ensure_entry = |index: usize| {
833 if self.dst_entries.len() <= index {
834 let entry = indirect_draw_validation
835 .acquire_dst_entry(self.device.raw(), self.device.instance_flags)?;
836 self.dst_entries.push(entry);
837 }
838 Ok(())
839 };
840 let entry_data = Self::get_subrange_impl(ensure_entry, current_entry, size)?;
841 Ok((entry_data.index, entry_data.offset))
842 }
843
844 fn get_metadata_subrange(
845 &mut self,
846 size: u64,
847 current_entry: &mut Option<CurrentEntry>,
848 ) -> Result<(usize, u64), DeviceError> {
849 let indirect_draw_validation = &self.device.indirect_validation.as_ref().unwrap().draw;
850 let ensure_entry = |index: usize| {
851 if self.metadata_entries.len() <= index {
852 let entry = indirect_draw_validation
853 .acquire_metadata_entry(self.device.raw(), self.device.instance_flags)?;
854 self.metadata_entries.push(entry);
855 }
856 Ok(())
857 };
858 let entry_data = Self::get_subrange_impl(ensure_entry, current_entry, size)?;
859 Ok((entry_data.index, entry_data.offset))
860 }
861
862 fn get_subrange_impl(
863 ensure_entry: impl FnOnce(usize) -> Result<(), hal::DeviceError>,
864 current_entry: &mut Option<CurrentEntry>,
865 size: u64,
866 ) -> Result<CurrentEntry, DeviceError> {
867 let index = if let Some(current_entry) = current_entry.as_mut() {
868 if current_entry.offset + size <= BUFFER_SIZE.get() {
869 let entry_data = current_entry.clone();
870 current_entry.offset += size;
871 return Ok(entry_data);
872 } else {
873 current_entry.index + 1
874 }
875 } else {
876 0
877 };
878
879 ensure_entry(index).map_err(DeviceError::from_hal)?;
880
881 let entry_data = CurrentEntry { index, offset: 0 };
882
883 *current_entry = Some(CurrentEntry {
884 index,
885 offset: size,
886 });
887
888 Ok(entry_data)
889 }
890}
891
892#[repr(C)]
894struct MetadataEntry {
895 src_offset: u32,
896 dst_offset: u32,
897 vertex_or_index_limit: u32,
898 instance_limit: u32,
899}
900
901impl MetadataEntry {
902 fn new(
903 indexed: bool,
904 src_offset: u64,
905 dst_offset: u64,
906 vertex_or_index_limit: u64,
907 instance_limit: u64,
908 ) -> Self {
909 const U32_MAX_AS_U64: u64 = u32::MAX as u64;
910
911 let src_offset = u64_offset_to_u32_offset(src_offset);
912 let src_offset = src_offset / 4; let src_offset = src_offset | ((indexed as u32) << 31);
917
918 let max_limit = U32_MAX_AS_U64 + U32_MAX_AS_U64; let vertex_or_index_limit = vertex_or_index_limit.min(max_limit);
922 let vertex_or_index_limit_bit_32 = (vertex_or_index_limit >> 32) as u32; let vertex_or_index_limit = vertex_or_index_limit as u32; let instance_limit = instance_limit.min(max_limit);
926 let instance_limit_bit_32 = (instance_limit >> 32) as u32; let instance_limit = instance_limit as u32; let dst_offset = u64_offset_to_u32_offset(dst_offset);
930 let dst_offset = dst_offset / 4; let dst_offset =
936 dst_offset | (vertex_or_index_limit_bit_32 << 30) | (instance_limit_bit_32 << 31);
937
938 Self {
939 src_offset,
940 dst_offset,
941 vertex_or_index_limit,
942 instance_limit,
943 }
944 }
945}
946
947struct DrawIndirectValidationBatch {
948 src_buffer: Arc<crate::resource::Buffer>,
949 src_dynamic_offset: u64,
950 dst_resource_index: usize,
951 entries: Vec<MetadataEntry>,
952
953 staging_buffer_index: usize,
954 staging_buffer_offset: u64,
955 metadata_resource_index: usize,
956 metadata_buffer_offset: u64,
957}
958
959impl DrawIndirectValidationBatch {
960 fn metadata(&self) -> &[u8] {
962 unsafe {
963 core::slice::from_raw_parts(
964 self.entries.as_ptr().cast::<u8>(),
965 self.entries.len() * size_of::<MetadataEntry>(),
966 )
967 }
968 }
969}
970
971pub(crate) struct DrawBatcher {
973 batches: FastHashMap<(TrackerIndex, u64, usize), DrawIndirectValidationBatch>,
974 current_dst_entry: Option<CurrentEntry>,
975}
976
977impl DrawBatcher {
978 pub(crate) fn new() -> Self {
979 Self {
980 batches: FastHashMap::default(),
981 current_dst_entry: None,
982 }
983 }
984
985 pub(crate) fn add<'a>(
990 &mut self,
991 indirect_draw_validation_resources: &'a mut DrawResources,
992 device: &Device,
993 src_buffer: &Arc<crate::resource::Buffer>,
994 offset: u64,
995 family: crate::command::DrawCommandFamily,
996 vertex_or_index_limit: u64,
997 instance_limit: u64,
998 ) -> Result<(usize, u64), DeviceError> {
999 let stride = crate::command::get_dst_stride_of_indirect_args(device.backend(), family);
1000
1001 let (dst_resource_index, dst_offset) = indirect_draw_validation_resources
1002 .get_dst_subrange(stride, &mut self.current_dst_entry)?;
1003
1004 let buffer_size = src_buffer.size;
1005 let limits = device.adapter.limits();
1006 let data_size = get_src_stride_of_indirect_args(family);
1007 let (src_dynamic_offset, src_offset) =
1008 calculate_src_offsets(buffer_size, &limits, offset, data_size);
1009
1010 let src_buffer_tracker_index = src_buffer.tracker_index();
1011
1012 let entry = MetadataEntry::new(
1013 family == crate::command::DrawCommandFamily::DrawIndexed,
1014 src_offset,
1015 dst_offset,
1016 vertex_or_index_limit,
1017 instance_limit,
1018 );
1019
1020 match self.batches.entry((
1021 src_buffer_tracker_index,
1022 src_dynamic_offset,
1023 dst_resource_index,
1024 )) {
1025 hashbrown::hash_map::Entry::Occupied(mut occupied_entry) => {
1026 occupied_entry.get_mut().entries.push(entry)
1027 }
1028 hashbrown::hash_map::Entry::Vacant(vacant_entry) => {
1029 vacant_entry.insert(DrawIndirectValidationBatch {
1030 src_buffer: src_buffer.clone(),
1031 src_dynamic_offset,
1032 dst_resource_index,
1033 entries: vec![entry],
1034
1035 staging_buffer_index: 0,
1037 staging_buffer_offset: 0,
1038 metadata_resource_index: 0,
1039 metadata_buffer_offset: 0,
1040 });
1041 }
1042 }
1043
1044 Ok((dst_resource_index, dst_offset))
1045 }
1046}
1047
1048fn u64_offset_to_u32_offset(offset: u64) -> u32 {
1052 offset.try_into().unwrap()
1053}
1054
1055#[cfg(test)]
1056mod tests {
1057 use super::*;
1058
1059 #[test]
1060 fn calculate_src_offsets_test() {
1061 const MBUS: u64 = 256 << 20; const MBIS: u64 = 128 << 20; #[rustfmt::skip]
1065 let cases: &[(u64, u64, u32, u64, u64, u64, u64)] = &[
1066 (MBUS, MBIS, 32, 16, 0, 0, 0),
1070 (MBUS, MBIS, 32, 16, MBUS - 16, MBIS, MBIS - 16),
1072 (MBUS + 4, MBUS, 32, 16, MBUS + 4 - 16, 32, MBUS - 32 + 4 - 16),
1074 (512, 256, 32, 16, 240, 224, 16), (512, 256, 32, 16, 248, 224, 24), (512, 256, 32, 16, 256, 224, 32), (512, 256, 64, 16, 240, 192, 48), (512, 256, 64, 16, 248, 192, 56), (512, 256, 64, 16, 256, 192, 64), (512, 256, 128, 16, 240, 128, 112), (512, 256, 128, 16, 248, 128, 120), (512, 256, 128, 16, 256, 256, 0), (512, 256, 32, 20, 236, 224, 12), (512, 256, 32, 20, 244, 224, 20), (512, 256, 32, 20, 252, 224, 28), (512, 256, 64, 20, 236, 192, 44), (512, 256, 64, 20, 244, 192, 52), (512, 256, 64, 20, 252, 192, 60), (512, 256, 128, 20, 236, 128, 108), (512, 256, 128, 20, 244, 128, 116), (512, 256, 128, 20, 252, 128, 124), ];
1102
1103 for &(
1104 buffer_size,
1105 max_storage_buffer_binding_size,
1106 min_storage_buffer_offset_alignment,
1107 data_size,
1108 offset,
1109 expected_out_dynamic_offset,
1110 expected_out_offset,
1111 ) in cases
1112 {
1113 let limits = Limits {
1114 max_storage_buffer_binding_size,
1115 min_storage_buffer_offset_alignment,
1116 ..Limits::default()
1117 };
1118 let (out_dynamic_offset, out_offset) =
1119 calculate_src_offsets(buffer_size, &limits, offset, data_size);
1120 let binding_size = calculate_src_buffer_binding_size(buffer_size, &limits);
1121 assert_eq!(out_dynamic_offset + out_offset, offset);
1123 assert_eq!(
1124 out_dynamic_offset % min_storage_buffer_offset_alignment as u64,
1125 0
1126 );
1127 assert!(out_dynamic_offset + binding_size <= buffer_size);
1128 assert!(out_offset + data_size <= binding_size);
1129 assert_eq!(
1131 (out_dynamic_offset, out_offset),
1132 (expected_out_dynamic_offset, expected_out_offset),
1133 "buffer_size={buffer_size} \
1134 max_binding_size={max_storage_buffer_binding_size} \
1135 offset_alignment={min_storage_buffer_offset_alignment} \
1136 data_size={data_size} \
1137 offset={offset}"
1138 );
1139 }
1140 }
1141}