1use super::{
2 utils::{BufferBarrierScratch, BufferBarriers, UniqueIndexExt as _, UniqueIndexScratch},
3 CreateIndirectValidationPipelineError,
4};
5use crate::{
6 device::{queue::TempResource, Device, DeviceError},
7 lock::{rank, Mutex},
8 pipeline::{CreateComputePipelineError, CreateShaderModuleError},
9 resource::{StagingBuffer, Trackable},
10 snatch::SnatchGuard,
11 track::TrackerIndex,
12 FastHashMap,
13};
14use alloc::{boxed::Box, string::ToString, sync::Arc, vec, vec::Vec};
15use core::{
16 mem::{size_of, size_of_val},
17 num::NonZeroU64,
18};
19use wgt::Limits;
20
21const BUFFER_SIZE: wgt::BufferSize = unsafe { wgt::BufferSize::new_unchecked(1_048_560) };
36
37#[derive(Debug)]
49pub(crate) struct Draw {
50 module: Box<dyn hal::DynShaderModule>,
51 metadata_bind_group_layout: Box<dyn hal::DynBindGroupLayout>,
52 src_bind_group_layout: Box<dyn hal::DynBindGroupLayout>,
53 dst_bind_group_layout: Box<dyn hal::DynBindGroupLayout>,
54 pipeline_layout: Box<dyn hal::DynPipelineLayout>,
55 pipeline: Box<dyn hal::DynComputePipeline>,
56
57 free_indirect_entries: Mutex<Vec<BufferPoolEntry>>,
58 free_metadata_entries: Mutex<Vec<BufferPoolEntry>>,
59}
60
61impl Draw {
62 pub(super) fn new(
63 device: &dyn hal::DynDevice,
64 required_features: &wgt::Features,
65 backend: wgt::Backend,
66 ) -> Result<Self, CreateIndirectValidationPipelineError> {
67 let module = create_validation_module(device)?;
68
69 let metadata_bind_group_layout =
70 create_bind_group_layout(device, true, false, BUFFER_SIZE)?;
71 let src_bind_group_layout =
72 create_bind_group_layout(device, true, true, wgt::BufferSize::new(4 * 4).unwrap())?;
73 let dst_bind_group_layout = create_bind_group_layout(device, false, false, BUFFER_SIZE)?;
74
75 let pipeline_layout_desc = hal::PipelineLayoutDescriptor {
76 label: None,
77 flags: hal::PipelineLayoutFlags::empty(),
78 bind_group_layouts: &[
79 metadata_bind_group_layout.as_ref(),
80 src_bind_group_layout.as_ref(),
81 dst_bind_group_layout.as_ref(),
82 ],
83 push_constant_ranges: &[wgt::PushConstantRange {
84 stages: wgt::ShaderStages::COMPUTE,
85 range: 0..8,
86 }],
87 };
88 let pipeline_layout = unsafe {
89 device
90 .create_pipeline_layout(&pipeline_layout_desc)
91 .map_err(DeviceError::from_hal)?
92 };
93
94 let supports_indirect_first_instance =
95 required_features.contains(wgt::Features::INDIRECT_FIRST_INSTANCE);
96 let write_d3d12_special_constants = backend == wgt::Backend::Dx12;
97 let pipeline = create_validation_pipeline(
98 device,
99 module.as_ref(),
100 pipeline_layout.as_ref(),
101 supports_indirect_first_instance,
102 write_d3d12_special_constants,
103 )?;
104
105 Ok(Self {
106 module,
107 metadata_bind_group_layout,
108 src_bind_group_layout,
109 dst_bind_group_layout,
110 pipeline_layout,
111 pipeline,
112
113 free_indirect_entries: Mutex::new(rank::BUFFER_POOL, Vec::new()),
114 free_metadata_entries: Mutex::new(rank::BUFFER_POOL, Vec::new()),
115 })
116 }
117
118 pub(super) fn create_src_bind_group(
120 &self,
121 device: &dyn hal::DynDevice,
122 limits: &Limits,
123 buffer_size: u64,
124 buffer: &dyn hal::DynBuffer,
125 ) -> Result<Option<Box<dyn hal::DynBindGroup>>, DeviceError> {
126 let binding_size = calculate_src_buffer_binding_size(buffer_size, limits);
127 let Some(binding_size) = NonZeroU64::new(binding_size) else {
128 return Ok(None);
129 };
130 let hal_desc = hal::BindGroupDescriptor {
131 label: None,
132 layout: self.src_bind_group_layout.as_ref(),
133 entries: &[hal::BindGroupEntry {
134 binding: 0,
135 resource_index: 0,
136 count: 1,
137 }],
138 buffers: &[hal::BufferBinding::new_unchecked(buffer, 0, binding_size)],
140 samplers: &[],
141 textures: &[],
142 acceleration_structures: &[],
143 external_textures: &[],
144 };
145 unsafe {
146 device
147 .create_bind_group(&hal_desc)
148 .map(Some)
149 .map_err(DeviceError::from_hal)
150 }
151 }
152
153 fn acquire_dst_entry(
154 &self,
155 device: &dyn hal::DynDevice,
156 ) -> Result<BufferPoolEntry, hal::DeviceError> {
157 let mut free_buffers = self.free_indirect_entries.lock();
158 match free_buffers.pop() {
159 Some(buffer) => Ok(buffer),
160 None => {
161 let usage = wgt::BufferUses::INDIRECT | wgt::BufferUses::STORAGE_READ_WRITE;
162 create_buffer_and_bind_group(device, usage, self.dst_bind_group_layout.as_ref())
163 }
164 }
165 }
166
167 fn release_dst_entries(&self, entries: impl Iterator<Item = BufferPoolEntry>) {
168 self.free_indirect_entries.lock().extend(entries);
169 }
170
171 fn acquire_metadata_entry(
172 &self,
173 device: &dyn hal::DynDevice,
174 ) -> Result<BufferPoolEntry, hal::DeviceError> {
175 let mut free_buffers = self.free_metadata_entries.lock();
176 match free_buffers.pop() {
177 Some(buffer) => Ok(buffer),
178 None => {
179 let usage = wgt::BufferUses::COPY_DST | wgt::BufferUses::STORAGE_READ_ONLY;
180 create_buffer_and_bind_group(
181 device,
182 usage,
183 self.metadata_bind_group_layout.as_ref(),
184 )
185 }
186 }
187 }
188
189 fn release_metadata_entries(&self, entries: impl Iterator<Item = BufferPoolEntry>) {
190 self.free_metadata_entries.lock().extend(entries);
191 }
192
193 pub(crate) fn inject_validation_pass(
195 &self,
196 device: &Arc<Device>,
197 snatch_guard: &SnatchGuard,
198 resources: &mut DrawResources,
199 temp_resources: &mut Vec<TempResource>,
200 encoder: &mut dyn hal::DynCommandEncoder,
201 batcher: DrawBatcher,
202 ) -> Result<(), DeviceError> {
203 let mut batches = batcher.batches;
204
205 if batches.is_empty() {
206 return Ok(());
207 }
208
209 let max_staging_buffer_size = 1 << 26; let mut staging_buffers = Vec::new();
212
213 let mut current_size = 0;
214 for batch in batches.values_mut() {
215 let data = batch.metadata();
216 let offset = if current_size + data.len() > max_staging_buffer_size {
217 let staging_buffer =
218 StagingBuffer::new(device, NonZeroU64::new(current_size as u64).unwrap())?;
219 staging_buffers.push(staging_buffer);
220 current_size = data.len();
221 0
222 } else {
223 let offset = current_size;
224 current_size += data.len();
225 offset as u64
226 };
227 batch.staging_buffer_index = staging_buffers.len();
228 batch.staging_buffer_offset = offset;
229 }
230 if current_size != 0 {
231 let staging_buffer =
232 StagingBuffer::new(device, NonZeroU64::new(current_size as u64).unwrap())?;
233 staging_buffers.push(staging_buffer);
234 }
235
236 for batch in batches.values() {
237 let data = batch.metadata();
238 let staging_buffer = &mut staging_buffers[batch.staging_buffer_index];
239 unsafe {
240 staging_buffer.write_with_offset(
241 data,
242 0,
243 batch.staging_buffer_offset as isize,
244 data.len(),
245 )
246 };
247 }
248
249 let staging_buffers: Vec<_> = staging_buffers
250 .into_iter()
251 .map(|buffer| buffer.flush())
252 .collect();
253
254 let mut current_metadata_entry = None;
255 for batch in batches.values_mut() {
256 let data = batch.metadata();
257 let (metadata_resource_index, metadata_buffer_offset) =
258 resources.get_metadata_subrange(data.len() as u64, &mut current_metadata_entry)?;
259 batch.metadata_resource_index = metadata_resource_index;
260 batch.metadata_buffer_offset = metadata_buffer_offset;
261 }
262
263 let buffer_barrier_scratch = &mut BufferBarrierScratch::new();
264 let unique_index_scratch = &mut UniqueIndexScratch::new();
265
266 BufferBarriers::new(buffer_barrier_scratch)
267 .extend(
268 batches
269 .values()
270 .map(|batch| batch.staging_buffer_index)
271 .unique(unique_index_scratch)
272 .map(|index| hal::BufferBarrier {
273 buffer: staging_buffers[index].raw(),
274 usage: hal::StateTransition {
275 from: wgt::BufferUses::MAP_WRITE,
276 to: wgt::BufferUses::COPY_SRC,
277 },
278 }),
279 )
280 .extend(
281 batches
282 .values()
283 .map(|batch| batch.metadata_resource_index)
284 .unique(unique_index_scratch)
285 .map(|index| hal::BufferBarrier {
286 buffer: resources.get_metadata_buffer(index),
287 usage: hal::StateTransition {
288 from: wgt::BufferUses::STORAGE_READ_ONLY,
289 to: wgt::BufferUses::COPY_DST,
290 },
291 }),
292 )
293 .encode(encoder);
294
295 for batch in batches.values() {
296 let data = batch.metadata();
297 let data_size = NonZeroU64::new(data.len() as u64).unwrap();
298
299 let staging_buffer = &staging_buffers[batch.staging_buffer_index];
300
301 let metadata_buffer = resources.get_metadata_buffer(batch.metadata_resource_index);
302
303 unsafe {
304 encoder.copy_buffer_to_buffer(
305 staging_buffer.raw(),
306 metadata_buffer,
307 &[hal::BufferCopy {
308 src_offset: batch.staging_buffer_offset,
309 dst_offset: batch.metadata_buffer_offset,
310 size: data_size,
311 }],
312 );
313 }
314 }
315
316 for staging_buffer in staging_buffers {
317 temp_resources.push(TempResource::StagingBuffer(staging_buffer));
318 }
319
320 BufferBarriers::new(buffer_barrier_scratch)
321 .extend(
322 batches
323 .values()
324 .map(|batch| batch.metadata_resource_index)
325 .unique(unique_index_scratch)
326 .map(|index| hal::BufferBarrier {
327 buffer: resources.get_metadata_buffer(index),
328 usage: hal::StateTransition {
329 from: wgt::BufferUses::COPY_DST,
330 to: wgt::BufferUses::STORAGE_READ_ONLY,
331 },
332 }),
333 )
334 .extend(
335 batches
336 .values()
337 .map(|batch| batch.dst_resource_index)
338 .unique(unique_index_scratch)
339 .map(|index| hal::BufferBarrier {
340 buffer: resources.get_dst_buffer(index),
341 usage: hal::StateTransition {
342 from: wgt::BufferUses::INDIRECT,
343 to: wgt::BufferUses::STORAGE_READ_WRITE,
344 },
345 }),
346 )
347 .encode(encoder);
348
349 let desc = hal::ComputePassDescriptor {
350 label: None,
351 timestamp_writes: None,
352 };
353 unsafe {
354 encoder.begin_compute_pass(&desc);
355 }
356 unsafe {
357 encoder.set_compute_pipeline(self.pipeline.as_ref());
358 }
359
360 for batch in batches.values() {
361 let pipeline_layout = self.pipeline_layout.as_ref();
362
363 let metadata_start =
364 (batch.metadata_buffer_offset / size_of::<MetadataEntry>() as u64) as u32;
365 let metadata_count = batch.entries.len() as u32;
366 unsafe {
367 encoder.set_push_constants(
368 pipeline_layout,
369 wgt::ShaderStages::COMPUTE,
370 0,
371 &[metadata_start, metadata_count],
372 );
373 }
374
375 let metadata_bind_group =
376 resources.get_metadata_bind_group(batch.metadata_resource_index);
377 unsafe {
378 encoder.set_bind_group(pipeline_layout, 0, Some(metadata_bind_group), &[]);
379 }
380
381 let src_bind_group = batch
382 .src_buffer
383 .indirect_validation_bind_groups
384 .get(snatch_guard)
385 .unwrap()
386 .draw
387 .as_ref();
388 unsafe {
389 encoder.set_bind_group(
390 pipeline_layout,
391 1,
392 Some(src_bind_group),
393 &[batch.src_dynamic_offset as u32],
394 );
395 }
396
397 let dst_bind_group = resources.get_dst_bind_group(batch.dst_resource_index);
398 unsafe {
399 encoder.set_bind_group(pipeline_layout, 2, Some(dst_bind_group), &[]);
400 }
401
402 unsafe {
403 encoder.dispatch([(batch.entries.len() as u32).div_ceil(64), 1, 1]);
404 }
405 }
406
407 unsafe {
408 encoder.end_compute_pass();
409 }
410
411 BufferBarriers::new(buffer_barrier_scratch)
412 .extend(
413 batches
414 .values()
415 .map(|batch| batch.dst_resource_index)
416 .unique(unique_index_scratch)
417 .map(|index| hal::BufferBarrier {
418 buffer: resources.get_dst_buffer(index),
419 usage: hal::StateTransition {
420 from: wgt::BufferUses::STORAGE_READ_WRITE,
421 to: wgt::BufferUses::INDIRECT,
422 },
423 }),
424 )
425 .encode(encoder);
426
427 Ok(())
428 }
429
430 pub(super) fn dispose(self, device: &dyn hal::DynDevice) {
431 let Draw {
432 module,
433 metadata_bind_group_layout,
434 src_bind_group_layout,
435 dst_bind_group_layout,
436 pipeline_layout,
437 pipeline,
438
439 free_indirect_entries,
440 free_metadata_entries,
441 } = self;
442
443 for entry in free_indirect_entries.into_inner().drain(..) {
444 unsafe {
445 device.destroy_bind_group(entry.bind_group);
446 device.destroy_buffer(entry.buffer);
447 }
448 }
449
450 for entry in free_metadata_entries.into_inner().drain(..) {
451 unsafe {
452 device.destroy_bind_group(entry.bind_group);
453 device.destroy_buffer(entry.buffer);
454 }
455 }
456
457 unsafe {
458 device.destroy_compute_pipeline(pipeline);
459 device.destroy_pipeline_layout(pipeline_layout);
460 device.destroy_bind_group_layout(metadata_bind_group_layout);
461 device.destroy_bind_group_layout(src_bind_group_layout);
462 device.destroy_bind_group_layout(dst_bind_group_layout);
463 device.destroy_shader_module(module);
464 }
465 }
466}
467
468fn create_validation_module(
469 device: &dyn hal::DynDevice,
470) -> Result<Box<dyn hal::DynShaderModule>, CreateIndirectValidationPipelineError> {
471 let src = include_str!("./validate_draw.wgsl");
472
473 #[cfg(feature = "wgsl")]
474 let module = naga::front::wgsl::parse_str(src).map_err(|inner| {
475 CreateShaderModuleError::Parsing(naga::error::ShaderError {
476 source: src.to_string(),
477 label: None,
478 inner: Box::new(inner),
479 })
480 })?;
481 #[cfg(not(feature = "wgsl"))]
482 #[allow(clippy::diverging_sub_expression)]
483 let module = panic!("Indirect validation requires the wgsl feature flag to be enabled!");
484
485 let info = crate::device::create_validator(
486 wgt::Features::PUSH_CONSTANTS,
487 wgt::DownlevelFlags::empty(),
488 naga::valid::ValidationFlags::all(),
489 )
490 .validate(&module)
491 .map_err(|inner| {
492 CreateShaderModuleError::Validation(naga::error::ShaderError {
493 source: src.to_string(),
494 label: None,
495 inner: Box::new(inner),
496 })
497 })?;
498 let hal_shader = hal::ShaderInput::Naga(hal::NagaShader {
499 module: alloc::borrow::Cow::Owned(module),
500 info,
501 debug_source: None,
502 });
503 let hal_desc = hal::ShaderModuleDescriptor {
504 label: None,
505 runtime_checks: wgt::ShaderRuntimeChecks::unchecked(),
506 };
507 let module = unsafe { device.create_shader_module(&hal_desc, hal_shader) }.map_err(
508 |error| match error {
509 hal::ShaderError::Device(error) => {
510 CreateShaderModuleError::Device(DeviceError::from_hal(error))
511 }
512 hal::ShaderError::Compilation(ref msg) => {
513 log::error!("Shader error: {msg}");
514 CreateShaderModuleError::Generation
515 }
516 },
517 )?;
518
519 Ok(module)
520}
521
522fn create_validation_pipeline(
523 device: &dyn hal::DynDevice,
524 module: &dyn hal::DynShaderModule,
525 pipeline_layout: &dyn hal::DynPipelineLayout,
526 supports_indirect_first_instance: bool,
527 write_d3d12_special_constants: bool,
528) -> Result<Box<dyn hal::DynComputePipeline>, CreateIndirectValidationPipelineError> {
529 let pipeline_desc = hal::ComputePipelineDescriptor {
530 label: None,
531 layout: pipeline_layout,
532 stage: hal::ProgrammableStage {
533 module,
534 entry_point: "main",
535 constants: &hashbrown::HashMap::from([
536 (
537 "supports_indirect_first_instance".to_string(),
538 f64::from(supports_indirect_first_instance),
539 ),
540 (
541 "write_d3d12_special_constants".to_string(),
542 f64::from(write_d3d12_special_constants),
543 ),
544 ]),
545 zero_initialize_workgroup_memory: false,
546 },
547 cache: None,
548 };
549 let pipeline =
550 unsafe { device.create_compute_pipeline(&pipeline_desc) }.map_err(|err| match err {
551 hal::PipelineError::Device(error) => {
552 CreateComputePipelineError::Device(DeviceError::from_hal(error))
553 }
554 hal::PipelineError::Linkage(_stages, msg) => CreateComputePipelineError::Internal(msg),
555 hal::PipelineError::EntryPoint(_stage) => CreateComputePipelineError::Internal(
556 crate::device::ENTRYPOINT_FAILURE_ERROR.to_string(),
557 ),
558 hal::PipelineError::PipelineConstants(_, error) => {
559 CreateComputePipelineError::PipelineConstants(error)
560 }
561 })?;
562
563 Ok(pipeline)
564}
565
566fn create_bind_group_layout(
567 device: &dyn hal::DynDevice,
568 read_only: bool,
569 has_dynamic_offset: bool,
570 min_binding_size: wgt::BufferSize,
571) -> Result<Box<dyn hal::DynBindGroupLayout>, CreateIndirectValidationPipelineError> {
572 let bind_group_layout_desc = hal::BindGroupLayoutDescriptor {
573 label: None,
574 flags: hal::BindGroupLayoutFlags::empty(),
575 entries: &[wgt::BindGroupLayoutEntry {
576 binding: 0,
577 visibility: wgt::ShaderStages::COMPUTE,
578 ty: wgt::BindingType::Buffer {
579 ty: wgt::BufferBindingType::Storage { read_only },
580 has_dynamic_offset,
581 min_binding_size: Some(min_binding_size),
582 },
583 count: None,
584 }],
585 };
586 let bind_group_layout = unsafe {
587 device
588 .create_bind_group_layout(&bind_group_layout_desc)
589 .map_err(DeviceError::from_hal)?
590 };
591
592 Ok(bind_group_layout)
593}
594
595fn calculate_src_buffer_binding_size(buffer_size: u64, limits: &Limits) -> u64 {
597 let max_storage_buffer_binding_size = limits.max_storage_buffer_binding_size as u64;
598 let min_storage_buffer_offset_alignment = limits.min_storage_buffer_offset_alignment as u64;
599
600 if buffer_size <= max_storage_buffer_binding_size {
601 buffer_size
602 } else {
603 let buffer_rem = buffer_size % min_storage_buffer_offset_alignment;
604 let binding_rem = max_storage_buffer_binding_size % min_storage_buffer_offset_alignment;
605
606 if buffer_rem <= binding_rem {
609 max_storage_buffer_binding_size - binding_rem + buffer_rem
610 }
611 else {
613 max_storage_buffer_binding_size - binding_rem - min_storage_buffer_offset_alignment
614 + buffer_rem
615 }
616 }
617}
618
619fn calculate_src_offsets(buffer_size: u64, limits: &Limits, offset: u64) -> (u64, u64) {
621 let binding_size = calculate_src_buffer_binding_size(buffer_size, limits);
622
623 let min_storage_buffer_offset_alignment = limits.min_storage_buffer_offset_alignment as u64;
624
625 let chunk_adjustment = match min_storage_buffer_offset_alignment {
626 4 => 0,
628 8 => 2,
632 16.. => 1,
636 _ => unreachable!(),
637 };
638
639 let chunks = binding_size / min_storage_buffer_offset_alignment;
640 let dynamic_offset_stride =
641 chunks.saturating_sub(chunk_adjustment) * min_storage_buffer_offset_alignment;
642
643 if dynamic_offset_stride == 0 {
644 return (0, offset);
645 }
646
647 let max_dynamic_offset = buffer_size - binding_size;
648 let max_dynamic_offset_index = max_dynamic_offset / dynamic_offset_stride;
649
650 let src_dynamic_offset_index = offset / dynamic_offset_stride;
651
652 let src_dynamic_offset =
653 src_dynamic_offset_index.min(max_dynamic_offset_index) * dynamic_offset_stride;
654 let src_offset = offset - src_dynamic_offset;
655
656 (src_dynamic_offset, src_offset)
657}
658
659#[derive(Debug)]
660struct BufferPoolEntry {
661 buffer: Box<dyn hal::DynBuffer>,
662 bind_group: Box<dyn hal::DynBindGroup>,
663}
664
665fn create_buffer_and_bind_group(
666 device: &dyn hal::DynDevice,
667 usage: wgt::BufferUses,
668 bind_group_layout: &dyn hal::DynBindGroupLayout,
669) -> Result<BufferPoolEntry, hal::DeviceError> {
670 let buffer_desc = hal::BufferDescriptor {
671 label: None,
672 size: BUFFER_SIZE.get(),
673 usage,
674 memory_flags: hal::MemoryFlags::empty(),
675 };
676 let buffer = unsafe { device.create_buffer(&buffer_desc) }?;
677 let bind_group_desc = hal::BindGroupDescriptor {
678 label: None,
679 layout: bind_group_layout,
680 entries: &[hal::BindGroupEntry {
681 binding: 0,
682 resource_index: 0,
683 count: 1,
684 }],
685 buffers: &[hal::BufferBinding::new_unchecked(
687 buffer.as_ref(),
688 0,
689 BUFFER_SIZE,
690 )],
691 samplers: &[],
692 textures: &[],
693 acceleration_structures: &[],
694 external_textures: &[],
695 };
696 let bind_group = unsafe { device.create_bind_group(&bind_group_desc) }?;
697 Ok(BufferPoolEntry { buffer, bind_group })
698}
699
700#[derive(Clone)]
701struct CurrentEntry {
702 index: usize,
703 offset: u64,
704}
705
706pub(crate) struct DrawResources {
708 device: Arc<Device>,
709 dst_entries: Vec<BufferPoolEntry>,
710 metadata_entries: Vec<BufferPoolEntry>,
711}
712
713impl Drop for DrawResources {
714 fn drop(&mut self) {
715 if let Some(ref indirect_validation) = self.device.indirect_validation {
716 let indirect_draw_validation = &indirect_validation.draw;
717 indirect_draw_validation.release_dst_entries(self.dst_entries.drain(..));
718 indirect_draw_validation.release_metadata_entries(self.metadata_entries.drain(..));
719 }
720 }
721}
722
723impl DrawResources {
724 pub(crate) fn new(device: Arc<Device>) -> Self {
725 DrawResources {
726 device,
727 dst_entries: Vec::new(),
728 metadata_entries: Vec::new(),
729 }
730 }
731
732 pub(crate) fn get_dst_buffer(&self, index: usize) -> &dyn hal::DynBuffer {
733 self.dst_entries.get(index).unwrap().buffer.as_ref()
734 }
735
736 fn get_dst_bind_group(&self, index: usize) -> &dyn hal::DynBindGroup {
737 self.dst_entries.get(index).unwrap().bind_group.as_ref()
738 }
739
740 fn get_metadata_buffer(&self, index: usize) -> &dyn hal::DynBuffer {
741 self.metadata_entries.get(index).unwrap().buffer.as_ref()
742 }
743
744 fn get_metadata_bind_group(&self, index: usize) -> &dyn hal::DynBindGroup {
745 self.metadata_entries
746 .get(index)
747 .unwrap()
748 .bind_group
749 .as_ref()
750 }
751
752 fn get_dst_subrange(
753 &mut self,
754 size: u64,
755 current_entry: &mut Option<CurrentEntry>,
756 ) -> Result<(usize, u64), DeviceError> {
757 let indirect_draw_validation = &self.device.indirect_validation.as_ref().unwrap().draw;
758 let ensure_entry = |index: usize| {
759 if self.dst_entries.len() <= index {
760 let entry = indirect_draw_validation.acquire_dst_entry(self.device.raw())?;
761 self.dst_entries.push(entry);
762 }
763 Ok(())
764 };
765 let entry_data = Self::get_subrange_impl(ensure_entry, current_entry, size)?;
766 Ok((entry_data.index, entry_data.offset))
767 }
768
769 fn get_metadata_subrange(
770 &mut self,
771 size: u64,
772 current_entry: &mut Option<CurrentEntry>,
773 ) -> Result<(usize, u64), DeviceError> {
774 let indirect_draw_validation = &self.device.indirect_validation.as_ref().unwrap().draw;
775 let ensure_entry = |index: usize| {
776 if self.metadata_entries.len() <= index {
777 let entry = indirect_draw_validation.acquire_metadata_entry(self.device.raw())?;
778 self.metadata_entries.push(entry);
779 }
780 Ok(())
781 };
782 let entry_data = Self::get_subrange_impl(ensure_entry, current_entry, size)?;
783 Ok((entry_data.index, entry_data.offset))
784 }
785
786 fn get_subrange_impl(
787 ensure_entry: impl FnOnce(usize) -> Result<(), hal::DeviceError>,
788 current_entry: &mut Option<CurrentEntry>,
789 size: u64,
790 ) -> Result<CurrentEntry, DeviceError> {
791 let index = if let Some(current_entry) = current_entry.as_mut() {
792 if current_entry.offset + size <= BUFFER_SIZE.get() {
793 let entry_data = current_entry.clone();
794 current_entry.offset += size;
795 return Ok(entry_data);
796 } else {
797 current_entry.index + 1
798 }
799 } else {
800 0
801 };
802
803 ensure_entry(index).map_err(DeviceError::from_hal)?;
804
805 let entry_data = CurrentEntry { index, offset: 0 };
806
807 *current_entry = Some(CurrentEntry {
808 index,
809 offset: size,
810 });
811
812 Ok(entry_data)
813 }
814}
815
816#[repr(C)]
818struct MetadataEntry {
819 src_offset: u32,
820 dst_offset: u32,
821 vertex_or_index_limit: u32,
822 instance_limit: u32,
823}
824
825impl MetadataEntry {
826 fn new(
827 indexed: bool,
828 src_offset: u64,
829 dst_offset: u64,
830 vertex_or_index_limit: u64,
831 instance_limit: u64,
832 ) -> Self {
833 debug_assert_eq!(
834 4,
835 size_of_val(&Limits::default().max_storage_buffer_binding_size)
836 );
837
838 let src_offset = src_offset as u32; let src_offset = src_offset / 4; let src_offset = src_offset | ((indexed as u32) << 31);
844
845 let max_limit = u32::MAX as u64 + u32::MAX as u64; let vertex_or_index_limit = vertex_or_index_limit.min(max_limit);
849 let vertex_or_index_limit_bit_32 = (vertex_or_index_limit >> 32) as u32; let vertex_or_index_limit = vertex_or_index_limit as u32; let instance_limit = instance_limit.min(max_limit);
853 let instance_limit_bit_32 = (instance_limit >> 32) as u32; let instance_limit = instance_limit as u32; let dst_offset = dst_offset as u32; let dst_offset = dst_offset / 4; let dst_offset =
863 dst_offset | (vertex_or_index_limit_bit_32 << 30) | (instance_limit_bit_32 << 31);
864
865 Self {
866 src_offset,
867 dst_offset,
868 vertex_or_index_limit,
869 instance_limit,
870 }
871 }
872}
873
874struct DrawIndirectValidationBatch {
875 src_buffer: Arc<crate::resource::Buffer>,
876 src_dynamic_offset: u64,
877 dst_resource_index: usize,
878 entries: Vec<MetadataEntry>,
879
880 staging_buffer_index: usize,
881 staging_buffer_offset: u64,
882 metadata_resource_index: usize,
883 metadata_buffer_offset: u64,
884}
885
886impl DrawIndirectValidationBatch {
887 fn metadata(&self) -> &[u8] {
889 unsafe {
890 core::slice::from_raw_parts(
891 self.entries.as_ptr().cast::<u8>(),
892 self.entries.len() * size_of::<MetadataEntry>(),
893 )
894 }
895 }
896}
897
898pub(crate) struct DrawBatcher {
900 batches: FastHashMap<(TrackerIndex, u64, usize), DrawIndirectValidationBatch>,
901 current_dst_entry: Option<CurrentEntry>,
902}
903
904impl DrawBatcher {
905 pub(crate) fn new() -> Self {
906 Self {
907 batches: FastHashMap::default(),
908 current_dst_entry: None,
909 }
910 }
911
912 pub(crate) fn add<'a>(
917 &mut self,
918 indirect_draw_validation_resources: &'a mut DrawResources,
919 device: &Device,
920 src_buffer: &Arc<crate::resource::Buffer>,
921 offset: u64,
922 family: crate::command::DrawCommandFamily,
923 vertex_or_index_limit: u64,
924 instance_limit: u64,
925 ) -> Result<(usize, u64), DeviceError> {
926 let extra = if device.backend() == wgt::Backend::Dx12 {
928 3 * size_of::<u32>() as u64
929 } else {
930 0
931 };
932 let stride = extra + crate::command::get_stride_of_indirect_args(family);
933
934 let (dst_resource_index, dst_offset) = indirect_draw_validation_resources
935 .get_dst_subrange(stride, &mut self.current_dst_entry)?;
936
937 let buffer_size = src_buffer.size;
938 let limits = device.adapter.limits();
939 let (src_dynamic_offset, src_offset) = calculate_src_offsets(buffer_size, &limits, offset);
940
941 let src_buffer_tracker_index = src_buffer.tracker_index();
942
943 let entry = MetadataEntry::new(
944 family == crate::command::DrawCommandFamily::DrawIndexed,
945 src_offset,
946 dst_offset,
947 vertex_or_index_limit,
948 instance_limit,
949 );
950
951 match self.batches.entry((
952 src_buffer_tracker_index,
953 src_dynamic_offset,
954 dst_resource_index,
955 )) {
956 hashbrown::hash_map::Entry::Occupied(mut occupied_entry) => {
957 occupied_entry.get_mut().entries.push(entry)
958 }
959 hashbrown::hash_map::Entry::Vacant(vacant_entry) => {
960 vacant_entry.insert(DrawIndirectValidationBatch {
961 src_buffer: src_buffer.clone(),
962 src_dynamic_offset,
963 dst_resource_index,
964 entries: vec![entry],
965
966 staging_buffer_index: 0,
968 staging_buffer_offset: 0,
969 metadata_resource_index: 0,
970 metadata_buffer_offset: 0,
971 });
972 }
973 }
974
975 Ok((dst_resource_index, dst_offset))
976 }
977}