1use super::{
2 utils::{BufferBarrierScratch, BufferBarriers, UniqueIndexExt as _, UniqueIndexScratch},
3 CreateIndirectValidationPipelineError,
4};
5use crate::{
6 command::RenderPassErrorInner,
7 device::{queue::TempResource, Device, DeviceError},
8 hal_label,
9 lock::{rank, Mutex},
10 pipeline::{CreateComputePipelineError, CreateShaderModuleError},
11 resource::{RawResourceAccess as _, StagingBuffer, Trackable},
12 snatch::SnatchGuard,
13 track::TrackerIndex,
14 FastHashMap,
15};
16use alloc::{boxed::Box, string::ToString, sync::Arc, vec, vec::Vec};
17use core::{
18 mem::{size_of, size_of_val},
19 num::NonZeroU64,
20};
21use wgt::Limits;
22
23const BUFFER_SIZE: wgt::BufferSize = wgt::BufferSize::new(1_048_560).unwrap();
38
39#[derive(Debug)]
51pub(crate) struct Draw {
52 module: Box<dyn hal::DynShaderModule>,
53 metadata_bind_group_layout: Box<dyn hal::DynBindGroupLayout>,
54 src_bind_group_layout: Box<dyn hal::DynBindGroupLayout>,
55 dst_bind_group_layout: Box<dyn hal::DynBindGroupLayout>,
56 pipeline_layout: Box<dyn hal::DynPipelineLayout>,
57 pipeline: Box<dyn hal::DynComputePipeline>,
58
59 free_indirect_entries: Mutex<Vec<BufferPoolEntry>>,
60 free_metadata_entries: Mutex<Vec<BufferPoolEntry>>,
61}
62
63impl Draw {
64 pub(super) fn new(
65 device: &dyn hal::DynDevice,
66 required_features: &wgt::Features,
67 instance_flags: wgt::InstanceFlags,
68 backend: wgt::Backend,
69 ) -> Result<Self, CreateIndirectValidationPipelineError> {
70 let module = create_validation_module(device, instance_flags)?;
71
72 let metadata_bind_group_layout = create_bind_group_layout(
73 device,
74 true,
75 false,
76 BUFFER_SIZE,
77 hal_label(
78 Some("(wgpu internal) Indirect draw validation metadata bind group layout"),
79 instance_flags,
80 ),
81 )?;
82 let src_bind_group_layout = create_bind_group_layout(
83 device,
84 true,
85 true,
86 wgt::BufferSize::new(4 * 4).unwrap(),
87 hal_label(
88 Some("(wgpu internal) Indirect draw validation source bind group layout"),
89 instance_flags,
90 ),
91 )?;
92 let dst_bind_group_layout = create_bind_group_layout(
93 device,
94 false,
95 false,
96 BUFFER_SIZE,
97 hal_label(
98 Some("(wgpu internal) Indirect draw validation destination bind group layout"),
99 instance_flags,
100 ),
101 )?;
102
103 let pipeline_layout_desc = hal::PipelineLayoutDescriptor {
104 label: hal_label(
105 Some("(wgpu internal) Indirect draw validation pipeline layout"),
106 instance_flags,
107 ),
108 flags: hal::PipelineLayoutFlags::empty(),
109 bind_group_layouts: &[
110 Some(metadata_bind_group_layout.as_ref()),
111 Some(src_bind_group_layout.as_ref()),
112 Some(dst_bind_group_layout.as_ref()),
113 ],
114 immediate_size: 8,
115 };
116 let pipeline_layout = unsafe {
117 device
118 .create_pipeline_layout(&pipeline_layout_desc)
119 .map_err(DeviceError::from_hal)?
120 };
121
122 let supports_indirect_first_instance =
123 required_features.contains(wgt::Features::INDIRECT_FIRST_INSTANCE);
124 let write_d3d12_special_constants = backend == wgt::Backend::Dx12;
125 let pipeline = create_validation_pipeline(
126 device,
127 module.as_ref(),
128 pipeline_layout.as_ref(),
129 supports_indirect_first_instance,
130 write_d3d12_special_constants,
131 instance_flags,
132 )?;
133
134 Ok(Self {
135 module,
136 metadata_bind_group_layout,
137 src_bind_group_layout,
138 dst_bind_group_layout,
139 pipeline_layout,
140 pipeline,
141
142 free_indirect_entries: Mutex::new(rank::BUFFER_POOL, Vec::new()),
143 free_metadata_entries: Mutex::new(rank::BUFFER_POOL, Vec::new()),
144 })
145 }
146
147 pub(super) fn create_src_bind_group(
149 &self,
150 device: &dyn hal::DynDevice,
151 limits: &Limits,
152 buffer_size: u64,
153 buffer: &dyn hal::DynBuffer,
154 instance_flags: wgt::InstanceFlags,
155 ) -> Result<Option<Box<dyn hal::DynBindGroup>>, DeviceError> {
156 let binding_size = calculate_src_buffer_binding_size(buffer_size, limits);
157 let Some(binding_size) = NonZeroU64::new(binding_size) else {
158 return Ok(None);
159 };
160 let hal_desc = hal::BindGroupDescriptor {
161 label: hal_label(
162 Some("(wgpu internal) Indirect draw validation source bind group"),
163 instance_flags,
164 ),
165 layout: self.src_bind_group_layout.as_ref(),
166 entries: &[hal::BindGroupEntry {
167 binding: 0,
168 resource_index: 0,
169 count: 1,
170 }],
171 buffers: &[hal::BufferBinding::new_unchecked(buffer, 0, binding_size)],
173 samplers: &[],
174 textures: &[],
175 acceleration_structures: &[],
176 external_textures: &[],
177 };
178 unsafe {
179 device
180 .create_bind_group(&hal_desc)
181 .map(Some)
182 .map_err(DeviceError::from_hal)
183 }
184 }
185
186 fn acquire_dst_entry(
187 &self,
188 device: &dyn hal::DynDevice,
189 instance_flags: wgt::InstanceFlags,
190 ) -> Result<BufferPoolEntry, hal::DeviceError> {
191 let mut free_buffers = self.free_indirect_entries.lock();
192 match free_buffers.pop() {
193 Some(buffer) => Ok(buffer),
194 None => {
195 let usage = wgt::BufferUses::INDIRECT | wgt::BufferUses::STORAGE_READ_WRITE;
196 create_buffer_and_bind_group(
197 device,
198 usage,
199 self.dst_bind_group_layout.as_ref(),
200 hal_label(Some("(wgpu internal) Indirect draw validation destination buffer"), instance_flags),
201 hal_label(Some("(wgpu internal) Indirect draw validation destination bind group layout"), instance_flags),
202 )
203 }
204 }
205 }
206
207 fn release_dst_entries(&self, entries: impl Iterator<Item = BufferPoolEntry>) {
208 self.free_indirect_entries.lock().extend(entries);
209 }
210
211 fn acquire_metadata_entry(
212 &self,
213 device: &dyn hal::DynDevice,
214 instance_flags: wgt::InstanceFlags,
215 ) -> Result<BufferPoolEntry, hal::DeviceError> {
216 let mut free_buffers = self.free_metadata_entries.lock();
217 match free_buffers.pop() {
218 Some(buffer) => Ok(buffer),
219 None => {
220 let usage = wgt::BufferUses::COPY_DST | wgt::BufferUses::STORAGE_READ_ONLY;
221 create_buffer_and_bind_group(
222 device,
223 usage,
224 self.metadata_bind_group_layout.as_ref(),
225 hal_label(
226 Some("(wgpu internal) Indirect draw validation metadata buffer"),
227 instance_flags,
228 ),
229 hal_label(
230 Some("(wgpu internal) Indirect draw validation metadata bind group layout"),
231 instance_flags,
232 ),
233 )
234 }
235 }
236 }
237
238 fn release_metadata_entries(&self, entries: impl Iterator<Item = BufferPoolEntry>) {
239 self.free_metadata_entries.lock().extend(entries);
240 }
241
242 pub(crate) fn inject_validation_pass(
244 &self,
245 device: &Arc<Device>,
246 snatch_guard: &SnatchGuard,
247 resources: &mut DrawResources,
248 temp_resources: &mut Vec<TempResource>,
249 encoder: &mut dyn hal::DynCommandEncoder,
250 batcher: DrawBatcher,
251 ) -> Result<(), RenderPassErrorInner> {
252 let mut batches = batcher.batches;
253
254 if batches.is_empty() {
255 return Ok(());
256 }
257
258 let max_staging_buffer_size = 1 << 26; let mut staging_buffers = Vec::new();
261
262 let mut current_size = 0;
263 for batch in batches.values_mut() {
264 let data = batch.metadata();
265 let offset = if current_size + data.len() > max_staging_buffer_size {
266 let staging_buffer =
267 StagingBuffer::new(device, NonZeroU64::new(current_size as u64).unwrap())?;
268 staging_buffers.push(staging_buffer);
269 current_size = data.len();
270 0
271 } else {
272 let offset = current_size;
273 current_size += data.len();
274 offset as u64
275 };
276 batch.staging_buffer_index = staging_buffers.len();
277 batch.staging_buffer_offset = offset;
278 }
279 if current_size != 0 {
280 let staging_buffer =
281 StagingBuffer::new(device, NonZeroU64::new(current_size as u64).unwrap())?;
282 staging_buffers.push(staging_buffer);
283 }
284
285 for batch in batches.values() {
286 let data = batch.metadata();
287 let staging_buffer = &mut staging_buffers[batch.staging_buffer_index];
288 unsafe {
289 staging_buffer.write_with_offset(
290 data,
291 0,
292 batch.staging_buffer_offset as isize,
293 data.len(),
294 )
295 };
296 }
297
298 let staging_buffers: Vec<_> = staging_buffers
299 .into_iter()
300 .map(|buffer| buffer.flush())
301 .collect();
302
303 let mut current_metadata_entry = None;
304 for batch in batches.values_mut() {
305 let data = batch.metadata();
306 let (metadata_resource_index, metadata_buffer_offset) =
307 resources.get_metadata_subrange(data.len() as u64, &mut current_metadata_entry)?;
308 batch.metadata_resource_index = metadata_resource_index;
309 batch.metadata_buffer_offset = metadata_buffer_offset;
310 }
311
312 let buffer_barrier_scratch = &mut BufferBarrierScratch::new();
313 let unique_index_scratch = &mut UniqueIndexScratch::new();
314
315 BufferBarriers::new(buffer_barrier_scratch)
316 .extend(
317 batches
318 .values()
319 .map(|batch| batch.staging_buffer_index)
320 .unique(unique_index_scratch)
321 .map(|index| hal::BufferBarrier {
322 buffer: staging_buffers[index].raw(),
323 usage: hal::StateTransition {
324 from: wgt::BufferUses::MAP_WRITE,
325 to: wgt::BufferUses::COPY_SRC,
326 },
327 }),
328 )
329 .extend(
330 batches
331 .values()
332 .map(|batch| batch.metadata_resource_index)
333 .unique(unique_index_scratch)
334 .map(|index| hal::BufferBarrier {
335 buffer: resources.get_metadata_buffer(index),
336 usage: hal::StateTransition {
337 from: wgt::BufferUses::STORAGE_READ_ONLY,
338 to: wgt::BufferUses::COPY_DST,
339 },
340 }),
341 )
342 .encode(encoder);
343
344 for batch in batches.values() {
345 let data = batch.metadata();
346 let data_size = NonZeroU64::new(data.len() as u64).unwrap();
347
348 let staging_buffer = &staging_buffers[batch.staging_buffer_index];
349
350 let metadata_buffer = resources.get_metadata_buffer(batch.metadata_resource_index);
351
352 unsafe {
353 encoder.copy_buffer_to_buffer(
354 staging_buffer.raw(),
355 metadata_buffer,
356 &[hal::BufferCopy {
357 src_offset: batch.staging_buffer_offset,
358 dst_offset: batch.metadata_buffer_offset,
359 size: data_size,
360 }],
361 );
362 }
363 }
364
365 for staging_buffer in staging_buffers {
366 temp_resources.push(TempResource::StagingBuffer(staging_buffer));
367 }
368
369 BufferBarriers::new(buffer_barrier_scratch)
370 .extend(
371 batches
372 .values()
373 .map(|batch| batch.metadata_resource_index)
374 .unique(unique_index_scratch)
375 .map(|index| hal::BufferBarrier {
376 buffer: resources.get_metadata_buffer(index),
377 usage: hal::StateTransition {
378 from: wgt::BufferUses::COPY_DST,
379 to: wgt::BufferUses::STORAGE_READ_ONLY,
380 },
381 }),
382 )
383 .extend(
384 batches
385 .values()
386 .map(|batch| batch.dst_resource_index)
387 .unique(unique_index_scratch)
388 .map(|index| hal::BufferBarrier {
389 buffer: resources.get_dst_buffer(index),
390 usage: hal::StateTransition {
391 from: wgt::BufferUses::INDIRECT,
392 to: wgt::BufferUses::STORAGE_READ_WRITE,
393 },
394 }),
395 )
396 .encode(encoder);
397
398 let desc = hal::ComputePassDescriptor {
399 label: hal_label(
400 Some("(wgpu internal) Indirect draw validation pass"),
401 device.instance_flags,
402 ),
403 timestamp_writes: None,
404 };
405 unsafe {
406 encoder.begin_compute_pass(&desc);
407 }
408 unsafe {
409 encoder.set_compute_pipeline(self.pipeline.as_ref());
410 }
411
412 for batch in batches.values() {
413 let pipeline_layout = self.pipeline_layout.as_ref();
414
415 let metadata_start =
416 (batch.metadata_buffer_offset / size_of::<MetadataEntry>() as u64) as u32;
417 let metadata_count = batch.entries.len() as u32;
418 unsafe {
419 encoder.set_immediates(pipeline_layout, 0, &[metadata_start, metadata_count]);
420 }
421
422 let metadata_bind_group =
423 resources.get_metadata_bind_group(batch.metadata_resource_index);
424 unsafe {
425 encoder.set_bind_group(pipeline_layout, 0, metadata_bind_group, &[]);
426 }
427
428 batch.src_buffer.try_raw(snatch_guard)?;
430
431 let src_bind_group = batch
432 .src_buffer
433 .indirect_validation_bind_groups
434 .get(snatch_guard)
435 .unwrap()
436 .draw
437 .as_ref();
438 unsafe {
439 encoder.set_bind_group(
440 pipeline_layout,
441 1,
442 src_bind_group,
443 &[batch.src_dynamic_offset as u32],
444 );
445 }
446
447 let dst_bind_group = resources.get_dst_bind_group(batch.dst_resource_index);
448 unsafe {
449 encoder.set_bind_group(pipeline_layout, 2, dst_bind_group, &[]);
450 }
451
452 unsafe {
453 encoder.dispatch([(batch.entries.len() as u32).div_ceil(64), 1, 1]);
454 }
455 }
456
457 unsafe {
458 encoder.end_compute_pass();
459 }
460
461 BufferBarriers::new(buffer_barrier_scratch)
462 .extend(
463 batches
464 .values()
465 .map(|batch| batch.dst_resource_index)
466 .unique(unique_index_scratch)
467 .map(|index| hal::BufferBarrier {
468 buffer: resources.get_dst_buffer(index),
469 usage: hal::StateTransition {
470 from: wgt::BufferUses::STORAGE_READ_WRITE,
471 to: wgt::BufferUses::INDIRECT,
472 },
473 }),
474 )
475 .encode(encoder);
476
477 Ok(())
478 }
479
480 pub(super) fn dispose(self, device: &dyn hal::DynDevice) {
481 let Draw {
482 module,
483 metadata_bind_group_layout,
484 src_bind_group_layout,
485 dst_bind_group_layout,
486 pipeline_layout,
487 pipeline,
488
489 free_indirect_entries,
490 free_metadata_entries,
491 } = self;
492
493 for entry in free_indirect_entries.into_inner().drain(..) {
494 unsafe {
495 device.destroy_bind_group(entry.bind_group);
496 device.destroy_buffer(entry.buffer);
497 }
498 }
499
500 for entry in free_metadata_entries.into_inner().drain(..) {
501 unsafe {
502 device.destroy_bind_group(entry.bind_group);
503 device.destroy_buffer(entry.buffer);
504 }
505 }
506
507 unsafe {
508 device.destroy_compute_pipeline(pipeline);
509 device.destroy_pipeline_layout(pipeline_layout);
510 device.destroy_bind_group_layout(metadata_bind_group_layout);
511 device.destroy_bind_group_layout(src_bind_group_layout);
512 device.destroy_bind_group_layout(dst_bind_group_layout);
513 device.destroy_shader_module(module);
514 }
515 }
516}
517
518fn create_validation_module(
519 device: &dyn hal::DynDevice,
520 instance_flags: wgt::InstanceFlags,
521) -> Result<Box<dyn hal::DynShaderModule>, CreateIndirectValidationPipelineError> {
522 let src = include_str!("./validate_draw.wgsl");
523
524 #[cfg(feature = "wgsl")]
525 let module = naga::front::wgsl::parse_str(src).map_err(|inner| {
526 CreateShaderModuleError::Parsing(naga::error::ShaderError {
527 source: src.to_string(),
528 label: None,
529 inner: Box::new(inner),
530 })
531 })?;
532 #[cfg(not(feature = "wgsl"))]
533 #[allow(clippy::diverging_sub_expression)]
534 let module = panic!("Indirect validation requires the wgsl feature flag to be enabled!");
535
536 let info = crate::device::create_validator(
537 wgt::Features::IMMEDIATES,
538 wgt::DownlevelFlags::empty(),
539 naga::valid::ValidationFlags::all(),
540 )
541 .validate(&module)
542 .map_err(|inner| {
543 CreateShaderModuleError::Validation(naga::error::ShaderError {
544 source: src.to_string(),
545 label: None,
546 inner: Box::new(inner),
547 })
548 })?;
549 let hal_shader = hal::ShaderInput::Naga(hal::NagaShader {
550 module: alloc::borrow::Cow::Owned(module),
551 info,
552 debug_source: None,
553 });
554 let hal_desc = hal::ShaderModuleDescriptor {
555 label: hal_label(
556 Some("(wgpu internal) Indirect draw validation shader module"),
557 instance_flags,
558 ),
559 runtime_checks: wgt::ShaderRuntimeChecks::unchecked(),
560 };
561 let module = unsafe { device.create_shader_module(&hal_desc, hal_shader) }.map_err(
562 |error| match error {
563 hal::ShaderError::Device(error) => {
564 CreateShaderModuleError::Device(DeviceError::from_hal(error))
565 }
566 hal::ShaderError::Compilation(ref msg) => {
567 log::error!("Shader error: {msg}");
568 CreateShaderModuleError::Generation
569 }
570 },
571 )?;
572
573 Ok(module)
574}
575
576fn create_validation_pipeline(
577 device: &dyn hal::DynDevice,
578 module: &dyn hal::DynShaderModule,
579 pipeline_layout: &dyn hal::DynPipelineLayout,
580 supports_indirect_first_instance: bool,
581 write_d3d12_special_constants: bool,
582 instance_flags: wgt::InstanceFlags,
583) -> Result<Box<dyn hal::DynComputePipeline>, CreateIndirectValidationPipelineError> {
584 let pipeline_desc = hal::ComputePipelineDescriptor {
585 label: hal_label(
586 Some("(wgpu internal) Indirect draw validation pipeline"),
587 instance_flags,
588 ),
589 layout: pipeline_layout,
590 stage: hal::ProgrammableStage {
591 module,
592 entry_point: "main",
593 constants: &hashbrown::HashMap::from([
594 (
595 "supports_indirect_first_instance".to_string(),
596 f64::from(supports_indirect_first_instance),
597 ),
598 (
599 "write_d3d12_special_constants".to_string(),
600 f64::from(write_d3d12_special_constants),
601 ),
602 ]),
603 zero_initialize_workgroup_memory: false,
604 },
605 cache: None,
606 };
607 let pipeline =
608 unsafe { device.create_compute_pipeline(&pipeline_desc) }.map_err(|err| match err {
609 hal::PipelineError::Device(error) => {
610 CreateComputePipelineError::Device(DeviceError::from_hal(error))
611 }
612 hal::PipelineError::Linkage(_stages, msg) => CreateComputePipelineError::Internal(msg),
613 hal::PipelineError::EntryPoint(_stage) => CreateComputePipelineError::Internal(
614 crate::device::ENTRYPOINT_FAILURE_ERROR.to_string(),
615 ),
616 hal::PipelineError::PipelineConstants(_, error) => {
617 CreateComputePipelineError::PipelineConstants(error)
618 }
619 })?;
620
621 Ok(pipeline)
622}
623
624fn create_bind_group_layout(
625 device: &dyn hal::DynDevice,
626 read_only: bool,
627 has_dynamic_offset: bool,
628 min_binding_size: wgt::BufferSize,
629 label: Option<&'static str>,
630) -> Result<Box<dyn hal::DynBindGroupLayout>, CreateIndirectValidationPipelineError> {
631 let bind_group_layout_desc = hal::BindGroupLayoutDescriptor {
632 label,
633 flags: hal::BindGroupLayoutFlags::empty(),
634 entries: &[wgt::BindGroupLayoutEntry {
635 binding: 0,
636 visibility: wgt::ShaderStages::COMPUTE,
637 ty: wgt::BindingType::Buffer {
638 ty: wgt::BufferBindingType::Storage { read_only },
639 has_dynamic_offset,
640 min_binding_size: Some(min_binding_size),
641 },
642 count: None,
643 }],
644 };
645 let bind_group_layout = unsafe {
646 device
647 .create_bind_group_layout(&bind_group_layout_desc)
648 .map_err(DeviceError::from_hal)?
649 };
650
651 Ok(bind_group_layout)
652}
653
654fn calculate_src_buffer_binding_size(buffer_size: u64, limits: &Limits) -> u64 {
656 let max_storage_buffer_binding_size = limits.max_storage_buffer_binding_size as u64;
657 let min_storage_buffer_offset_alignment = limits.min_storage_buffer_offset_alignment as u64;
658
659 if buffer_size <= max_storage_buffer_binding_size {
660 buffer_size
661 } else {
662 let buffer_rem = buffer_size % min_storage_buffer_offset_alignment;
663 let binding_rem = max_storage_buffer_binding_size % min_storage_buffer_offset_alignment;
664
665 if buffer_rem <= binding_rem {
668 max_storage_buffer_binding_size - binding_rem + buffer_rem
669 }
670 else {
672 max_storage_buffer_binding_size - binding_rem - min_storage_buffer_offset_alignment
673 + buffer_rem
674 }
675 }
676}
677
678fn calculate_src_offsets(buffer_size: u64, limits: &Limits, offset: u64) -> (u64, u64) {
680 let binding_size = calculate_src_buffer_binding_size(buffer_size, limits);
681
682 let min_storage_buffer_offset_alignment = limits.min_storage_buffer_offset_alignment as u64;
683
684 let chunk_adjustment = match min_storage_buffer_offset_alignment {
685 4 => 0,
687 8 => 2,
691 16.. => 1,
695 _ => unreachable!(),
696 };
697
698 let chunks = binding_size / min_storage_buffer_offset_alignment;
699 let dynamic_offset_stride =
700 chunks.saturating_sub(chunk_adjustment) * min_storage_buffer_offset_alignment;
701
702 if dynamic_offset_stride == 0 {
703 return (0, offset);
704 }
705
706 let max_dynamic_offset = buffer_size - binding_size;
707 let max_dynamic_offset_index = max_dynamic_offset / dynamic_offset_stride;
708
709 let src_dynamic_offset_index = offset / dynamic_offset_stride;
710
711 let src_dynamic_offset =
712 src_dynamic_offset_index.min(max_dynamic_offset_index) * dynamic_offset_stride;
713 let src_offset = offset - src_dynamic_offset;
714
715 (src_dynamic_offset, src_offset)
716}
717
718#[derive(Debug)]
719struct BufferPoolEntry {
720 buffer: Box<dyn hal::DynBuffer>,
721 bind_group: Box<dyn hal::DynBindGroup>,
722}
723
724fn create_buffer_and_bind_group(
725 device: &dyn hal::DynDevice,
726 usage: wgt::BufferUses,
727 bind_group_layout: &dyn hal::DynBindGroupLayout,
728 buffer_label: Option<&'static str>,
729 bind_group_label: Option<&'static str>,
730) -> Result<BufferPoolEntry, hal::DeviceError> {
731 let buffer_desc = hal::BufferDescriptor {
732 label: buffer_label,
733 size: BUFFER_SIZE.get(),
734 usage,
735 memory_flags: hal::MemoryFlags::empty(),
736 };
737 let buffer = unsafe { device.create_buffer(&buffer_desc) }?;
738 let bind_group_desc = hal::BindGroupDescriptor {
739 label: bind_group_label,
740 layout: bind_group_layout,
741 entries: &[hal::BindGroupEntry {
742 binding: 0,
743 resource_index: 0,
744 count: 1,
745 }],
746 buffers: &[hal::BufferBinding::new_unchecked(
748 buffer.as_ref(),
749 0,
750 BUFFER_SIZE,
751 )],
752 samplers: &[],
753 textures: &[],
754 acceleration_structures: &[],
755 external_textures: &[],
756 };
757 let bind_group = unsafe { device.create_bind_group(&bind_group_desc) }?;
758 Ok(BufferPoolEntry { buffer, bind_group })
759}
760
761#[derive(Clone)]
762struct CurrentEntry {
763 index: usize,
764 offset: u64,
765}
766
767pub(crate) struct DrawResources {
769 device: Arc<Device>,
770 dst_entries: Vec<BufferPoolEntry>,
771 metadata_entries: Vec<BufferPoolEntry>,
772}
773
774impl Drop for DrawResources {
775 fn drop(&mut self) {
776 if let Some(ref indirect_validation) = self.device.indirect_validation {
777 let indirect_draw_validation = &indirect_validation.draw;
778 indirect_draw_validation.release_dst_entries(self.dst_entries.drain(..));
779 indirect_draw_validation.release_metadata_entries(self.metadata_entries.drain(..));
780 }
781 }
782}
783
784impl DrawResources {
785 pub(crate) fn new(device: Arc<Device>) -> Self {
786 DrawResources {
787 device,
788 dst_entries: Vec::new(),
789 metadata_entries: Vec::new(),
790 }
791 }
792
793 pub(crate) fn get_dst_buffer(&self, index: usize) -> &dyn hal::DynBuffer {
794 self.dst_entries.get(index).unwrap().buffer.as_ref()
795 }
796
797 fn get_dst_bind_group(&self, index: usize) -> &dyn hal::DynBindGroup {
798 self.dst_entries.get(index).unwrap().bind_group.as_ref()
799 }
800
801 fn get_metadata_buffer(&self, index: usize) -> &dyn hal::DynBuffer {
802 self.metadata_entries.get(index).unwrap().buffer.as_ref()
803 }
804
805 fn get_metadata_bind_group(&self, index: usize) -> &dyn hal::DynBindGroup {
806 self.metadata_entries
807 .get(index)
808 .unwrap()
809 .bind_group
810 .as_ref()
811 }
812
813 fn get_dst_subrange(
814 &mut self,
815 size: u64,
816 current_entry: &mut Option<CurrentEntry>,
817 ) -> Result<(usize, u64), DeviceError> {
818 let indirect_draw_validation = &self.device.indirect_validation.as_ref().unwrap().draw;
819 let ensure_entry = |index: usize| {
820 if self.dst_entries.len() <= index {
821 let entry = indirect_draw_validation
822 .acquire_dst_entry(self.device.raw(), self.device.instance_flags)?;
823 self.dst_entries.push(entry);
824 }
825 Ok(())
826 };
827 let entry_data = Self::get_subrange_impl(ensure_entry, current_entry, size)?;
828 Ok((entry_data.index, entry_data.offset))
829 }
830
831 fn get_metadata_subrange(
832 &mut self,
833 size: u64,
834 current_entry: &mut Option<CurrentEntry>,
835 ) -> Result<(usize, u64), DeviceError> {
836 let indirect_draw_validation = &self.device.indirect_validation.as_ref().unwrap().draw;
837 let ensure_entry = |index: usize| {
838 if self.metadata_entries.len() <= index {
839 let entry = indirect_draw_validation
840 .acquire_metadata_entry(self.device.raw(), self.device.instance_flags)?;
841 self.metadata_entries.push(entry);
842 }
843 Ok(())
844 };
845 let entry_data = Self::get_subrange_impl(ensure_entry, current_entry, size)?;
846 Ok((entry_data.index, entry_data.offset))
847 }
848
849 fn get_subrange_impl(
850 ensure_entry: impl FnOnce(usize) -> Result<(), hal::DeviceError>,
851 current_entry: &mut Option<CurrentEntry>,
852 size: u64,
853 ) -> Result<CurrentEntry, DeviceError> {
854 let index = if let Some(current_entry) = current_entry.as_mut() {
855 if current_entry.offset + size <= BUFFER_SIZE.get() {
856 let entry_data = current_entry.clone();
857 current_entry.offset += size;
858 return Ok(entry_data);
859 } else {
860 current_entry.index + 1
861 }
862 } else {
863 0
864 };
865
866 ensure_entry(index).map_err(DeviceError::from_hal)?;
867
868 let entry_data = CurrentEntry { index, offset: 0 };
869
870 *current_entry = Some(CurrentEntry {
871 index,
872 offset: size,
873 });
874
875 Ok(entry_data)
876 }
877}
878
879#[repr(C)]
881struct MetadataEntry {
882 src_offset: u32,
883 dst_offset: u32,
884 vertex_or_index_limit: u32,
885 instance_limit: u32,
886}
887
888impl MetadataEntry {
889 fn new(
890 indexed: bool,
891 src_offset: u64,
892 dst_offset: u64,
893 vertex_or_index_limit: u64,
894 instance_limit: u64,
895 ) -> Self {
896 debug_assert_eq!(
897 4,
898 size_of_val(&Limits::default().max_storage_buffer_binding_size)
899 );
900
901 let src_offset = src_offset as u32; let src_offset = src_offset / 4; let src_offset = src_offset | ((indexed as u32) << 31);
907
908 let max_limit = u32::MAX as u64 + u32::MAX as u64; let vertex_or_index_limit = vertex_or_index_limit.min(max_limit);
912 let vertex_or_index_limit_bit_32 = (vertex_or_index_limit >> 32) as u32; let vertex_or_index_limit = vertex_or_index_limit as u32; let instance_limit = instance_limit.min(max_limit);
916 let instance_limit_bit_32 = (instance_limit >> 32) as u32; let instance_limit = instance_limit as u32; let dst_offset = dst_offset as u32; let dst_offset = dst_offset / 4; let dst_offset =
926 dst_offset | (vertex_or_index_limit_bit_32 << 30) | (instance_limit_bit_32 << 31);
927
928 Self {
929 src_offset,
930 dst_offset,
931 vertex_or_index_limit,
932 instance_limit,
933 }
934 }
935}
936
937struct DrawIndirectValidationBatch {
938 src_buffer: Arc<crate::resource::Buffer>,
939 src_dynamic_offset: u64,
940 dst_resource_index: usize,
941 entries: Vec<MetadataEntry>,
942
943 staging_buffer_index: usize,
944 staging_buffer_offset: u64,
945 metadata_resource_index: usize,
946 metadata_buffer_offset: u64,
947}
948
949impl DrawIndirectValidationBatch {
950 fn metadata(&self) -> &[u8] {
952 unsafe {
953 core::slice::from_raw_parts(
954 self.entries.as_ptr().cast::<u8>(),
955 self.entries.len() * size_of::<MetadataEntry>(),
956 )
957 }
958 }
959}
960
961pub(crate) struct DrawBatcher {
963 batches: FastHashMap<(TrackerIndex, u64, usize), DrawIndirectValidationBatch>,
964 current_dst_entry: Option<CurrentEntry>,
965}
966
967impl DrawBatcher {
968 pub(crate) fn new() -> Self {
969 Self {
970 batches: FastHashMap::default(),
971 current_dst_entry: None,
972 }
973 }
974
975 pub(crate) fn add<'a>(
980 &mut self,
981 indirect_draw_validation_resources: &'a mut DrawResources,
982 device: &Device,
983 src_buffer: &Arc<crate::resource::Buffer>,
984 offset: u64,
985 family: crate::command::DrawCommandFamily,
986 vertex_or_index_limit: u64,
987 instance_limit: u64,
988 ) -> Result<(usize, u64), DeviceError> {
989 let extra = if device.backend() == wgt::Backend::Dx12 {
991 3 * size_of::<u32>() as u64
992 } else {
993 0
994 };
995 let stride = extra + crate::command::get_stride_of_indirect_args(family);
996
997 let (dst_resource_index, dst_offset) = indirect_draw_validation_resources
998 .get_dst_subrange(stride, &mut self.current_dst_entry)?;
999
1000 let buffer_size = src_buffer.size;
1001 let limits = device.adapter.limits();
1002 let (src_dynamic_offset, src_offset) = calculate_src_offsets(buffer_size, &limits, offset);
1003
1004 let src_buffer_tracker_index = src_buffer.tracker_index();
1005
1006 let entry = MetadataEntry::new(
1007 family == crate::command::DrawCommandFamily::DrawIndexed,
1008 src_offset,
1009 dst_offset,
1010 vertex_or_index_limit,
1011 instance_limit,
1012 );
1013
1014 match self.batches.entry((
1015 src_buffer_tracker_index,
1016 src_dynamic_offset,
1017 dst_resource_index,
1018 )) {
1019 hashbrown::hash_map::Entry::Occupied(mut occupied_entry) => {
1020 occupied_entry.get_mut().entries.push(entry)
1021 }
1022 hashbrown::hash_map::Entry::Vacant(vacant_entry) => {
1023 vacant_entry.insert(DrawIndirectValidationBatch {
1024 src_buffer: src_buffer.clone(),
1025 src_dynamic_offset,
1026 dst_resource_index,
1027 entries: vec![entry],
1028
1029 staging_buffer_index: 0,
1031 staging_buffer_offset: 0,
1032 metadata_resource_index: 0,
1033 metadata_buffer_offset: 0,
1034 });
1035 }
1036 }
1037
1038 Ok((dst_resource_index, dst_offset))
1039 }
1040}