1use super::{
2 utils::{BufferBarrierScratch, BufferBarriers, UniqueIndexExt as _, UniqueIndexScratch},
3 CreateIndirectValidationPipelineError,
4};
5use crate::{
6 command::RenderPassErrorInner,
7 device::{queue::TempResource, Device, DeviceError},
8 lock::{rank, Mutex},
9 pipeline::{CreateComputePipelineError, CreateShaderModuleError},
10 resource::{RawResourceAccess as _, StagingBuffer, Trackable},
11 snatch::SnatchGuard,
12 track::TrackerIndex,
13 FastHashMap,
14};
15use alloc::{boxed::Box, string::ToString, sync::Arc, vec, vec::Vec};
16use core::{
17 mem::{size_of, size_of_val},
18 num::NonZeroU64,
19};
20use wgt::Limits;
21
22const BUFFER_SIZE: wgt::BufferSize = unsafe { wgt::BufferSize::new_unchecked(1_048_560) };
37
38#[derive(Debug)]
50pub(crate) struct Draw {
51 module: Box<dyn hal::DynShaderModule>,
52 metadata_bind_group_layout: Box<dyn hal::DynBindGroupLayout>,
53 src_bind_group_layout: Box<dyn hal::DynBindGroupLayout>,
54 dst_bind_group_layout: Box<dyn hal::DynBindGroupLayout>,
55 pipeline_layout: Box<dyn hal::DynPipelineLayout>,
56 pipeline: Box<dyn hal::DynComputePipeline>,
57
58 free_indirect_entries: Mutex<Vec<BufferPoolEntry>>,
59 free_metadata_entries: Mutex<Vec<BufferPoolEntry>>,
60}
61
62impl Draw {
63 pub(super) fn new(
64 device: &dyn hal::DynDevice,
65 required_features: &wgt::Features,
66 backend: wgt::Backend,
67 ) -> Result<Self, CreateIndirectValidationPipelineError> {
68 let module = create_validation_module(device)?;
69
70 let metadata_bind_group_layout =
71 create_bind_group_layout(device, true, false, BUFFER_SIZE)?;
72 let src_bind_group_layout =
73 create_bind_group_layout(device, true, true, wgt::BufferSize::new(4 * 4).unwrap())?;
74 let dst_bind_group_layout = create_bind_group_layout(device, false, false, BUFFER_SIZE)?;
75
76 let pipeline_layout_desc = hal::PipelineLayoutDescriptor {
77 label: None,
78 flags: hal::PipelineLayoutFlags::empty(),
79 bind_group_layouts: &[
80 metadata_bind_group_layout.as_ref(),
81 src_bind_group_layout.as_ref(),
82 dst_bind_group_layout.as_ref(),
83 ],
84 immediates_ranges: &[wgt::ImmediateRange {
85 stages: wgt::ShaderStages::COMPUTE,
86 range: 0..8,
87 }],
88 };
89 let pipeline_layout = unsafe {
90 device
91 .create_pipeline_layout(&pipeline_layout_desc)
92 .map_err(DeviceError::from_hal)?
93 };
94
95 let supports_indirect_first_instance =
96 required_features.contains(wgt::Features::INDIRECT_FIRST_INSTANCE);
97 let write_d3d12_special_constants = backend == wgt::Backend::Dx12;
98 let pipeline = create_validation_pipeline(
99 device,
100 module.as_ref(),
101 pipeline_layout.as_ref(),
102 supports_indirect_first_instance,
103 write_d3d12_special_constants,
104 )?;
105
106 Ok(Self {
107 module,
108 metadata_bind_group_layout,
109 src_bind_group_layout,
110 dst_bind_group_layout,
111 pipeline_layout,
112 pipeline,
113
114 free_indirect_entries: Mutex::new(rank::BUFFER_POOL, Vec::new()),
115 free_metadata_entries: Mutex::new(rank::BUFFER_POOL, Vec::new()),
116 })
117 }
118
119 pub(super) fn create_src_bind_group(
121 &self,
122 device: &dyn hal::DynDevice,
123 limits: &Limits,
124 buffer_size: u64,
125 buffer: &dyn hal::DynBuffer,
126 ) -> Result<Option<Box<dyn hal::DynBindGroup>>, DeviceError> {
127 let binding_size = calculate_src_buffer_binding_size(buffer_size, limits);
128 let Some(binding_size) = NonZeroU64::new(binding_size) else {
129 return Ok(None);
130 };
131 let hal_desc = hal::BindGroupDescriptor {
132 label: None,
133 layout: self.src_bind_group_layout.as_ref(),
134 entries: &[hal::BindGroupEntry {
135 binding: 0,
136 resource_index: 0,
137 count: 1,
138 }],
139 buffers: &[hal::BufferBinding::new_unchecked(buffer, 0, binding_size)],
141 samplers: &[],
142 textures: &[],
143 acceleration_structures: &[],
144 external_textures: &[],
145 };
146 unsafe {
147 device
148 .create_bind_group(&hal_desc)
149 .map(Some)
150 .map_err(DeviceError::from_hal)
151 }
152 }
153
154 fn acquire_dst_entry(
155 &self,
156 device: &dyn hal::DynDevice,
157 ) -> Result<BufferPoolEntry, hal::DeviceError> {
158 let mut free_buffers = self.free_indirect_entries.lock();
159 match free_buffers.pop() {
160 Some(buffer) => Ok(buffer),
161 None => {
162 let usage = wgt::BufferUses::INDIRECT | wgt::BufferUses::STORAGE_READ_WRITE;
163 create_buffer_and_bind_group(device, usage, self.dst_bind_group_layout.as_ref())
164 }
165 }
166 }
167
168 fn release_dst_entries(&self, entries: impl Iterator<Item = BufferPoolEntry>) {
169 self.free_indirect_entries.lock().extend(entries);
170 }
171
172 fn acquire_metadata_entry(
173 &self,
174 device: &dyn hal::DynDevice,
175 ) -> Result<BufferPoolEntry, hal::DeviceError> {
176 let mut free_buffers = self.free_metadata_entries.lock();
177 match free_buffers.pop() {
178 Some(buffer) => Ok(buffer),
179 None => {
180 let usage = wgt::BufferUses::COPY_DST | wgt::BufferUses::STORAGE_READ_ONLY;
181 create_buffer_and_bind_group(
182 device,
183 usage,
184 self.metadata_bind_group_layout.as_ref(),
185 )
186 }
187 }
188 }
189
190 fn release_metadata_entries(&self, entries: impl Iterator<Item = BufferPoolEntry>) {
191 self.free_metadata_entries.lock().extend(entries);
192 }
193
194 pub(crate) fn inject_validation_pass(
196 &self,
197 device: &Arc<Device>,
198 snatch_guard: &SnatchGuard,
199 resources: &mut DrawResources,
200 temp_resources: &mut Vec<TempResource>,
201 encoder: &mut dyn hal::DynCommandEncoder,
202 batcher: DrawBatcher,
203 ) -> Result<(), RenderPassErrorInner> {
204 let mut batches = batcher.batches;
205
206 if batches.is_empty() {
207 return Ok(());
208 }
209
210 let max_staging_buffer_size = 1 << 26; let mut staging_buffers = Vec::new();
213
214 let mut current_size = 0;
215 for batch in batches.values_mut() {
216 let data = batch.metadata();
217 let offset = if current_size + data.len() > max_staging_buffer_size {
218 let staging_buffer =
219 StagingBuffer::new(device, NonZeroU64::new(current_size as u64).unwrap())?;
220 staging_buffers.push(staging_buffer);
221 current_size = data.len();
222 0
223 } else {
224 let offset = current_size;
225 current_size += data.len();
226 offset as u64
227 };
228 batch.staging_buffer_index = staging_buffers.len();
229 batch.staging_buffer_offset = offset;
230 }
231 if current_size != 0 {
232 let staging_buffer =
233 StagingBuffer::new(device, NonZeroU64::new(current_size as u64).unwrap())?;
234 staging_buffers.push(staging_buffer);
235 }
236
237 for batch in batches.values() {
238 let data = batch.metadata();
239 let staging_buffer = &mut staging_buffers[batch.staging_buffer_index];
240 unsafe {
241 staging_buffer.write_with_offset(
242 data,
243 0,
244 batch.staging_buffer_offset as isize,
245 data.len(),
246 )
247 };
248 }
249
250 let staging_buffers: Vec<_> = staging_buffers
251 .into_iter()
252 .map(|buffer| buffer.flush())
253 .collect();
254
255 let mut current_metadata_entry = None;
256 for batch in batches.values_mut() {
257 let data = batch.metadata();
258 let (metadata_resource_index, metadata_buffer_offset) =
259 resources.get_metadata_subrange(data.len() as u64, &mut current_metadata_entry)?;
260 batch.metadata_resource_index = metadata_resource_index;
261 batch.metadata_buffer_offset = metadata_buffer_offset;
262 }
263
264 let buffer_barrier_scratch = &mut BufferBarrierScratch::new();
265 let unique_index_scratch = &mut UniqueIndexScratch::new();
266
267 BufferBarriers::new(buffer_barrier_scratch)
268 .extend(
269 batches
270 .values()
271 .map(|batch| batch.staging_buffer_index)
272 .unique(unique_index_scratch)
273 .map(|index| hal::BufferBarrier {
274 buffer: staging_buffers[index].raw(),
275 usage: hal::StateTransition {
276 from: wgt::BufferUses::MAP_WRITE,
277 to: wgt::BufferUses::COPY_SRC,
278 },
279 }),
280 )
281 .extend(
282 batches
283 .values()
284 .map(|batch| batch.metadata_resource_index)
285 .unique(unique_index_scratch)
286 .map(|index| hal::BufferBarrier {
287 buffer: resources.get_metadata_buffer(index),
288 usage: hal::StateTransition {
289 from: wgt::BufferUses::STORAGE_READ_ONLY,
290 to: wgt::BufferUses::COPY_DST,
291 },
292 }),
293 )
294 .encode(encoder);
295
296 for batch in batches.values() {
297 let data = batch.metadata();
298 let data_size = NonZeroU64::new(data.len() as u64).unwrap();
299
300 let staging_buffer = &staging_buffers[batch.staging_buffer_index];
301
302 let metadata_buffer = resources.get_metadata_buffer(batch.metadata_resource_index);
303
304 unsafe {
305 encoder.copy_buffer_to_buffer(
306 staging_buffer.raw(),
307 metadata_buffer,
308 &[hal::BufferCopy {
309 src_offset: batch.staging_buffer_offset,
310 dst_offset: batch.metadata_buffer_offset,
311 size: data_size,
312 }],
313 );
314 }
315 }
316
317 for staging_buffer in staging_buffers {
318 temp_resources.push(TempResource::StagingBuffer(staging_buffer));
319 }
320
321 BufferBarriers::new(buffer_barrier_scratch)
322 .extend(
323 batches
324 .values()
325 .map(|batch| batch.metadata_resource_index)
326 .unique(unique_index_scratch)
327 .map(|index| hal::BufferBarrier {
328 buffer: resources.get_metadata_buffer(index),
329 usage: hal::StateTransition {
330 from: wgt::BufferUses::COPY_DST,
331 to: wgt::BufferUses::STORAGE_READ_ONLY,
332 },
333 }),
334 )
335 .extend(
336 batches
337 .values()
338 .map(|batch| batch.dst_resource_index)
339 .unique(unique_index_scratch)
340 .map(|index| hal::BufferBarrier {
341 buffer: resources.get_dst_buffer(index),
342 usage: hal::StateTransition {
343 from: wgt::BufferUses::INDIRECT,
344 to: wgt::BufferUses::STORAGE_READ_WRITE,
345 },
346 }),
347 )
348 .encode(encoder);
349
350 let desc = hal::ComputePassDescriptor {
351 label: None,
352 timestamp_writes: None,
353 };
354 unsafe {
355 encoder.begin_compute_pass(&desc);
356 }
357 unsafe {
358 encoder.set_compute_pipeline(self.pipeline.as_ref());
359 }
360
361 for batch in batches.values() {
362 let pipeline_layout = self.pipeline_layout.as_ref();
363
364 let metadata_start =
365 (batch.metadata_buffer_offset / size_of::<MetadataEntry>() as u64) as u32;
366 let metadata_count = batch.entries.len() as u32;
367 unsafe {
368 encoder.set_immediates(
369 pipeline_layout,
370 wgt::ShaderStages::COMPUTE,
371 0,
372 &[metadata_start, metadata_count],
373 );
374 }
375
376 let metadata_bind_group =
377 resources.get_metadata_bind_group(batch.metadata_resource_index);
378 unsafe {
379 encoder.set_bind_group(pipeline_layout, 0, Some(metadata_bind_group), &[]);
380 }
381
382 batch.src_buffer.try_raw(snatch_guard)?;
384
385 let src_bind_group = batch
386 .src_buffer
387 .indirect_validation_bind_groups
388 .get(snatch_guard)
389 .unwrap()
390 .draw
391 .as_ref();
392 unsafe {
393 encoder.set_bind_group(
394 pipeline_layout,
395 1,
396 Some(src_bind_group),
397 &[batch.src_dynamic_offset as u32],
398 );
399 }
400
401 let dst_bind_group = resources.get_dst_bind_group(batch.dst_resource_index);
402 unsafe {
403 encoder.set_bind_group(pipeline_layout, 2, Some(dst_bind_group), &[]);
404 }
405
406 unsafe {
407 encoder.dispatch([(batch.entries.len() as u32).div_ceil(64), 1, 1]);
408 }
409 }
410
411 unsafe {
412 encoder.end_compute_pass();
413 }
414
415 BufferBarriers::new(buffer_barrier_scratch)
416 .extend(
417 batches
418 .values()
419 .map(|batch| batch.dst_resource_index)
420 .unique(unique_index_scratch)
421 .map(|index| hal::BufferBarrier {
422 buffer: resources.get_dst_buffer(index),
423 usage: hal::StateTransition {
424 from: wgt::BufferUses::STORAGE_READ_WRITE,
425 to: wgt::BufferUses::INDIRECT,
426 },
427 }),
428 )
429 .encode(encoder);
430
431 Ok(())
432 }
433
434 pub(super) fn dispose(self, device: &dyn hal::DynDevice) {
435 let Draw {
436 module,
437 metadata_bind_group_layout,
438 src_bind_group_layout,
439 dst_bind_group_layout,
440 pipeline_layout,
441 pipeline,
442
443 free_indirect_entries,
444 free_metadata_entries,
445 } = self;
446
447 for entry in free_indirect_entries.into_inner().drain(..) {
448 unsafe {
449 device.destroy_bind_group(entry.bind_group);
450 device.destroy_buffer(entry.buffer);
451 }
452 }
453
454 for entry in free_metadata_entries.into_inner().drain(..) {
455 unsafe {
456 device.destroy_bind_group(entry.bind_group);
457 device.destroy_buffer(entry.buffer);
458 }
459 }
460
461 unsafe {
462 device.destroy_compute_pipeline(pipeline);
463 device.destroy_pipeline_layout(pipeline_layout);
464 device.destroy_bind_group_layout(metadata_bind_group_layout);
465 device.destroy_bind_group_layout(src_bind_group_layout);
466 device.destroy_bind_group_layout(dst_bind_group_layout);
467 device.destroy_shader_module(module);
468 }
469 }
470}
471
472fn create_validation_module(
473 device: &dyn hal::DynDevice,
474) -> Result<Box<dyn hal::DynShaderModule>, CreateIndirectValidationPipelineError> {
475 let src = include_str!("./validate_draw.wgsl");
476
477 #[cfg(feature = "wgsl")]
478 let module = naga::front::wgsl::parse_str(src).map_err(|inner| {
479 CreateShaderModuleError::Parsing(naga::error::ShaderError {
480 source: src.to_string(),
481 label: None,
482 inner: Box::new(inner),
483 })
484 })?;
485 #[cfg(not(feature = "wgsl"))]
486 #[allow(clippy::diverging_sub_expression)]
487 let module = panic!("Indirect validation requires the wgsl feature flag to be enabled!");
488
489 let info = crate::device::create_validator(
490 wgt::Features::IMMEDIATES,
491 wgt::DownlevelFlags::empty(),
492 naga::valid::ValidationFlags::all(),
493 )
494 .validate(&module)
495 .map_err(|inner| {
496 CreateShaderModuleError::Validation(naga::error::ShaderError {
497 source: src.to_string(),
498 label: None,
499 inner: Box::new(inner),
500 })
501 })?;
502 let hal_shader = hal::ShaderInput::Naga(hal::NagaShader {
503 module: alloc::borrow::Cow::Owned(module),
504 info,
505 debug_source: None,
506 });
507 let hal_desc = hal::ShaderModuleDescriptor {
508 label: None,
509 runtime_checks: wgt::ShaderRuntimeChecks::unchecked(),
510 };
511 let module = unsafe { device.create_shader_module(&hal_desc, hal_shader) }.map_err(
512 |error| match error {
513 hal::ShaderError::Device(error) => {
514 CreateShaderModuleError::Device(DeviceError::from_hal(error))
515 }
516 hal::ShaderError::Compilation(ref msg) => {
517 log::error!("Shader error: {msg}");
518 CreateShaderModuleError::Generation
519 }
520 },
521 )?;
522
523 Ok(module)
524}
525
526fn create_validation_pipeline(
527 device: &dyn hal::DynDevice,
528 module: &dyn hal::DynShaderModule,
529 pipeline_layout: &dyn hal::DynPipelineLayout,
530 supports_indirect_first_instance: bool,
531 write_d3d12_special_constants: bool,
532) -> Result<Box<dyn hal::DynComputePipeline>, CreateIndirectValidationPipelineError> {
533 let pipeline_desc = hal::ComputePipelineDescriptor {
534 label: None,
535 layout: pipeline_layout,
536 stage: hal::ProgrammableStage {
537 module,
538 entry_point: "main",
539 constants: &hashbrown::HashMap::from([
540 (
541 "supports_indirect_first_instance".to_string(),
542 f64::from(supports_indirect_first_instance),
543 ),
544 (
545 "write_d3d12_special_constants".to_string(),
546 f64::from(write_d3d12_special_constants),
547 ),
548 ]),
549 zero_initialize_workgroup_memory: false,
550 },
551 cache: None,
552 };
553 let pipeline =
554 unsafe { device.create_compute_pipeline(&pipeline_desc) }.map_err(|err| match err {
555 hal::PipelineError::Device(error) => {
556 CreateComputePipelineError::Device(DeviceError::from_hal(error))
557 }
558 hal::PipelineError::Linkage(_stages, msg) => CreateComputePipelineError::Internal(msg),
559 hal::PipelineError::EntryPoint(_stage) => CreateComputePipelineError::Internal(
560 crate::device::ENTRYPOINT_FAILURE_ERROR.to_string(),
561 ),
562 hal::PipelineError::PipelineConstants(_, error) => {
563 CreateComputePipelineError::PipelineConstants(error)
564 }
565 })?;
566
567 Ok(pipeline)
568}
569
570fn create_bind_group_layout(
571 device: &dyn hal::DynDevice,
572 read_only: bool,
573 has_dynamic_offset: bool,
574 min_binding_size: wgt::BufferSize,
575) -> Result<Box<dyn hal::DynBindGroupLayout>, CreateIndirectValidationPipelineError> {
576 let bind_group_layout_desc = hal::BindGroupLayoutDescriptor {
577 label: None,
578 flags: hal::BindGroupLayoutFlags::empty(),
579 entries: &[wgt::BindGroupLayoutEntry {
580 binding: 0,
581 visibility: wgt::ShaderStages::COMPUTE,
582 ty: wgt::BindingType::Buffer {
583 ty: wgt::BufferBindingType::Storage { read_only },
584 has_dynamic_offset,
585 min_binding_size: Some(min_binding_size),
586 },
587 count: None,
588 }],
589 };
590 let bind_group_layout = unsafe {
591 device
592 .create_bind_group_layout(&bind_group_layout_desc)
593 .map_err(DeviceError::from_hal)?
594 };
595
596 Ok(bind_group_layout)
597}
598
599fn calculate_src_buffer_binding_size(buffer_size: u64, limits: &Limits) -> u64 {
601 let max_storage_buffer_binding_size = limits.max_storage_buffer_binding_size as u64;
602 let min_storage_buffer_offset_alignment = limits.min_storage_buffer_offset_alignment as u64;
603
604 if buffer_size <= max_storage_buffer_binding_size {
605 buffer_size
606 } else {
607 let buffer_rem = buffer_size % min_storage_buffer_offset_alignment;
608 let binding_rem = max_storage_buffer_binding_size % min_storage_buffer_offset_alignment;
609
610 if buffer_rem <= binding_rem {
613 max_storage_buffer_binding_size - binding_rem + buffer_rem
614 }
615 else {
617 max_storage_buffer_binding_size - binding_rem - min_storage_buffer_offset_alignment
618 + buffer_rem
619 }
620 }
621}
622
623fn calculate_src_offsets(buffer_size: u64, limits: &Limits, offset: u64) -> (u64, u64) {
625 let binding_size = calculate_src_buffer_binding_size(buffer_size, limits);
626
627 let min_storage_buffer_offset_alignment = limits.min_storage_buffer_offset_alignment as u64;
628
629 let chunk_adjustment = match min_storage_buffer_offset_alignment {
630 4 => 0,
632 8 => 2,
636 16.. => 1,
640 _ => unreachable!(),
641 };
642
643 let chunks = binding_size / min_storage_buffer_offset_alignment;
644 let dynamic_offset_stride =
645 chunks.saturating_sub(chunk_adjustment) * min_storage_buffer_offset_alignment;
646
647 if dynamic_offset_stride == 0 {
648 return (0, offset);
649 }
650
651 let max_dynamic_offset = buffer_size - binding_size;
652 let max_dynamic_offset_index = max_dynamic_offset / dynamic_offset_stride;
653
654 let src_dynamic_offset_index = offset / dynamic_offset_stride;
655
656 let src_dynamic_offset =
657 src_dynamic_offset_index.min(max_dynamic_offset_index) * dynamic_offset_stride;
658 let src_offset = offset - src_dynamic_offset;
659
660 (src_dynamic_offset, src_offset)
661}
662
663#[derive(Debug)]
664struct BufferPoolEntry {
665 buffer: Box<dyn hal::DynBuffer>,
666 bind_group: Box<dyn hal::DynBindGroup>,
667}
668
669fn create_buffer_and_bind_group(
670 device: &dyn hal::DynDevice,
671 usage: wgt::BufferUses,
672 bind_group_layout: &dyn hal::DynBindGroupLayout,
673) -> Result<BufferPoolEntry, hal::DeviceError> {
674 let buffer_desc = hal::BufferDescriptor {
675 label: None,
676 size: BUFFER_SIZE.get(),
677 usage,
678 memory_flags: hal::MemoryFlags::empty(),
679 };
680 let buffer = unsafe { device.create_buffer(&buffer_desc) }?;
681 let bind_group_desc = hal::BindGroupDescriptor {
682 label: None,
683 layout: bind_group_layout,
684 entries: &[hal::BindGroupEntry {
685 binding: 0,
686 resource_index: 0,
687 count: 1,
688 }],
689 buffers: &[hal::BufferBinding::new_unchecked(
691 buffer.as_ref(),
692 0,
693 BUFFER_SIZE,
694 )],
695 samplers: &[],
696 textures: &[],
697 acceleration_structures: &[],
698 external_textures: &[],
699 };
700 let bind_group = unsafe { device.create_bind_group(&bind_group_desc) }?;
701 Ok(BufferPoolEntry { buffer, bind_group })
702}
703
704#[derive(Clone)]
705struct CurrentEntry {
706 index: usize,
707 offset: u64,
708}
709
710pub(crate) struct DrawResources {
712 device: Arc<Device>,
713 dst_entries: Vec<BufferPoolEntry>,
714 metadata_entries: Vec<BufferPoolEntry>,
715}
716
717impl Drop for DrawResources {
718 fn drop(&mut self) {
719 if let Some(ref indirect_validation) = self.device.indirect_validation {
720 let indirect_draw_validation = &indirect_validation.draw;
721 indirect_draw_validation.release_dst_entries(self.dst_entries.drain(..));
722 indirect_draw_validation.release_metadata_entries(self.metadata_entries.drain(..));
723 }
724 }
725}
726
727impl DrawResources {
728 pub(crate) fn new(device: Arc<Device>) -> Self {
729 DrawResources {
730 device,
731 dst_entries: Vec::new(),
732 metadata_entries: Vec::new(),
733 }
734 }
735
736 pub(crate) fn get_dst_buffer(&self, index: usize) -> &dyn hal::DynBuffer {
737 self.dst_entries.get(index).unwrap().buffer.as_ref()
738 }
739
740 fn get_dst_bind_group(&self, index: usize) -> &dyn hal::DynBindGroup {
741 self.dst_entries.get(index).unwrap().bind_group.as_ref()
742 }
743
744 fn get_metadata_buffer(&self, index: usize) -> &dyn hal::DynBuffer {
745 self.metadata_entries.get(index).unwrap().buffer.as_ref()
746 }
747
748 fn get_metadata_bind_group(&self, index: usize) -> &dyn hal::DynBindGroup {
749 self.metadata_entries
750 .get(index)
751 .unwrap()
752 .bind_group
753 .as_ref()
754 }
755
756 fn get_dst_subrange(
757 &mut self,
758 size: u64,
759 current_entry: &mut Option<CurrentEntry>,
760 ) -> Result<(usize, u64), DeviceError> {
761 let indirect_draw_validation = &self.device.indirect_validation.as_ref().unwrap().draw;
762 let ensure_entry = |index: usize| {
763 if self.dst_entries.len() <= index {
764 let entry = indirect_draw_validation.acquire_dst_entry(self.device.raw())?;
765 self.dst_entries.push(entry);
766 }
767 Ok(())
768 };
769 let entry_data = Self::get_subrange_impl(ensure_entry, current_entry, size)?;
770 Ok((entry_data.index, entry_data.offset))
771 }
772
773 fn get_metadata_subrange(
774 &mut self,
775 size: u64,
776 current_entry: &mut Option<CurrentEntry>,
777 ) -> Result<(usize, u64), DeviceError> {
778 let indirect_draw_validation = &self.device.indirect_validation.as_ref().unwrap().draw;
779 let ensure_entry = |index: usize| {
780 if self.metadata_entries.len() <= index {
781 let entry = indirect_draw_validation.acquire_metadata_entry(self.device.raw())?;
782 self.metadata_entries.push(entry);
783 }
784 Ok(())
785 };
786 let entry_data = Self::get_subrange_impl(ensure_entry, current_entry, size)?;
787 Ok((entry_data.index, entry_data.offset))
788 }
789
790 fn get_subrange_impl(
791 ensure_entry: impl FnOnce(usize) -> Result<(), hal::DeviceError>,
792 current_entry: &mut Option<CurrentEntry>,
793 size: u64,
794 ) -> Result<CurrentEntry, DeviceError> {
795 let index = if let Some(current_entry) = current_entry.as_mut() {
796 if current_entry.offset + size <= BUFFER_SIZE.get() {
797 let entry_data = current_entry.clone();
798 current_entry.offset += size;
799 return Ok(entry_data);
800 } else {
801 current_entry.index + 1
802 }
803 } else {
804 0
805 };
806
807 ensure_entry(index).map_err(DeviceError::from_hal)?;
808
809 let entry_data = CurrentEntry { index, offset: 0 };
810
811 *current_entry = Some(CurrentEntry {
812 index,
813 offset: size,
814 });
815
816 Ok(entry_data)
817 }
818}
819
820#[repr(C)]
822struct MetadataEntry {
823 src_offset: u32,
824 dst_offset: u32,
825 vertex_or_index_limit: u32,
826 instance_limit: u32,
827}
828
829impl MetadataEntry {
830 fn new(
831 indexed: bool,
832 src_offset: u64,
833 dst_offset: u64,
834 vertex_or_index_limit: u64,
835 instance_limit: u64,
836 ) -> Self {
837 debug_assert_eq!(
838 4,
839 size_of_val(&Limits::default().max_storage_buffer_binding_size)
840 );
841
842 let src_offset = src_offset as u32; let src_offset = src_offset / 4; let src_offset = src_offset | ((indexed as u32) << 31);
848
849 let max_limit = u32::MAX as u64 + u32::MAX as u64; let vertex_or_index_limit = vertex_or_index_limit.min(max_limit);
853 let vertex_or_index_limit_bit_32 = (vertex_or_index_limit >> 32) as u32; let vertex_or_index_limit = vertex_or_index_limit as u32; let instance_limit = instance_limit.min(max_limit);
857 let instance_limit_bit_32 = (instance_limit >> 32) as u32; let instance_limit = instance_limit as u32; let dst_offset = dst_offset as u32; let dst_offset = dst_offset / 4; let dst_offset =
867 dst_offset | (vertex_or_index_limit_bit_32 << 30) | (instance_limit_bit_32 << 31);
868
869 Self {
870 src_offset,
871 dst_offset,
872 vertex_or_index_limit,
873 instance_limit,
874 }
875 }
876}
877
878struct DrawIndirectValidationBatch {
879 src_buffer: Arc<crate::resource::Buffer>,
880 src_dynamic_offset: u64,
881 dst_resource_index: usize,
882 entries: Vec<MetadataEntry>,
883
884 staging_buffer_index: usize,
885 staging_buffer_offset: u64,
886 metadata_resource_index: usize,
887 metadata_buffer_offset: u64,
888}
889
890impl DrawIndirectValidationBatch {
891 fn metadata(&self) -> &[u8] {
893 unsafe {
894 core::slice::from_raw_parts(
895 self.entries.as_ptr().cast::<u8>(),
896 self.entries.len() * size_of::<MetadataEntry>(),
897 )
898 }
899 }
900}
901
902pub(crate) struct DrawBatcher {
904 batches: FastHashMap<(TrackerIndex, u64, usize), DrawIndirectValidationBatch>,
905 current_dst_entry: Option<CurrentEntry>,
906}
907
908impl DrawBatcher {
909 pub(crate) fn new() -> Self {
910 Self {
911 batches: FastHashMap::default(),
912 current_dst_entry: None,
913 }
914 }
915
916 pub(crate) fn add<'a>(
921 &mut self,
922 indirect_draw_validation_resources: &'a mut DrawResources,
923 device: &Device,
924 src_buffer: &Arc<crate::resource::Buffer>,
925 offset: u64,
926 family: crate::command::DrawCommandFamily,
927 vertex_or_index_limit: u64,
928 instance_limit: u64,
929 ) -> Result<(usize, u64), DeviceError> {
930 let extra = if device.backend() == wgt::Backend::Dx12 {
932 3 * size_of::<u32>() as u64
933 } else {
934 0
935 };
936 let stride = extra + crate::command::get_stride_of_indirect_args(family);
937
938 let (dst_resource_index, dst_offset) = indirect_draw_validation_resources
939 .get_dst_subrange(stride, &mut self.current_dst_entry)?;
940
941 let buffer_size = src_buffer.size;
942 let limits = device.adapter.limits();
943 let (src_dynamic_offset, src_offset) = calculate_src_offsets(buffer_size, &limits, offset);
944
945 let src_buffer_tracker_index = src_buffer.tracker_index();
946
947 let entry = MetadataEntry::new(
948 family == crate::command::DrawCommandFamily::DrawIndexed,
949 src_offset,
950 dst_offset,
951 vertex_or_index_limit,
952 instance_limit,
953 );
954
955 match self.batches.entry((
956 src_buffer_tracker_index,
957 src_dynamic_offset,
958 dst_resource_index,
959 )) {
960 hashbrown::hash_map::Entry::Occupied(mut occupied_entry) => {
961 occupied_entry.get_mut().entries.push(entry)
962 }
963 hashbrown::hash_map::Entry::Vacant(vacant_entry) => {
964 vacant_entry.insert(DrawIndirectValidationBatch {
965 src_buffer: src_buffer.clone(),
966 src_dynamic_offset,
967 dst_resource_index,
968 entries: vec![entry],
969
970 staging_buffer_index: 0,
972 staging_buffer_offset: 0,
973 metadata_resource_index: 0,
974 metadata_buffer_offset: 0,
975 });
976 }
977 }
978
979 Ok((dst_resource_index, dst_offset))
980 }
981}