1use super::{
2 utils::{BufferBarrierScratch, BufferBarriers, UniqueIndexExt as _, UniqueIndexScratch},
3 CreateIndirectValidationPipelineError,
4};
5use crate::{
6 command::RenderPassErrorInner,
7 device::{queue::TempResource, Device, DeviceError},
8 lock::{rank, Mutex},
9 pipeline::{CreateComputePipelineError, CreateShaderModuleError},
10 resource::{RawResourceAccess as _, StagingBuffer, Trackable},
11 snatch::SnatchGuard,
12 track::TrackerIndex,
13 FastHashMap,
14};
15use alloc::{boxed::Box, string::ToString, sync::Arc, vec, vec::Vec};
16use core::{
17 mem::{size_of, size_of_val},
18 num::NonZeroU64,
19};
20use wgt::Limits;
21
22const BUFFER_SIZE: wgt::BufferSize = unsafe { wgt::BufferSize::new_unchecked(1_048_560) };
37
38#[derive(Debug)]
50pub(crate) struct Draw {
51 module: Box<dyn hal::DynShaderModule>,
52 metadata_bind_group_layout: Box<dyn hal::DynBindGroupLayout>,
53 src_bind_group_layout: Box<dyn hal::DynBindGroupLayout>,
54 dst_bind_group_layout: Box<dyn hal::DynBindGroupLayout>,
55 pipeline_layout: Box<dyn hal::DynPipelineLayout>,
56 pipeline: Box<dyn hal::DynComputePipeline>,
57
58 free_indirect_entries: Mutex<Vec<BufferPoolEntry>>,
59 free_metadata_entries: Mutex<Vec<BufferPoolEntry>>,
60}
61
62impl Draw {
63 pub(super) fn new(
64 device: &dyn hal::DynDevice,
65 required_features: &wgt::Features,
66 backend: wgt::Backend,
67 ) -> Result<Self, CreateIndirectValidationPipelineError> {
68 let module = create_validation_module(device)?;
69
70 let metadata_bind_group_layout =
71 create_bind_group_layout(device, true, false, BUFFER_SIZE)?;
72 let src_bind_group_layout =
73 create_bind_group_layout(device, true, true, wgt::BufferSize::new(4 * 4).unwrap())?;
74 let dst_bind_group_layout = create_bind_group_layout(device, false, false, BUFFER_SIZE)?;
75
76 let pipeline_layout_desc = hal::PipelineLayoutDescriptor {
77 label: None,
78 flags: hal::PipelineLayoutFlags::empty(),
79 bind_group_layouts: &[
80 metadata_bind_group_layout.as_ref(),
81 src_bind_group_layout.as_ref(),
82 dst_bind_group_layout.as_ref(),
83 ],
84 immediate_size: 8,
85 };
86 let pipeline_layout = unsafe {
87 device
88 .create_pipeline_layout(&pipeline_layout_desc)
89 .map_err(DeviceError::from_hal)?
90 };
91
92 let supports_indirect_first_instance =
93 required_features.contains(wgt::Features::INDIRECT_FIRST_INSTANCE);
94 let write_d3d12_special_constants = backend == wgt::Backend::Dx12;
95 let pipeline = create_validation_pipeline(
96 device,
97 module.as_ref(),
98 pipeline_layout.as_ref(),
99 supports_indirect_first_instance,
100 write_d3d12_special_constants,
101 )?;
102
103 Ok(Self {
104 module,
105 metadata_bind_group_layout,
106 src_bind_group_layout,
107 dst_bind_group_layout,
108 pipeline_layout,
109 pipeline,
110
111 free_indirect_entries: Mutex::new(rank::BUFFER_POOL, Vec::new()),
112 free_metadata_entries: Mutex::new(rank::BUFFER_POOL, Vec::new()),
113 })
114 }
115
116 pub(super) fn create_src_bind_group(
118 &self,
119 device: &dyn hal::DynDevice,
120 limits: &Limits,
121 buffer_size: u64,
122 buffer: &dyn hal::DynBuffer,
123 ) -> Result<Option<Box<dyn hal::DynBindGroup>>, DeviceError> {
124 let binding_size = calculate_src_buffer_binding_size(buffer_size, limits);
125 let Some(binding_size) = NonZeroU64::new(binding_size) else {
126 return Ok(None);
127 };
128 let hal_desc = hal::BindGroupDescriptor {
129 label: None,
130 layout: self.src_bind_group_layout.as_ref(),
131 entries: &[hal::BindGroupEntry {
132 binding: 0,
133 resource_index: 0,
134 count: 1,
135 }],
136 buffers: &[hal::BufferBinding::new_unchecked(buffer, 0, binding_size)],
138 samplers: &[],
139 textures: &[],
140 acceleration_structures: &[],
141 external_textures: &[],
142 };
143 unsafe {
144 device
145 .create_bind_group(&hal_desc)
146 .map(Some)
147 .map_err(DeviceError::from_hal)
148 }
149 }
150
151 fn acquire_dst_entry(
152 &self,
153 device: &dyn hal::DynDevice,
154 ) -> Result<BufferPoolEntry, hal::DeviceError> {
155 let mut free_buffers = self.free_indirect_entries.lock();
156 match free_buffers.pop() {
157 Some(buffer) => Ok(buffer),
158 None => {
159 let usage = wgt::BufferUses::INDIRECT | wgt::BufferUses::STORAGE_READ_WRITE;
160 create_buffer_and_bind_group(device, usage, self.dst_bind_group_layout.as_ref())
161 }
162 }
163 }
164
165 fn release_dst_entries(&self, entries: impl Iterator<Item = BufferPoolEntry>) {
166 self.free_indirect_entries.lock().extend(entries);
167 }
168
169 fn acquire_metadata_entry(
170 &self,
171 device: &dyn hal::DynDevice,
172 ) -> Result<BufferPoolEntry, hal::DeviceError> {
173 let mut free_buffers = self.free_metadata_entries.lock();
174 match free_buffers.pop() {
175 Some(buffer) => Ok(buffer),
176 None => {
177 let usage = wgt::BufferUses::COPY_DST | wgt::BufferUses::STORAGE_READ_ONLY;
178 create_buffer_and_bind_group(
179 device,
180 usage,
181 self.metadata_bind_group_layout.as_ref(),
182 )
183 }
184 }
185 }
186
187 fn release_metadata_entries(&self, entries: impl Iterator<Item = BufferPoolEntry>) {
188 self.free_metadata_entries.lock().extend(entries);
189 }
190
191 pub(crate) fn inject_validation_pass(
193 &self,
194 device: &Arc<Device>,
195 snatch_guard: &SnatchGuard,
196 resources: &mut DrawResources,
197 temp_resources: &mut Vec<TempResource>,
198 encoder: &mut dyn hal::DynCommandEncoder,
199 batcher: DrawBatcher,
200 ) -> Result<(), RenderPassErrorInner> {
201 let mut batches = batcher.batches;
202
203 if batches.is_empty() {
204 return Ok(());
205 }
206
207 let max_staging_buffer_size = 1 << 26; let mut staging_buffers = Vec::new();
210
211 let mut current_size = 0;
212 for batch in batches.values_mut() {
213 let data = batch.metadata();
214 let offset = if current_size + data.len() > max_staging_buffer_size {
215 let staging_buffer =
216 StagingBuffer::new(device, NonZeroU64::new(current_size as u64).unwrap())?;
217 staging_buffers.push(staging_buffer);
218 current_size = data.len();
219 0
220 } else {
221 let offset = current_size;
222 current_size += data.len();
223 offset as u64
224 };
225 batch.staging_buffer_index = staging_buffers.len();
226 batch.staging_buffer_offset = offset;
227 }
228 if current_size != 0 {
229 let staging_buffer =
230 StagingBuffer::new(device, NonZeroU64::new(current_size as u64).unwrap())?;
231 staging_buffers.push(staging_buffer);
232 }
233
234 for batch in batches.values() {
235 let data = batch.metadata();
236 let staging_buffer = &mut staging_buffers[batch.staging_buffer_index];
237 unsafe {
238 staging_buffer.write_with_offset(
239 data,
240 0,
241 batch.staging_buffer_offset as isize,
242 data.len(),
243 )
244 };
245 }
246
247 let staging_buffers: Vec<_> = staging_buffers
248 .into_iter()
249 .map(|buffer| buffer.flush())
250 .collect();
251
252 let mut current_metadata_entry = None;
253 for batch in batches.values_mut() {
254 let data = batch.metadata();
255 let (metadata_resource_index, metadata_buffer_offset) =
256 resources.get_metadata_subrange(data.len() as u64, &mut current_metadata_entry)?;
257 batch.metadata_resource_index = metadata_resource_index;
258 batch.metadata_buffer_offset = metadata_buffer_offset;
259 }
260
261 let buffer_barrier_scratch = &mut BufferBarrierScratch::new();
262 let unique_index_scratch = &mut UniqueIndexScratch::new();
263
264 BufferBarriers::new(buffer_barrier_scratch)
265 .extend(
266 batches
267 .values()
268 .map(|batch| batch.staging_buffer_index)
269 .unique(unique_index_scratch)
270 .map(|index| hal::BufferBarrier {
271 buffer: staging_buffers[index].raw(),
272 usage: hal::StateTransition {
273 from: wgt::BufferUses::MAP_WRITE,
274 to: wgt::BufferUses::COPY_SRC,
275 },
276 }),
277 )
278 .extend(
279 batches
280 .values()
281 .map(|batch| batch.metadata_resource_index)
282 .unique(unique_index_scratch)
283 .map(|index| hal::BufferBarrier {
284 buffer: resources.get_metadata_buffer(index),
285 usage: hal::StateTransition {
286 from: wgt::BufferUses::STORAGE_READ_ONLY,
287 to: wgt::BufferUses::COPY_DST,
288 },
289 }),
290 )
291 .encode(encoder);
292
293 for batch in batches.values() {
294 let data = batch.metadata();
295 let data_size = NonZeroU64::new(data.len() as u64).unwrap();
296
297 let staging_buffer = &staging_buffers[batch.staging_buffer_index];
298
299 let metadata_buffer = resources.get_metadata_buffer(batch.metadata_resource_index);
300
301 unsafe {
302 encoder.copy_buffer_to_buffer(
303 staging_buffer.raw(),
304 metadata_buffer,
305 &[hal::BufferCopy {
306 src_offset: batch.staging_buffer_offset,
307 dst_offset: batch.metadata_buffer_offset,
308 size: data_size,
309 }],
310 );
311 }
312 }
313
314 for staging_buffer in staging_buffers {
315 temp_resources.push(TempResource::StagingBuffer(staging_buffer));
316 }
317
318 BufferBarriers::new(buffer_barrier_scratch)
319 .extend(
320 batches
321 .values()
322 .map(|batch| batch.metadata_resource_index)
323 .unique(unique_index_scratch)
324 .map(|index| hal::BufferBarrier {
325 buffer: resources.get_metadata_buffer(index),
326 usage: hal::StateTransition {
327 from: wgt::BufferUses::COPY_DST,
328 to: wgt::BufferUses::STORAGE_READ_ONLY,
329 },
330 }),
331 )
332 .extend(
333 batches
334 .values()
335 .map(|batch| batch.dst_resource_index)
336 .unique(unique_index_scratch)
337 .map(|index| hal::BufferBarrier {
338 buffer: resources.get_dst_buffer(index),
339 usage: hal::StateTransition {
340 from: wgt::BufferUses::INDIRECT,
341 to: wgt::BufferUses::STORAGE_READ_WRITE,
342 },
343 }),
344 )
345 .encode(encoder);
346
347 let desc = hal::ComputePassDescriptor {
348 label: None,
349 timestamp_writes: None,
350 };
351 unsafe {
352 encoder.begin_compute_pass(&desc);
353 }
354 unsafe {
355 encoder.set_compute_pipeline(self.pipeline.as_ref());
356 }
357
358 for batch in batches.values() {
359 let pipeline_layout = self.pipeline_layout.as_ref();
360
361 let metadata_start =
362 (batch.metadata_buffer_offset / size_of::<MetadataEntry>() as u64) as u32;
363 let metadata_count = batch.entries.len() as u32;
364 unsafe {
365 encoder.set_immediates(pipeline_layout, 0, &[metadata_start, metadata_count]);
366 }
367
368 let metadata_bind_group =
369 resources.get_metadata_bind_group(batch.metadata_resource_index);
370 unsafe {
371 encoder.set_bind_group(pipeline_layout, 0, Some(metadata_bind_group), &[]);
372 }
373
374 batch.src_buffer.try_raw(snatch_guard)?;
376
377 let src_bind_group = batch
378 .src_buffer
379 .indirect_validation_bind_groups
380 .get(snatch_guard)
381 .unwrap()
382 .draw
383 .as_ref();
384 unsafe {
385 encoder.set_bind_group(
386 pipeline_layout,
387 1,
388 Some(src_bind_group),
389 &[batch.src_dynamic_offset as u32],
390 );
391 }
392
393 let dst_bind_group = resources.get_dst_bind_group(batch.dst_resource_index);
394 unsafe {
395 encoder.set_bind_group(pipeline_layout, 2, Some(dst_bind_group), &[]);
396 }
397
398 unsafe {
399 encoder.dispatch([(batch.entries.len() as u32).div_ceil(64), 1, 1]);
400 }
401 }
402
403 unsafe {
404 encoder.end_compute_pass();
405 }
406
407 BufferBarriers::new(buffer_barrier_scratch)
408 .extend(
409 batches
410 .values()
411 .map(|batch| batch.dst_resource_index)
412 .unique(unique_index_scratch)
413 .map(|index| hal::BufferBarrier {
414 buffer: resources.get_dst_buffer(index),
415 usage: hal::StateTransition {
416 from: wgt::BufferUses::STORAGE_READ_WRITE,
417 to: wgt::BufferUses::INDIRECT,
418 },
419 }),
420 )
421 .encode(encoder);
422
423 Ok(())
424 }
425
426 pub(super) fn dispose(self, device: &dyn hal::DynDevice) {
427 let Draw {
428 module,
429 metadata_bind_group_layout,
430 src_bind_group_layout,
431 dst_bind_group_layout,
432 pipeline_layout,
433 pipeline,
434
435 free_indirect_entries,
436 free_metadata_entries,
437 } = self;
438
439 for entry in free_indirect_entries.into_inner().drain(..) {
440 unsafe {
441 device.destroy_bind_group(entry.bind_group);
442 device.destroy_buffer(entry.buffer);
443 }
444 }
445
446 for entry in free_metadata_entries.into_inner().drain(..) {
447 unsafe {
448 device.destroy_bind_group(entry.bind_group);
449 device.destroy_buffer(entry.buffer);
450 }
451 }
452
453 unsafe {
454 device.destroy_compute_pipeline(pipeline);
455 device.destroy_pipeline_layout(pipeline_layout);
456 device.destroy_bind_group_layout(metadata_bind_group_layout);
457 device.destroy_bind_group_layout(src_bind_group_layout);
458 device.destroy_bind_group_layout(dst_bind_group_layout);
459 device.destroy_shader_module(module);
460 }
461 }
462}
463
464fn create_validation_module(
465 device: &dyn hal::DynDevice,
466) -> Result<Box<dyn hal::DynShaderModule>, CreateIndirectValidationPipelineError> {
467 let src = include_str!("./validate_draw.wgsl");
468
469 #[cfg(feature = "wgsl")]
470 let module = naga::front::wgsl::parse_str(src).map_err(|inner| {
471 CreateShaderModuleError::Parsing(naga::error::ShaderError {
472 source: src.to_string(),
473 label: None,
474 inner: Box::new(inner),
475 })
476 })?;
477 #[cfg(not(feature = "wgsl"))]
478 #[allow(clippy::diverging_sub_expression)]
479 let module = panic!("Indirect validation requires the wgsl feature flag to be enabled!");
480
481 let info = crate::device::create_validator(
482 wgt::Features::IMMEDIATES,
483 wgt::DownlevelFlags::empty(),
484 naga::valid::ValidationFlags::all(),
485 )
486 .validate(&module)
487 .map_err(|inner| {
488 CreateShaderModuleError::Validation(naga::error::ShaderError {
489 source: src.to_string(),
490 label: None,
491 inner: Box::new(inner),
492 })
493 })?;
494 let hal_shader = hal::ShaderInput::Naga(hal::NagaShader {
495 module: alloc::borrow::Cow::Owned(module),
496 info,
497 debug_source: None,
498 });
499 let hal_desc = hal::ShaderModuleDescriptor {
500 label: None,
501 runtime_checks: wgt::ShaderRuntimeChecks::unchecked(),
502 };
503 let module = unsafe { device.create_shader_module(&hal_desc, hal_shader) }.map_err(
504 |error| match error {
505 hal::ShaderError::Device(error) => {
506 CreateShaderModuleError::Device(DeviceError::from_hal(error))
507 }
508 hal::ShaderError::Compilation(ref msg) => {
509 log::error!("Shader error: {msg}");
510 CreateShaderModuleError::Generation
511 }
512 },
513 )?;
514
515 Ok(module)
516}
517
518fn create_validation_pipeline(
519 device: &dyn hal::DynDevice,
520 module: &dyn hal::DynShaderModule,
521 pipeline_layout: &dyn hal::DynPipelineLayout,
522 supports_indirect_first_instance: bool,
523 write_d3d12_special_constants: bool,
524) -> Result<Box<dyn hal::DynComputePipeline>, CreateIndirectValidationPipelineError> {
525 let pipeline_desc = hal::ComputePipelineDescriptor {
526 label: None,
527 layout: pipeline_layout,
528 stage: hal::ProgrammableStage {
529 module,
530 entry_point: "main",
531 constants: &hashbrown::HashMap::from([
532 (
533 "supports_indirect_first_instance".to_string(),
534 f64::from(supports_indirect_first_instance),
535 ),
536 (
537 "write_d3d12_special_constants".to_string(),
538 f64::from(write_d3d12_special_constants),
539 ),
540 ]),
541 zero_initialize_workgroup_memory: false,
542 },
543 cache: None,
544 };
545 let pipeline =
546 unsafe { device.create_compute_pipeline(&pipeline_desc) }.map_err(|err| match err {
547 hal::PipelineError::Device(error) => {
548 CreateComputePipelineError::Device(DeviceError::from_hal(error))
549 }
550 hal::PipelineError::Linkage(_stages, msg) => CreateComputePipelineError::Internal(msg),
551 hal::PipelineError::EntryPoint(_stage) => CreateComputePipelineError::Internal(
552 crate::device::ENTRYPOINT_FAILURE_ERROR.to_string(),
553 ),
554 hal::PipelineError::PipelineConstants(_, error) => {
555 CreateComputePipelineError::PipelineConstants(error)
556 }
557 })?;
558
559 Ok(pipeline)
560}
561
562fn create_bind_group_layout(
563 device: &dyn hal::DynDevice,
564 read_only: bool,
565 has_dynamic_offset: bool,
566 min_binding_size: wgt::BufferSize,
567) -> Result<Box<dyn hal::DynBindGroupLayout>, CreateIndirectValidationPipelineError> {
568 let bind_group_layout_desc = hal::BindGroupLayoutDescriptor {
569 label: None,
570 flags: hal::BindGroupLayoutFlags::empty(),
571 entries: &[wgt::BindGroupLayoutEntry {
572 binding: 0,
573 visibility: wgt::ShaderStages::COMPUTE,
574 ty: wgt::BindingType::Buffer {
575 ty: wgt::BufferBindingType::Storage { read_only },
576 has_dynamic_offset,
577 min_binding_size: Some(min_binding_size),
578 },
579 count: None,
580 }],
581 };
582 let bind_group_layout = unsafe {
583 device
584 .create_bind_group_layout(&bind_group_layout_desc)
585 .map_err(DeviceError::from_hal)?
586 };
587
588 Ok(bind_group_layout)
589}
590
591fn calculate_src_buffer_binding_size(buffer_size: u64, limits: &Limits) -> u64 {
593 let max_storage_buffer_binding_size = limits.max_storage_buffer_binding_size as u64;
594 let min_storage_buffer_offset_alignment = limits.min_storage_buffer_offset_alignment as u64;
595
596 if buffer_size <= max_storage_buffer_binding_size {
597 buffer_size
598 } else {
599 let buffer_rem = buffer_size % min_storage_buffer_offset_alignment;
600 let binding_rem = max_storage_buffer_binding_size % min_storage_buffer_offset_alignment;
601
602 if buffer_rem <= binding_rem {
605 max_storage_buffer_binding_size - binding_rem + buffer_rem
606 }
607 else {
609 max_storage_buffer_binding_size - binding_rem - min_storage_buffer_offset_alignment
610 + buffer_rem
611 }
612 }
613}
614
615fn calculate_src_offsets(buffer_size: u64, limits: &Limits, offset: u64) -> (u64, u64) {
617 let binding_size = calculate_src_buffer_binding_size(buffer_size, limits);
618
619 let min_storage_buffer_offset_alignment = limits.min_storage_buffer_offset_alignment as u64;
620
621 let chunk_adjustment = match min_storage_buffer_offset_alignment {
622 4 => 0,
624 8 => 2,
628 16.. => 1,
632 _ => unreachable!(),
633 };
634
635 let chunks = binding_size / min_storage_buffer_offset_alignment;
636 let dynamic_offset_stride =
637 chunks.saturating_sub(chunk_adjustment) * min_storage_buffer_offset_alignment;
638
639 if dynamic_offset_stride == 0 {
640 return (0, offset);
641 }
642
643 let max_dynamic_offset = buffer_size - binding_size;
644 let max_dynamic_offset_index = max_dynamic_offset / dynamic_offset_stride;
645
646 let src_dynamic_offset_index = offset / dynamic_offset_stride;
647
648 let src_dynamic_offset =
649 src_dynamic_offset_index.min(max_dynamic_offset_index) * dynamic_offset_stride;
650 let src_offset = offset - src_dynamic_offset;
651
652 (src_dynamic_offset, src_offset)
653}
654
655#[derive(Debug)]
656struct BufferPoolEntry {
657 buffer: Box<dyn hal::DynBuffer>,
658 bind_group: Box<dyn hal::DynBindGroup>,
659}
660
661fn create_buffer_and_bind_group(
662 device: &dyn hal::DynDevice,
663 usage: wgt::BufferUses,
664 bind_group_layout: &dyn hal::DynBindGroupLayout,
665) -> Result<BufferPoolEntry, hal::DeviceError> {
666 let buffer_desc = hal::BufferDescriptor {
667 label: None,
668 size: BUFFER_SIZE.get(),
669 usage,
670 memory_flags: hal::MemoryFlags::empty(),
671 };
672 let buffer = unsafe { device.create_buffer(&buffer_desc) }?;
673 let bind_group_desc = hal::BindGroupDescriptor {
674 label: None,
675 layout: bind_group_layout,
676 entries: &[hal::BindGroupEntry {
677 binding: 0,
678 resource_index: 0,
679 count: 1,
680 }],
681 buffers: &[hal::BufferBinding::new_unchecked(
683 buffer.as_ref(),
684 0,
685 BUFFER_SIZE,
686 )],
687 samplers: &[],
688 textures: &[],
689 acceleration_structures: &[],
690 external_textures: &[],
691 };
692 let bind_group = unsafe { device.create_bind_group(&bind_group_desc) }?;
693 Ok(BufferPoolEntry { buffer, bind_group })
694}
695
696#[derive(Clone)]
697struct CurrentEntry {
698 index: usize,
699 offset: u64,
700}
701
702pub(crate) struct DrawResources {
704 device: Arc<Device>,
705 dst_entries: Vec<BufferPoolEntry>,
706 metadata_entries: Vec<BufferPoolEntry>,
707}
708
709impl Drop for DrawResources {
710 fn drop(&mut self) {
711 if let Some(ref indirect_validation) = self.device.indirect_validation {
712 let indirect_draw_validation = &indirect_validation.draw;
713 indirect_draw_validation.release_dst_entries(self.dst_entries.drain(..));
714 indirect_draw_validation.release_metadata_entries(self.metadata_entries.drain(..));
715 }
716 }
717}
718
719impl DrawResources {
720 pub(crate) fn new(device: Arc<Device>) -> Self {
721 DrawResources {
722 device,
723 dst_entries: Vec::new(),
724 metadata_entries: Vec::new(),
725 }
726 }
727
728 pub(crate) fn get_dst_buffer(&self, index: usize) -> &dyn hal::DynBuffer {
729 self.dst_entries.get(index).unwrap().buffer.as_ref()
730 }
731
732 fn get_dst_bind_group(&self, index: usize) -> &dyn hal::DynBindGroup {
733 self.dst_entries.get(index).unwrap().bind_group.as_ref()
734 }
735
736 fn get_metadata_buffer(&self, index: usize) -> &dyn hal::DynBuffer {
737 self.metadata_entries.get(index).unwrap().buffer.as_ref()
738 }
739
740 fn get_metadata_bind_group(&self, index: usize) -> &dyn hal::DynBindGroup {
741 self.metadata_entries
742 .get(index)
743 .unwrap()
744 .bind_group
745 .as_ref()
746 }
747
748 fn get_dst_subrange(
749 &mut self,
750 size: u64,
751 current_entry: &mut Option<CurrentEntry>,
752 ) -> Result<(usize, u64), DeviceError> {
753 let indirect_draw_validation = &self.device.indirect_validation.as_ref().unwrap().draw;
754 let ensure_entry = |index: usize| {
755 if self.dst_entries.len() <= index {
756 let entry = indirect_draw_validation.acquire_dst_entry(self.device.raw())?;
757 self.dst_entries.push(entry);
758 }
759 Ok(())
760 };
761 let entry_data = Self::get_subrange_impl(ensure_entry, current_entry, size)?;
762 Ok((entry_data.index, entry_data.offset))
763 }
764
765 fn get_metadata_subrange(
766 &mut self,
767 size: u64,
768 current_entry: &mut Option<CurrentEntry>,
769 ) -> Result<(usize, u64), DeviceError> {
770 let indirect_draw_validation = &self.device.indirect_validation.as_ref().unwrap().draw;
771 let ensure_entry = |index: usize| {
772 if self.metadata_entries.len() <= index {
773 let entry = indirect_draw_validation.acquire_metadata_entry(self.device.raw())?;
774 self.metadata_entries.push(entry);
775 }
776 Ok(())
777 };
778 let entry_data = Self::get_subrange_impl(ensure_entry, current_entry, size)?;
779 Ok((entry_data.index, entry_data.offset))
780 }
781
782 fn get_subrange_impl(
783 ensure_entry: impl FnOnce(usize) -> Result<(), hal::DeviceError>,
784 current_entry: &mut Option<CurrentEntry>,
785 size: u64,
786 ) -> Result<CurrentEntry, DeviceError> {
787 let index = if let Some(current_entry) = current_entry.as_mut() {
788 if current_entry.offset + size <= BUFFER_SIZE.get() {
789 let entry_data = current_entry.clone();
790 current_entry.offset += size;
791 return Ok(entry_data);
792 } else {
793 current_entry.index + 1
794 }
795 } else {
796 0
797 };
798
799 ensure_entry(index).map_err(DeviceError::from_hal)?;
800
801 let entry_data = CurrentEntry { index, offset: 0 };
802
803 *current_entry = Some(CurrentEntry {
804 index,
805 offset: size,
806 });
807
808 Ok(entry_data)
809 }
810}
811
812#[repr(C)]
814struct MetadataEntry {
815 src_offset: u32,
816 dst_offset: u32,
817 vertex_or_index_limit: u32,
818 instance_limit: u32,
819}
820
821impl MetadataEntry {
822 fn new(
823 indexed: bool,
824 src_offset: u64,
825 dst_offset: u64,
826 vertex_or_index_limit: u64,
827 instance_limit: u64,
828 ) -> Self {
829 debug_assert_eq!(
830 4,
831 size_of_val(&Limits::default().max_storage_buffer_binding_size)
832 );
833
834 let src_offset = src_offset as u32; let src_offset = src_offset / 4; let src_offset = src_offset | ((indexed as u32) << 31);
840
841 let max_limit = u32::MAX as u64 + u32::MAX as u64; let vertex_or_index_limit = vertex_or_index_limit.min(max_limit);
845 let vertex_or_index_limit_bit_32 = (vertex_or_index_limit >> 32) as u32; let vertex_or_index_limit = vertex_or_index_limit as u32; let instance_limit = instance_limit.min(max_limit);
849 let instance_limit_bit_32 = (instance_limit >> 32) as u32; let instance_limit = instance_limit as u32; let dst_offset = dst_offset as u32; let dst_offset = dst_offset / 4; let dst_offset =
859 dst_offset | (vertex_or_index_limit_bit_32 << 30) | (instance_limit_bit_32 << 31);
860
861 Self {
862 src_offset,
863 dst_offset,
864 vertex_or_index_limit,
865 instance_limit,
866 }
867 }
868}
869
870struct DrawIndirectValidationBatch {
871 src_buffer: Arc<crate::resource::Buffer>,
872 src_dynamic_offset: u64,
873 dst_resource_index: usize,
874 entries: Vec<MetadataEntry>,
875
876 staging_buffer_index: usize,
877 staging_buffer_offset: u64,
878 metadata_resource_index: usize,
879 metadata_buffer_offset: u64,
880}
881
882impl DrawIndirectValidationBatch {
883 fn metadata(&self) -> &[u8] {
885 unsafe {
886 core::slice::from_raw_parts(
887 self.entries.as_ptr().cast::<u8>(),
888 self.entries.len() * size_of::<MetadataEntry>(),
889 )
890 }
891 }
892}
893
894pub(crate) struct DrawBatcher {
896 batches: FastHashMap<(TrackerIndex, u64, usize), DrawIndirectValidationBatch>,
897 current_dst_entry: Option<CurrentEntry>,
898}
899
900impl DrawBatcher {
901 pub(crate) fn new() -> Self {
902 Self {
903 batches: FastHashMap::default(),
904 current_dst_entry: None,
905 }
906 }
907
908 pub(crate) fn add<'a>(
913 &mut self,
914 indirect_draw_validation_resources: &'a mut DrawResources,
915 device: &Device,
916 src_buffer: &Arc<crate::resource::Buffer>,
917 offset: u64,
918 family: crate::command::DrawCommandFamily,
919 vertex_or_index_limit: u64,
920 instance_limit: u64,
921 ) -> Result<(usize, u64), DeviceError> {
922 let extra = if device.backend() == wgt::Backend::Dx12 {
924 3 * size_of::<u32>() as u64
925 } else {
926 0
927 };
928 let stride = extra + crate::command::get_stride_of_indirect_args(family);
929
930 let (dst_resource_index, dst_offset) = indirect_draw_validation_resources
931 .get_dst_subrange(stride, &mut self.current_dst_entry)?;
932
933 let buffer_size = src_buffer.size;
934 let limits = device.adapter.limits();
935 let (src_dynamic_offset, src_offset) = calculate_src_offsets(buffer_size, &limits, offset);
936
937 let src_buffer_tracker_index = src_buffer.tracker_index();
938
939 let entry = MetadataEntry::new(
940 family == crate::command::DrawCommandFamily::DrawIndexed,
941 src_offset,
942 dst_offset,
943 vertex_or_index_limit,
944 instance_limit,
945 );
946
947 match self.batches.entry((
948 src_buffer_tracker_index,
949 src_dynamic_offset,
950 dst_resource_index,
951 )) {
952 hashbrown::hash_map::Entry::Occupied(mut occupied_entry) => {
953 occupied_entry.get_mut().entries.push(entry)
954 }
955 hashbrown::hash_map::Entry::Vacant(vacant_entry) => {
956 vacant_entry.insert(DrawIndirectValidationBatch {
957 src_buffer: src_buffer.clone(),
958 src_dynamic_offset,
959 dst_resource_index,
960 entries: vec![entry],
961
962 staging_buffer_index: 0,
964 staging_buffer_offset: 0,
965 metadata_resource_index: 0,
966 metadata_buffer_offset: 0,
967 });
968 }
969 }
970
971 Ok((dst_resource_index, dst_offset))
972 }
973}