1use super::{
2 utils::{BufferBarrierScratch, BufferBarriers, UniqueIndexExt as _, UniqueIndexScratch},
3 CreateIndirectValidationPipelineError,
4};
5use crate::{
6 command::RenderPassErrorInner,
7 device::{queue::TempResource, Device, DeviceError},
8 hal_label,
9 lock::{rank, Mutex},
10 pipeline::{CreateComputePipelineError, CreateShaderModuleError},
11 resource::{RawResourceAccess as _, StagingBuffer, Trackable},
12 snatch::SnatchGuard,
13 track::TrackerIndex,
14 FastHashMap,
15};
16use alloc::{boxed::Box, string::ToString, sync::Arc, vec, vec::Vec};
17use core::{mem::size_of, num::NonZeroU64};
18use wgt::Limits;
19
20const BUFFER_SIZE: wgt::BufferSize = wgt::BufferSize::new(1_048_560).unwrap();
35
36#[derive(Debug)]
48pub(crate) struct Draw {
49 module: Box<dyn hal::DynShaderModule>,
50 metadata_bind_group_layout: Box<dyn hal::DynBindGroupLayout>,
51 src_bind_group_layout: Box<dyn hal::DynBindGroupLayout>,
52 dst_bind_group_layout: Box<dyn hal::DynBindGroupLayout>,
53 pipeline_layout: Box<dyn hal::DynPipelineLayout>,
54 pipeline: Box<dyn hal::DynComputePipeline>,
55
56 free_indirect_entries: Mutex<Vec<BufferPoolEntry>>,
57 free_metadata_entries: Mutex<Vec<BufferPoolEntry>>,
58}
59
60impl Draw {
61 pub(super) fn new(
62 device: &dyn hal::DynDevice,
63 required_features: &wgt::Features,
64 instance_flags: wgt::InstanceFlags,
65 backend: wgt::Backend,
66 ) -> Result<Self, CreateIndirectValidationPipelineError> {
67 let module = create_validation_module(device, instance_flags)?;
68
69 let metadata_bind_group_layout = create_bind_group_layout(
70 device,
71 true,
72 false,
73 BUFFER_SIZE,
74 hal_label(
75 Some("(wgpu internal) Indirect draw validation metadata bind group layout"),
76 instance_flags,
77 ),
78 )?;
79 let src_bind_group_layout = create_bind_group_layout(
80 device,
81 true,
82 true,
83 wgt::BufferSize::new(4 * 4).unwrap(),
84 hal_label(
85 Some("(wgpu internal) Indirect draw validation source bind group layout"),
86 instance_flags,
87 ),
88 )?;
89 let dst_bind_group_layout = create_bind_group_layout(
90 device,
91 false,
92 false,
93 BUFFER_SIZE,
94 hal_label(
95 Some("(wgpu internal) Indirect draw validation destination bind group layout"),
96 instance_flags,
97 ),
98 )?;
99
100 let pipeline_layout_desc = hal::PipelineLayoutDescriptor {
101 label: hal_label(
102 Some("(wgpu internal) Indirect draw validation pipeline layout"),
103 instance_flags,
104 ),
105 flags: hal::PipelineLayoutFlags::empty(),
106 bind_group_layouts: &[
107 Some(metadata_bind_group_layout.as_ref()),
108 Some(src_bind_group_layout.as_ref()),
109 Some(dst_bind_group_layout.as_ref()),
110 ],
111 immediate_size: 8,
112 };
113 let pipeline_layout = unsafe {
114 device
115 .create_pipeline_layout(&pipeline_layout_desc)
116 .map_err(DeviceError::from_hal)?
117 };
118
119 let supports_indirect_first_instance =
120 required_features.contains(wgt::Features::INDIRECT_FIRST_INSTANCE);
121 let write_d3d12_special_constants = backend == wgt::Backend::Dx12;
122 let pipeline = create_validation_pipeline(
123 device,
124 module.as_ref(),
125 pipeline_layout.as_ref(),
126 supports_indirect_first_instance,
127 write_d3d12_special_constants,
128 instance_flags,
129 )?;
130
131 Ok(Self {
132 module,
133 metadata_bind_group_layout,
134 src_bind_group_layout,
135 dst_bind_group_layout,
136 pipeline_layout,
137 pipeline,
138
139 free_indirect_entries: Mutex::new(rank::BUFFER_POOL, Vec::new()),
140 free_metadata_entries: Mutex::new(rank::BUFFER_POOL, Vec::new()),
141 })
142 }
143
144 pub(super) fn create_src_bind_group(
146 &self,
147 device: &dyn hal::DynDevice,
148 limits: &Limits,
149 buffer_size: u64,
150 buffer: &dyn hal::DynBuffer,
151 instance_flags: wgt::InstanceFlags,
152 ) -> Result<Option<Box<dyn hal::DynBindGroup>>, DeviceError> {
153 let binding_size = calculate_src_buffer_binding_size(buffer_size, limits);
154 let Some(binding_size) = NonZeroU64::new(binding_size) else {
155 return Ok(None);
156 };
157 let hal_desc = hal::BindGroupDescriptor {
158 label: hal_label(
159 Some("(wgpu internal) Indirect draw validation source bind group"),
160 instance_flags,
161 ),
162 layout: self.src_bind_group_layout.as_ref(),
163 entries: &[hal::BindGroupEntry {
164 binding: 0,
165 resource_index: 0,
166 count: 1,
167 }],
168 buffers: &[hal::BufferBinding::new_unchecked(buffer, 0, binding_size)],
170 samplers: &[],
171 textures: &[],
172 acceleration_structures: &[],
173 external_textures: &[],
174 };
175 unsafe {
176 device
177 .create_bind_group(&hal_desc)
178 .map(Some)
179 .map_err(DeviceError::from_hal)
180 }
181 }
182
183 fn acquire_dst_entry(
184 &self,
185 device: &dyn hal::DynDevice,
186 instance_flags: wgt::InstanceFlags,
187 ) -> Result<BufferPoolEntry, hal::DeviceError> {
188 let mut free_buffers = self.free_indirect_entries.lock();
189 match free_buffers.pop() {
190 Some(buffer) => Ok(buffer),
191 None => {
192 let usage = wgt::BufferUses::INDIRECT | wgt::BufferUses::STORAGE_READ_WRITE;
193 create_buffer_and_bind_group(
194 device,
195 usage,
196 self.dst_bind_group_layout.as_ref(),
197 hal_label(Some("(wgpu internal) Indirect draw validation destination buffer"), instance_flags),
198 hal_label(Some("(wgpu internal) Indirect draw validation destination bind group layout"), instance_flags),
199 )
200 }
201 }
202 }
203
204 fn release_dst_entries(&self, entries: impl Iterator<Item = BufferPoolEntry>) {
205 self.free_indirect_entries.lock().extend(entries);
206 }
207
208 fn acquire_metadata_entry(
209 &self,
210 device: &dyn hal::DynDevice,
211 instance_flags: wgt::InstanceFlags,
212 ) -> Result<BufferPoolEntry, hal::DeviceError> {
213 let mut free_buffers = self.free_metadata_entries.lock();
214 match free_buffers.pop() {
215 Some(buffer) => Ok(buffer),
216 None => {
217 let usage = wgt::BufferUses::COPY_DST | wgt::BufferUses::STORAGE_READ_ONLY;
218 create_buffer_and_bind_group(
219 device,
220 usage,
221 self.metadata_bind_group_layout.as_ref(),
222 hal_label(
223 Some("(wgpu internal) Indirect draw validation metadata buffer"),
224 instance_flags,
225 ),
226 hal_label(
227 Some("(wgpu internal) Indirect draw validation metadata bind group layout"),
228 instance_flags,
229 ),
230 )
231 }
232 }
233 }
234
235 fn release_metadata_entries(&self, entries: impl Iterator<Item = BufferPoolEntry>) {
236 self.free_metadata_entries.lock().extend(entries);
237 }
238
239 pub(crate) fn inject_validation_pass(
241 &self,
242 device: &Arc<Device>,
243 snatch_guard: &SnatchGuard,
244 resources: &mut DrawResources,
245 temp_resources: &mut Vec<TempResource>,
246 encoder: &mut dyn hal::DynCommandEncoder,
247 batcher: DrawBatcher,
248 ) -> Result<(), RenderPassErrorInner> {
249 let mut batches = batcher.batches;
250
251 if batches.is_empty() {
252 return Ok(());
253 }
254
255 let max_staging_buffer_size = 1 << 26; let mut staging_buffers = Vec::new();
258
259 let mut current_size = 0;
260 for batch in batches.values_mut() {
261 let data = batch.metadata();
262 let offset = if current_size + data.len() > max_staging_buffer_size {
263 let staging_buffer =
264 StagingBuffer::new(device, NonZeroU64::new(current_size as u64).unwrap())?;
265 staging_buffers.push(staging_buffer);
266 current_size = data.len();
267 0
268 } else {
269 let offset = current_size;
270 current_size += data.len();
271 offset as u64
272 };
273 batch.staging_buffer_index = staging_buffers.len();
274 batch.staging_buffer_offset = offset;
275 }
276 if current_size != 0 {
277 let staging_buffer =
278 StagingBuffer::new(device, NonZeroU64::new(current_size as u64).unwrap())?;
279 staging_buffers.push(staging_buffer);
280 }
281
282 for batch in batches.values() {
283 let data = batch.metadata();
284 let staging_buffer = &mut staging_buffers[batch.staging_buffer_index];
285 unsafe {
286 staging_buffer.write_with_offset(
287 data,
288 0,
289 batch.staging_buffer_offset as isize,
290 data.len(),
291 )
292 };
293 }
294
295 let staging_buffers: Vec<_> = staging_buffers
296 .into_iter()
297 .map(|buffer| buffer.flush())
298 .collect();
299
300 let mut current_metadata_entry = None;
301 for batch in batches.values_mut() {
302 let data = batch.metadata();
303 let (metadata_resource_index, metadata_buffer_offset) =
304 resources.get_metadata_subrange(data.len() as u64, &mut current_metadata_entry)?;
305 batch.metadata_resource_index = metadata_resource_index;
306 batch.metadata_buffer_offset = metadata_buffer_offset;
307 }
308
309 let buffer_barrier_scratch = &mut BufferBarrierScratch::new();
310 let unique_index_scratch = &mut UniqueIndexScratch::new();
311
312 BufferBarriers::new(buffer_barrier_scratch)
313 .extend(
314 batches
315 .values()
316 .map(|batch| batch.staging_buffer_index)
317 .unique(unique_index_scratch)
318 .map(|index| hal::BufferBarrier {
319 buffer: staging_buffers[index].raw(),
320 usage: hal::StateTransition {
321 from: wgt::BufferUses::MAP_WRITE,
322 to: wgt::BufferUses::COPY_SRC,
323 },
324 }),
325 )
326 .extend(
327 batches
328 .values()
329 .map(|batch| batch.metadata_resource_index)
330 .unique(unique_index_scratch)
331 .map(|index| hal::BufferBarrier {
332 buffer: resources.get_metadata_buffer(index),
333 usage: hal::StateTransition {
334 from: wgt::BufferUses::STORAGE_READ_ONLY,
335 to: wgt::BufferUses::COPY_DST,
336 },
337 }),
338 )
339 .encode(encoder);
340
341 for batch in batches.values() {
342 let data = batch.metadata();
343 let data_size = NonZeroU64::new(data.len() as u64).unwrap();
344
345 let staging_buffer = &staging_buffers[batch.staging_buffer_index];
346
347 let metadata_buffer = resources.get_metadata_buffer(batch.metadata_resource_index);
348
349 unsafe {
350 encoder.copy_buffer_to_buffer(
351 staging_buffer.raw(),
352 metadata_buffer,
353 &[hal::BufferCopy {
354 src_offset: batch.staging_buffer_offset,
355 dst_offset: batch.metadata_buffer_offset,
356 size: data_size,
357 }],
358 );
359 }
360 }
361
362 for staging_buffer in staging_buffers {
363 temp_resources.push(TempResource::StagingBuffer(staging_buffer));
364 }
365
366 BufferBarriers::new(buffer_barrier_scratch)
367 .extend(
368 batches
369 .values()
370 .map(|batch| batch.metadata_resource_index)
371 .unique(unique_index_scratch)
372 .map(|index| hal::BufferBarrier {
373 buffer: resources.get_metadata_buffer(index),
374 usage: hal::StateTransition {
375 from: wgt::BufferUses::COPY_DST,
376 to: wgt::BufferUses::STORAGE_READ_ONLY,
377 },
378 }),
379 )
380 .extend(
381 batches
382 .values()
383 .map(|batch| batch.dst_resource_index)
384 .unique(unique_index_scratch)
385 .map(|index| hal::BufferBarrier {
386 buffer: resources.get_dst_buffer(index),
387 usage: hal::StateTransition {
388 from: wgt::BufferUses::INDIRECT,
389 to: wgt::BufferUses::STORAGE_READ_WRITE,
390 },
391 }),
392 )
393 .encode(encoder);
394
395 let desc = hal::ComputePassDescriptor {
396 label: hal_label(
397 Some("(wgpu internal) Indirect draw validation pass"),
398 device.instance_flags,
399 ),
400 timestamp_writes: None,
401 };
402 unsafe {
403 encoder.begin_compute_pass(&desc);
404 }
405 unsafe {
406 encoder.set_compute_pipeline(self.pipeline.as_ref());
407 }
408
409 for batch in batches.values() {
410 let pipeline_layout = self.pipeline_layout.as_ref();
411
412 let metadata_start =
413 (batch.metadata_buffer_offset / size_of::<MetadataEntry>() as u64) as u32;
414 let metadata_count = batch.entries.len() as u32;
415 unsafe {
416 encoder.set_immediates(pipeline_layout, 0, &[metadata_start, metadata_count]);
417 }
418
419 let metadata_bind_group =
420 resources.get_metadata_bind_group(batch.metadata_resource_index);
421 unsafe {
422 encoder.set_bind_group(pipeline_layout, 0, metadata_bind_group, &[]);
423 }
424
425 batch.src_buffer.try_raw(snatch_guard)?;
427
428 let src_bind_group = batch
429 .src_buffer
430 .indirect_validation_bind_groups
431 .get(snatch_guard)
432 .unwrap()
433 .draw
434 .as_ref();
435 unsafe {
436 encoder.set_bind_group(
437 pipeline_layout,
438 1,
439 src_bind_group,
440 &[batch.src_dynamic_offset as u32],
441 );
442 }
443
444 let dst_bind_group = resources.get_dst_bind_group(batch.dst_resource_index);
445 unsafe {
446 encoder.set_bind_group(pipeline_layout, 2, dst_bind_group, &[]);
447 }
448
449 unsafe {
450 encoder.dispatch([(batch.entries.len() as u32).div_ceil(64), 1, 1]);
451 }
452 }
453
454 unsafe {
455 encoder.end_compute_pass();
456 }
457
458 BufferBarriers::new(buffer_barrier_scratch)
459 .extend(
460 batches
461 .values()
462 .map(|batch| batch.dst_resource_index)
463 .unique(unique_index_scratch)
464 .map(|index| hal::BufferBarrier {
465 buffer: resources.get_dst_buffer(index),
466 usage: hal::StateTransition {
467 from: wgt::BufferUses::STORAGE_READ_WRITE,
468 to: wgt::BufferUses::INDIRECT,
469 },
470 }),
471 )
472 .encode(encoder);
473
474 Ok(())
475 }
476
477 pub(super) fn dispose(self, device: &dyn hal::DynDevice) {
478 let Draw {
479 module,
480 metadata_bind_group_layout,
481 src_bind_group_layout,
482 dst_bind_group_layout,
483 pipeline_layout,
484 pipeline,
485
486 free_indirect_entries,
487 free_metadata_entries,
488 } = self;
489
490 for entry in free_indirect_entries.into_inner().drain(..) {
491 unsafe {
492 device.destroy_bind_group(entry.bind_group);
493 device.destroy_buffer(entry.buffer);
494 }
495 }
496
497 for entry in free_metadata_entries.into_inner().drain(..) {
498 unsafe {
499 device.destroy_bind_group(entry.bind_group);
500 device.destroy_buffer(entry.buffer);
501 }
502 }
503
504 unsafe {
505 device.destroy_compute_pipeline(pipeline);
506 device.destroy_pipeline_layout(pipeline_layout);
507 device.destroy_bind_group_layout(metadata_bind_group_layout);
508 device.destroy_bind_group_layout(src_bind_group_layout);
509 device.destroy_bind_group_layout(dst_bind_group_layout);
510 device.destroy_shader_module(module);
511 }
512 }
513}
514
515fn create_validation_module(
516 device: &dyn hal::DynDevice,
517 instance_flags: wgt::InstanceFlags,
518) -> Result<Box<dyn hal::DynShaderModule>, CreateIndirectValidationPipelineError> {
519 let src = include_str!("./validate_draw.wgsl");
520
521 #[cfg(feature = "wgsl")]
522 let module = naga::front::wgsl::parse_str(src).map_err(|inner| {
523 CreateShaderModuleError::Parsing(naga::error::ShaderError {
524 source: src.to_string(),
525 label: None,
526 inner: Box::new(inner),
527 })
528 })?;
529 #[cfg(not(feature = "wgsl"))]
530 #[allow(clippy::diverging_sub_expression)]
531 let module = panic!("Indirect validation requires the wgsl feature flag to be enabled!");
532
533 let info = crate::device::create_validator(
534 wgt::Features::IMMEDIATES,
535 wgt::DownlevelFlags::empty(),
536 naga::valid::ValidationFlags::all(),
537 )
538 .validate(&module)
539 .map_err(|inner| {
540 CreateShaderModuleError::Validation(naga::error::ShaderError {
541 source: src.to_string(),
542 label: None,
543 inner: Box::new(inner),
544 })
545 })?;
546 let hal_shader = hal::ShaderInput::Naga(hal::NagaShader {
547 module: alloc::borrow::Cow::Owned(module),
548 info,
549 debug_source: None,
550 });
551 let hal_desc = hal::ShaderModuleDescriptor {
552 label: hal_label(
553 Some("(wgpu internal) Indirect draw validation shader module"),
554 instance_flags,
555 ),
556 runtime_checks: wgt::ShaderRuntimeChecks::unchecked(),
557 };
558 let module = unsafe { device.create_shader_module(&hal_desc, hal_shader) }.map_err(
559 |error| match error {
560 hal::ShaderError::Device(error) => {
561 CreateShaderModuleError::Device(DeviceError::from_hal(error))
562 }
563 hal::ShaderError::Compilation(ref msg) => {
564 log::error!("Shader error: {msg}");
565 CreateShaderModuleError::Generation
566 }
567 },
568 )?;
569
570 Ok(module)
571}
572
573fn create_validation_pipeline(
574 device: &dyn hal::DynDevice,
575 module: &dyn hal::DynShaderModule,
576 pipeline_layout: &dyn hal::DynPipelineLayout,
577 supports_indirect_first_instance: bool,
578 write_d3d12_special_constants: bool,
579 instance_flags: wgt::InstanceFlags,
580) -> Result<Box<dyn hal::DynComputePipeline>, CreateIndirectValidationPipelineError> {
581 let pipeline_desc = hal::ComputePipelineDescriptor {
582 label: hal_label(
583 Some("(wgpu internal) Indirect draw validation pipeline"),
584 instance_flags,
585 ),
586 layout: pipeline_layout,
587 stage: hal::ProgrammableStage {
588 module,
589 entry_point: "main",
590 constants: &hashbrown::HashMap::from([
591 (
592 "supports_indirect_first_instance".to_string(),
593 f64::from(supports_indirect_first_instance),
594 ),
595 (
596 "write_d3d12_special_constants".to_string(),
597 f64::from(write_d3d12_special_constants),
598 ),
599 ]),
600 zero_initialize_workgroup_memory: false,
601 },
602 cache: None,
603 };
604 let pipeline =
605 unsafe { device.create_compute_pipeline(&pipeline_desc) }.map_err(|err| match err {
606 hal::PipelineError::Device(error) => {
607 CreateComputePipelineError::Device(DeviceError::from_hal(error))
608 }
609 hal::PipelineError::Linkage(_stages, msg) => CreateComputePipelineError::Internal(msg),
610 hal::PipelineError::EntryPoint(_stage) => CreateComputePipelineError::Internal(
611 crate::device::ENTRYPOINT_FAILURE_ERROR.to_string(),
612 ),
613 hal::PipelineError::PipelineConstants(_, error) => {
614 CreateComputePipelineError::PipelineConstants(error)
615 }
616 })?;
617
618 Ok(pipeline)
619}
620
621fn create_bind_group_layout(
622 device: &dyn hal::DynDevice,
623 read_only: bool,
624 has_dynamic_offset: bool,
625 min_binding_size: wgt::BufferSize,
626 label: Option<&'static str>,
627) -> Result<Box<dyn hal::DynBindGroupLayout>, CreateIndirectValidationPipelineError> {
628 let bind_group_layout_desc = hal::BindGroupLayoutDescriptor {
629 label,
630 flags: hal::BindGroupLayoutFlags::empty(),
631 entries: &[wgt::BindGroupLayoutEntry {
632 binding: 0,
633 visibility: wgt::ShaderStages::COMPUTE,
634 ty: wgt::BindingType::Buffer {
635 ty: wgt::BufferBindingType::Storage { read_only },
636 has_dynamic_offset,
637 min_binding_size: Some(min_binding_size),
638 },
639 count: None,
640 }],
641 };
642 let bind_group_layout = unsafe {
643 device
644 .create_bind_group_layout(&bind_group_layout_desc)
645 .map_err(DeviceError::from_hal)?
646 };
647
648 Ok(bind_group_layout)
649}
650
651fn calculate_src_buffer_binding_size(buffer_size: u64, limits: &Limits) -> u64 {
653 let max_storage_buffer_binding_size = limits.max_storage_buffer_binding_size;
654 let min_storage_buffer_offset_alignment = limits.min_storage_buffer_offset_alignment as u64;
655
656 if buffer_size <= max_storage_buffer_binding_size {
657 buffer_size
658 } else {
659 let buffer_rem = buffer_size % min_storage_buffer_offset_alignment;
660 let binding_rem = max_storage_buffer_binding_size % min_storage_buffer_offset_alignment;
661
662 if buffer_rem <= binding_rem {
665 max_storage_buffer_binding_size - binding_rem + buffer_rem
666 }
667 else {
669 max_storage_buffer_binding_size - binding_rem - min_storage_buffer_offset_alignment
670 + buffer_rem
671 }
672 }
673}
674
675fn calculate_src_offsets(buffer_size: u64, limits: &Limits, offset: u64) -> (u64, u64) {
677 let binding_size = calculate_src_buffer_binding_size(buffer_size, limits);
678
679 let min_storage_buffer_offset_alignment = limits.min_storage_buffer_offset_alignment as u64;
680
681 let chunk_adjustment = match min_storage_buffer_offset_alignment {
682 4 => 0,
684 8 => 2,
688 16.. => 1,
692 _ => unreachable!(),
693 };
694
695 let chunks = binding_size / min_storage_buffer_offset_alignment;
696 let dynamic_offset_stride =
697 chunks.saturating_sub(chunk_adjustment) * min_storage_buffer_offset_alignment;
698
699 if dynamic_offset_stride == 0 {
700 return (0, offset);
701 }
702
703 let max_dynamic_offset = buffer_size - binding_size;
704 let max_dynamic_offset_index = max_dynamic_offset / dynamic_offset_stride;
705
706 let src_dynamic_offset_index = offset / dynamic_offset_stride;
707
708 let src_dynamic_offset =
709 src_dynamic_offset_index.min(max_dynamic_offset_index) * dynamic_offset_stride;
710 let src_offset = offset - src_dynamic_offset;
711
712 (src_dynamic_offset, src_offset)
713}
714
715#[derive(Debug)]
716struct BufferPoolEntry {
717 buffer: Box<dyn hal::DynBuffer>,
718 bind_group: Box<dyn hal::DynBindGroup>,
719}
720
721fn create_buffer_and_bind_group(
722 device: &dyn hal::DynDevice,
723 usage: wgt::BufferUses,
724 bind_group_layout: &dyn hal::DynBindGroupLayout,
725 buffer_label: Option<&'static str>,
726 bind_group_label: Option<&'static str>,
727) -> Result<BufferPoolEntry, hal::DeviceError> {
728 let buffer_desc = hal::BufferDescriptor {
729 label: buffer_label,
730 size: BUFFER_SIZE.get(),
731 usage,
732 memory_flags: hal::MemoryFlags::empty(),
733 };
734 let buffer = unsafe { device.create_buffer(&buffer_desc) }?;
735 let bind_group_desc = hal::BindGroupDescriptor {
736 label: bind_group_label,
737 layout: bind_group_layout,
738 entries: &[hal::BindGroupEntry {
739 binding: 0,
740 resource_index: 0,
741 count: 1,
742 }],
743 buffers: &[hal::BufferBinding::new_unchecked(
745 buffer.as_ref(),
746 0,
747 BUFFER_SIZE,
748 )],
749 samplers: &[],
750 textures: &[],
751 acceleration_structures: &[],
752 external_textures: &[],
753 };
754 let bind_group = unsafe { device.create_bind_group(&bind_group_desc) }?;
755 Ok(BufferPoolEntry { buffer, bind_group })
756}
757
758#[derive(Clone)]
759struct CurrentEntry {
760 index: usize,
761 offset: u64,
762}
763
764pub(crate) struct DrawResources {
766 device: Arc<Device>,
767 dst_entries: Vec<BufferPoolEntry>,
768 metadata_entries: Vec<BufferPoolEntry>,
769}
770
771impl Drop for DrawResources {
772 fn drop(&mut self) {
773 if let Some(ref indirect_validation) = self.device.indirect_validation {
774 let indirect_draw_validation = &indirect_validation.draw;
775 indirect_draw_validation.release_dst_entries(self.dst_entries.drain(..));
776 indirect_draw_validation.release_metadata_entries(self.metadata_entries.drain(..));
777 }
778 }
779}
780
781impl DrawResources {
782 pub(crate) fn new(device: Arc<Device>) -> Self {
783 DrawResources {
784 device,
785 dst_entries: Vec::new(),
786 metadata_entries: Vec::new(),
787 }
788 }
789
790 pub(crate) fn get_dst_buffer(&self, index: usize) -> &dyn hal::DynBuffer {
791 self.dst_entries.get(index).unwrap().buffer.as_ref()
792 }
793
794 fn get_dst_bind_group(&self, index: usize) -> &dyn hal::DynBindGroup {
795 self.dst_entries.get(index).unwrap().bind_group.as_ref()
796 }
797
798 fn get_metadata_buffer(&self, index: usize) -> &dyn hal::DynBuffer {
799 self.metadata_entries.get(index).unwrap().buffer.as_ref()
800 }
801
802 fn get_metadata_bind_group(&self, index: usize) -> &dyn hal::DynBindGroup {
803 self.metadata_entries
804 .get(index)
805 .unwrap()
806 .bind_group
807 .as_ref()
808 }
809
810 fn get_dst_subrange(
811 &mut self,
812 size: u64,
813 current_entry: &mut Option<CurrentEntry>,
814 ) -> Result<(usize, u64), DeviceError> {
815 let indirect_draw_validation = &self.device.indirect_validation.as_ref().unwrap().draw;
816 let ensure_entry = |index: usize| {
817 if self.dst_entries.len() <= index {
818 let entry = indirect_draw_validation
819 .acquire_dst_entry(self.device.raw(), self.device.instance_flags)?;
820 self.dst_entries.push(entry);
821 }
822 Ok(())
823 };
824 let entry_data = Self::get_subrange_impl(ensure_entry, current_entry, size)?;
825 Ok((entry_data.index, entry_data.offset))
826 }
827
828 fn get_metadata_subrange(
829 &mut self,
830 size: u64,
831 current_entry: &mut Option<CurrentEntry>,
832 ) -> Result<(usize, u64), DeviceError> {
833 let indirect_draw_validation = &self.device.indirect_validation.as_ref().unwrap().draw;
834 let ensure_entry = |index: usize| {
835 if self.metadata_entries.len() <= index {
836 let entry = indirect_draw_validation
837 .acquire_metadata_entry(self.device.raw(), self.device.instance_flags)?;
838 self.metadata_entries.push(entry);
839 }
840 Ok(())
841 };
842 let entry_data = Self::get_subrange_impl(ensure_entry, current_entry, size)?;
843 Ok((entry_data.index, entry_data.offset))
844 }
845
846 fn get_subrange_impl(
847 ensure_entry: impl FnOnce(usize) -> Result<(), hal::DeviceError>,
848 current_entry: &mut Option<CurrentEntry>,
849 size: u64,
850 ) -> Result<CurrentEntry, DeviceError> {
851 let index = if let Some(current_entry) = current_entry.as_mut() {
852 if current_entry.offset + size <= BUFFER_SIZE.get() {
853 let entry_data = current_entry.clone();
854 current_entry.offset += size;
855 return Ok(entry_data);
856 } else {
857 current_entry.index + 1
858 }
859 } else {
860 0
861 };
862
863 ensure_entry(index).map_err(DeviceError::from_hal)?;
864
865 let entry_data = CurrentEntry { index, offset: 0 };
866
867 *current_entry = Some(CurrentEntry {
868 index,
869 offset: size,
870 });
871
872 Ok(entry_data)
873 }
874}
875
876#[repr(C)]
878struct MetadataEntry {
879 src_offset: u32,
880 dst_offset: u32,
881 vertex_or_index_limit: u32,
882 instance_limit: u32,
883}
884
885impl MetadataEntry {
886 fn new(
887 indexed: bool,
888 src_offset: u64,
889 dst_offset: u64,
890 vertex_or_index_limit: u64,
891 instance_limit: u64,
892 ) -> Self {
893 const U32_MAX_AS_U64: u64 = u32::MAX as u64;
894
895 assert!(src_offset <= U32_MAX_AS_U64);
897 assert!(dst_offset <= U32_MAX_AS_U64);
898
899 let src_offset = src_offset as u32;
900 let src_offset = src_offset / 4; let src_offset = src_offset | ((indexed as u32) << 31);
905
906 let max_limit = U32_MAX_AS_U64 + U32_MAX_AS_U64; let vertex_or_index_limit = vertex_or_index_limit.min(max_limit);
910 let vertex_or_index_limit_bit_32 = (vertex_or_index_limit >> 32) as u32; let vertex_or_index_limit = vertex_or_index_limit as u32; let instance_limit = instance_limit.min(max_limit);
914 let instance_limit_bit_32 = (instance_limit >> 32) as u32; let instance_limit = instance_limit as u32; let dst_offset = dst_offset as u32;
918 let dst_offset = dst_offset / 4; let dst_offset =
924 dst_offset | (vertex_or_index_limit_bit_32 << 30) | (instance_limit_bit_32 << 31);
925
926 Self {
927 src_offset,
928 dst_offset,
929 vertex_or_index_limit,
930 instance_limit,
931 }
932 }
933}
934
935struct DrawIndirectValidationBatch {
936 src_buffer: Arc<crate::resource::Buffer>,
937 src_dynamic_offset: u64,
938 dst_resource_index: usize,
939 entries: Vec<MetadataEntry>,
940
941 staging_buffer_index: usize,
942 staging_buffer_offset: u64,
943 metadata_resource_index: usize,
944 metadata_buffer_offset: u64,
945}
946
947impl DrawIndirectValidationBatch {
948 fn metadata(&self) -> &[u8] {
950 unsafe {
951 core::slice::from_raw_parts(
952 self.entries.as_ptr().cast::<u8>(),
953 self.entries.len() * size_of::<MetadataEntry>(),
954 )
955 }
956 }
957}
958
959pub(crate) struct DrawBatcher {
961 batches: FastHashMap<(TrackerIndex, u64, usize), DrawIndirectValidationBatch>,
962 current_dst_entry: Option<CurrentEntry>,
963}
964
965impl DrawBatcher {
966 pub(crate) fn new() -> Self {
967 Self {
968 batches: FastHashMap::default(),
969 current_dst_entry: None,
970 }
971 }
972
973 pub(crate) fn add<'a>(
978 &mut self,
979 indirect_draw_validation_resources: &'a mut DrawResources,
980 device: &Device,
981 src_buffer: &Arc<crate::resource::Buffer>,
982 offset: u64,
983 family: crate::command::DrawCommandFamily,
984 vertex_or_index_limit: u64,
985 instance_limit: u64,
986 ) -> Result<(usize, u64), DeviceError> {
987 let extra = if device.backend() == wgt::Backend::Dx12 {
989 3 * size_of::<u32>() as u64
990 } else {
991 0
992 };
993 let stride = extra + crate::command::get_stride_of_indirect_args(family);
994
995 let (dst_resource_index, dst_offset) = indirect_draw_validation_resources
996 .get_dst_subrange(stride, &mut self.current_dst_entry)?;
997
998 let buffer_size = src_buffer.size;
999 let limits = device.adapter.limits();
1000 let (src_dynamic_offset, src_offset) = calculate_src_offsets(buffer_size, &limits, offset);
1001
1002 let src_buffer_tracker_index = src_buffer.tracker_index();
1003
1004 let entry = MetadataEntry::new(
1005 family == crate::command::DrawCommandFamily::DrawIndexed,
1006 src_offset,
1007 dst_offset,
1008 vertex_or_index_limit,
1009 instance_limit,
1010 );
1011
1012 match self.batches.entry((
1013 src_buffer_tracker_index,
1014 src_dynamic_offset,
1015 dst_resource_index,
1016 )) {
1017 hashbrown::hash_map::Entry::Occupied(mut occupied_entry) => {
1018 occupied_entry.get_mut().entries.push(entry)
1019 }
1020 hashbrown::hash_map::Entry::Vacant(vacant_entry) => {
1021 vacant_entry.insert(DrawIndirectValidationBatch {
1022 src_buffer: src_buffer.clone(),
1023 src_dynamic_offset,
1024 dst_resource_index,
1025 entries: vec![entry],
1026
1027 staging_buffer_index: 0,
1029 staging_buffer_offset: 0,
1030 metadata_resource_index: 0,
1031 metadata_buffer_offset: 0,
1032 });
1033 }
1034 }
1035
1036 Ok((dst_resource_index, dst_offset))
1037 }
1038}