1use alloc::{boxed::Box, sync::Arc, vec::Vec};
2use core::{
3 cmp::max,
4 num::NonZeroU64,
5 ops::{Deref, Range},
6};
7
8use wgt::{math::align_to, BufferUsages, BufferUses, Features};
9
10use crate::ray_tracing::{AsAction, AsBuild, TlasBuild, ValidateAsActionsError};
11use crate::{
12 command::CommandBufferMutable,
13 device::queue::TempResource,
14 global::Global,
15 hub::Hub,
16 id::CommandEncoderId,
17 init_tracker::MemoryInitKind,
18 ray_tracing::{
19 BlasBuildEntry, BlasGeometries, BlasTriangleGeometry, BuildAccelerationStructureError,
20 TlasInstance, TlasPackage, TraceBlasBuildEntry, TraceBlasGeometries,
21 TraceBlasTriangleGeometry, TraceTlasInstance, TraceTlasPackage,
22 },
23 resource::{Blas, BlasCompactState, Buffer, Labeled, StagingBuffer, Tlas},
24 scratch::ScratchBuffer,
25 snatch::SnatchGuard,
26 track::PendingTransition,
27};
28use crate::{command::EncoderStateError, device::resource::CommandIndices};
29use crate::{lock::RwLockWriteGuard, resource::RawResourceAccess};
30
31use crate::id::{BlasId, TlasId};
32
33struct TriangleBufferStore<'a> {
34 vertex_buffer: Arc<Buffer>,
35 vertex_transition: Option<PendingTransition<BufferUses>>,
36 index_buffer_transition: Option<(Arc<Buffer>, Option<PendingTransition<BufferUses>>)>,
37 transform_buffer_transition: Option<(Arc<Buffer>, Option<PendingTransition<BufferUses>>)>,
38 geometry: BlasTriangleGeometry<'a>,
39 ending_blas: Option<Arc<Blas>>,
40}
41
42struct BlasStore<'a> {
43 blas: Arc<Blas>,
44 entries: hal::AccelerationStructureEntries<'a, dyn hal::DynBuffer>,
45 scratch_buffer_offset: u64,
46}
47
48struct UnsafeTlasStore<'a> {
49 tlas: Arc<Tlas>,
50 entries: hal::AccelerationStructureEntries<'a, dyn hal::DynBuffer>,
51 scratch_buffer_offset: u64,
52}
53
54struct TlasStore<'a> {
55 internal: UnsafeTlasStore<'a>,
56 range: Range<usize>,
57}
58
59impl Global {
60 pub fn command_encoder_mark_acceleration_structures_built(
61 &self,
62 command_encoder_id: CommandEncoderId,
63 blas_ids: &[BlasId],
64 tlas_ids: &[TlasId],
65 ) -> Result<(), EncoderStateError> {
66 profiling::scope!("CommandEncoder::mark_acceleration_structures_built");
67
68 let hub = &self.hub;
69
70 let cmd_enc = hub.command_encoders.get(command_encoder_id);
71
72 let mut cmd_buf_data = cmd_enc.data.lock();
73 cmd_buf_data.record_with(
74 |cmd_buf_data| -> Result<(), BuildAccelerationStructureError> {
75 let device = &cmd_enc.device;
76 device.check_is_valid()?;
77 device.require_features(Features::EXPERIMENTAL_RAY_QUERY)?;
78
79 let mut build_command = AsBuild::default();
80
81 for blas in blas_ids {
82 let blas = hub.blas_s.get(*blas).get()?;
83 build_command.blas_s_built.push(blas);
84 }
85
86 for tlas in tlas_ids {
87 let tlas = hub.tlas_s.get(*tlas).get()?;
88 build_command.tlas_s_built.push(TlasBuild {
89 tlas,
90 dependencies: Vec::new(),
91 });
92 }
93
94 cmd_buf_data.as_actions.push(AsAction::Build(build_command));
95 Ok(())
96 },
97 )
98 }
99
100 pub fn command_encoder_build_acceleration_structures<'a>(
101 &self,
102 command_encoder_id: CommandEncoderId,
103 blas_iter: impl Iterator<Item = BlasBuildEntry<'a>>,
104 tlas_iter: impl Iterator<Item = TlasPackage<'a>>,
105 ) -> Result<(), EncoderStateError> {
106 profiling::scope!("CommandEncoder::build_acceleration_structures");
107
108 let hub = &self.hub;
109
110 let cmd_enc = hub.command_encoders.get(command_encoder_id);
111
112 let mut build_command = AsBuild::default();
113
114 let trace_blas: Vec<TraceBlasBuildEntry> = blas_iter
115 .map(|blas_entry| {
116 let geometries = match blas_entry.geometries {
117 BlasGeometries::TriangleGeometries(triangle_geometries) => {
118 TraceBlasGeometries::TriangleGeometries(
119 triangle_geometries
120 .map(|tg| TraceBlasTriangleGeometry {
121 size: tg.size.clone(),
122 vertex_buffer: tg.vertex_buffer,
123 index_buffer: tg.index_buffer,
124 transform_buffer: tg.transform_buffer,
125 first_vertex: tg.first_vertex,
126 vertex_stride: tg.vertex_stride,
127 first_index: tg.first_index,
128 transform_buffer_offset: tg.transform_buffer_offset,
129 })
130 .collect(),
131 )
132 }
133 };
134 TraceBlasBuildEntry {
135 blas_id: blas_entry.blas_id,
136 geometries,
137 }
138 })
139 .collect();
140
141 let trace_tlas: Vec<TraceTlasPackage> = tlas_iter
142 .map(|package: TlasPackage| {
143 let instances = package
144 .instances
145 .map(|instance| {
146 instance.map(|instance| TraceTlasInstance {
147 blas_id: instance.blas_id,
148 transform: *instance.transform,
149 custom_data: instance.custom_data,
150 mask: instance.mask,
151 })
152 })
153 .collect();
154 TraceTlasPackage {
155 tlas_id: package.tlas_id,
156 instances,
157 lowest_unmodified: package.lowest_unmodified,
158 }
159 })
160 .collect();
161
162 let blas_iter = trace_blas.iter().map(|blas_entry| {
163 let geometries = match &blas_entry.geometries {
164 TraceBlasGeometries::TriangleGeometries(triangle_geometries) => {
165 let iter = triangle_geometries.iter().map(|tg| BlasTriangleGeometry {
166 size: &tg.size,
167 vertex_buffer: tg.vertex_buffer,
168 index_buffer: tg.index_buffer,
169 transform_buffer: tg.transform_buffer,
170 first_vertex: tg.first_vertex,
171 vertex_stride: tg.vertex_stride,
172 first_index: tg.first_index,
173 transform_buffer_offset: tg.transform_buffer_offset,
174 });
175 BlasGeometries::TriangleGeometries(Box::new(iter))
176 }
177 };
178 BlasBuildEntry {
179 blas_id: blas_entry.blas_id,
180 geometries,
181 }
182 });
183
184 let tlas_iter = trace_tlas.iter().map(|tlas_package| {
185 let instances = tlas_package.instances.iter().map(|instance| {
186 instance.as_ref().map(|instance| TlasInstance {
187 blas_id: instance.blas_id,
188 transform: &instance.transform,
189 custom_data: instance.custom_data,
190 mask: instance.mask,
191 })
192 });
193 TlasPackage {
194 tlas_id: tlas_package.tlas_id,
195 instances: Box::new(instances),
196 lowest_unmodified: tlas_package.lowest_unmodified,
197 }
198 });
199
200 let mut cmd_buf_data = cmd_enc.data.lock();
201 cmd_buf_data.record_with(|cmd_buf_data| {
202 #[cfg(feature = "trace")]
203 if let Some(ref mut list) = cmd_buf_data.commands {
204 list.push(crate::device::trace::Command::BuildAccelerationStructures {
205 blas: trace_blas.clone(),
206 tlas: trace_tlas.clone(),
207 });
208 }
209
210 let device = &cmd_enc.device;
211 device.check_is_valid()?;
212 device.require_features(Features::EXPERIMENTAL_RAY_QUERY)?;
213
214 let mut buf_storage = Vec::new();
215 iter_blas(
216 blas_iter,
217 cmd_buf_data,
218 &mut build_command,
219 &mut buf_storage,
220 hub,
221 )?;
222
223 let snatch_guard = device.snatchable_lock.read();
224 let mut input_barriers = Vec::<hal::BufferBarrier<dyn hal::DynBuffer>>::new();
225 let mut scratch_buffer_blas_size = 0;
226 let mut blas_storage = Vec::new();
227 iter_buffers(
228 &mut buf_storage,
229 &snatch_guard,
230 &mut input_barriers,
231 cmd_buf_data,
232 &mut scratch_buffer_blas_size,
233 &mut blas_storage,
234 hub,
235 device.alignments.ray_tracing_scratch_buffer_alignment,
236 )?;
237 let mut tlas_lock_store = Vec::<(Option<TlasPackage>, Arc<Tlas>)>::new();
238
239 for package in tlas_iter {
240 let tlas = hub.tlas_s.get(package.tlas_id).get()?;
241
242 cmd_buf_data.trackers.tlas_s.insert_single(tlas.clone());
243
244 tlas_lock_store.push((Some(package), tlas))
245 }
246
247 let mut scratch_buffer_tlas_size = 0;
248 let mut tlas_storage = Vec::<TlasStore>::new();
249 let mut instance_buffer_staging_source = Vec::<u8>::new();
250
251 for (package, tlas) in &mut tlas_lock_store {
252 let package = package.take().unwrap();
253
254 let scratch_buffer_offset = scratch_buffer_tlas_size;
255 scratch_buffer_tlas_size += align_to(
256 tlas.size_info.build_scratch_size as u32,
257 device.alignments.ray_tracing_scratch_buffer_alignment,
258 ) as u64;
259
260 let first_byte_index = instance_buffer_staging_source.len();
261
262 let mut dependencies = Vec::new();
263
264 let mut instance_count = 0;
265 for instance in package.instances.flatten() {
266 if instance.custom_data >= (1u32 << 24u32) {
267 return Err(BuildAccelerationStructureError::TlasInvalidCustomIndex(
268 tlas.error_ident(),
269 ));
270 }
271 let blas = hub.blas_s.get(instance.blas_id).get()?;
272
273 cmd_buf_data.trackers.blas_s.insert_single(blas.clone());
274
275 instance_buffer_staging_source.extend(device.raw().tlas_instance_to_bytes(
276 hal::TlasInstance {
277 transform: *instance.transform,
278 custom_data: instance.custom_data,
279 mask: instance.mask,
280 blas_address: blas.handle,
281 },
282 ));
283
284 if tlas.flags.contains(
285 wgpu_types::AccelerationStructureFlags::ALLOW_RAY_HIT_VERTEX_RETURN,
286 ) && !blas.flags.contains(
287 wgpu_types::AccelerationStructureFlags::ALLOW_RAY_HIT_VERTEX_RETURN,
288 ) {
289 return Err(
290 BuildAccelerationStructureError::TlasDependentMissingVertexReturn(
291 tlas.error_ident(),
292 blas.error_ident(),
293 ),
294 );
295 }
296
297 instance_count += 1;
298
299 dependencies.push(blas.clone());
300 }
301
302 build_command.tlas_s_built.push(TlasBuild {
303 tlas: tlas.clone(),
304 dependencies,
305 });
306
307 if instance_count > tlas.max_instance_count {
308 return Err(BuildAccelerationStructureError::TlasInstanceCountExceeded(
309 tlas.error_ident(),
310 instance_count,
311 tlas.max_instance_count,
312 ));
313 }
314
315 tlas_storage.push(TlasStore {
316 internal: UnsafeTlasStore {
317 tlas: tlas.clone(),
318 entries: hal::AccelerationStructureEntries::Instances(
319 hal::AccelerationStructureInstances {
320 buffer: Some(tlas.instance_buffer.as_ref()),
321 offset: 0,
322 count: instance_count,
323 },
324 ),
325 scratch_buffer_offset,
326 },
327 range: first_byte_index..instance_buffer_staging_source.len(),
328 });
329 }
330
331 let Some(scratch_size) =
332 wgt::BufferSize::new(max(scratch_buffer_blas_size, scratch_buffer_tlas_size))
333 else {
334 return Ok(());
336 };
337
338 let scratch_buffer = ScratchBuffer::new(device, scratch_size)?;
339
340 let scratch_buffer_barrier = hal::BufferBarrier::<dyn hal::DynBuffer> {
341 buffer: scratch_buffer.raw(),
342 usage: hal::StateTransition {
343 from: BufferUses::ACCELERATION_STRUCTURE_SCRATCH,
344 to: BufferUses::ACCELERATION_STRUCTURE_SCRATCH,
345 },
346 };
347
348 let mut tlas_descriptors = Vec::with_capacity(tlas_storage.len());
349
350 for &TlasStore {
351 internal:
352 UnsafeTlasStore {
353 ref tlas,
354 ref entries,
355 ref scratch_buffer_offset,
356 },
357 ..
358 } in &tlas_storage
359 {
360 if tlas.update_mode == wgt::AccelerationStructureUpdateMode::PreferUpdate {
361 log::info!("only rebuild implemented")
362 }
363 tlas_descriptors.push(hal::BuildAccelerationStructureDescriptor {
364 entries,
365 mode: hal::AccelerationStructureBuildMode::Build,
366 flags: tlas.flags,
367 source_acceleration_structure: None,
368 destination_acceleration_structure: tlas.try_raw(&snatch_guard)?,
369 scratch_buffer: scratch_buffer.raw(),
370 scratch_buffer_offset: *scratch_buffer_offset,
371 })
372 }
373
374 let blas_present = !blas_storage.is_empty();
375 let tlas_present = !tlas_storage.is_empty();
376
377 let cmd_buf_raw = cmd_buf_data.encoder.open()?;
378
379 let mut blas_s_compactable = Vec::new();
380 let mut descriptors = Vec::new();
381
382 for storage in &blas_storage {
383 descriptors.push(map_blas(
384 storage,
385 scratch_buffer.raw(),
386 &snatch_guard,
387 &mut blas_s_compactable,
388 )?);
389 }
390
391 build_blas(
392 cmd_buf_raw,
393 blas_present,
394 tlas_present,
395 input_barriers,
396 &descriptors,
397 scratch_buffer_barrier,
398 blas_s_compactable,
399 );
400
401 if tlas_present {
402 let staging_buffer = if !instance_buffer_staging_source.is_empty() {
403 let mut staging_buffer = StagingBuffer::new(
404 device,
405 wgt::BufferSize::new(instance_buffer_staging_source.len() as u64).unwrap(),
406 )?;
407 staging_buffer.write(&instance_buffer_staging_source);
408 let flushed = staging_buffer.flush();
409 Some(flushed)
410 } else {
411 None
412 };
413
414 unsafe {
415 if let Some(ref staging_buffer) = staging_buffer {
416 cmd_buf_raw.transition_buffers(&[
417 hal::BufferBarrier::<dyn hal::DynBuffer> {
418 buffer: staging_buffer.raw(),
419 usage: hal::StateTransition {
420 from: BufferUses::MAP_WRITE,
421 to: BufferUses::COPY_SRC,
422 },
423 },
424 ]);
425 }
426 }
427
428 let mut instance_buffer_barriers = Vec::new();
429 for &TlasStore {
430 internal: UnsafeTlasStore { ref tlas, .. },
431 ref range,
432 } in &tlas_storage
433 {
434 let size = match wgt::BufferSize::new((range.end - range.start) as u64) {
435 None => continue,
436 Some(size) => size,
437 };
438 instance_buffer_barriers.push(hal::BufferBarrier::<dyn hal::DynBuffer> {
439 buffer: tlas.instance_buffer.as_ref(),
440 usage: hal::StateTransition {
441 from: BufferUses::COPY_DST,
442 to: BufferUses::TOP_LEVEL_ACCELERATION_STRUCTURE_INPUT,
443 },
444 });
445 unsafe {
446 cmd_buf_raw.transition_buffers(&[
447 hal::BufferBarrier::<dyn hal::DynBuffer> {
448 buffer: tlas.instance_buffer.as_ref(),
449 usage: hal::StateTransition {
450 from: BufferUses::TOP_LEVEL_ACCELERATION_STRUCTURE_INPUT,
451 to: BufferUses::COPY_DST,
452 },
453 },
454 ]);
455 let temp = hal::BufferCopy {
456 src_offset: range.start as u64,
457 dst_offset: 0,
458 size,
459 };
460 cmd_buf_raw.copy_buffer_to_buffer(
461 staging_buffer.as_ref().unwrap().raw(),
464 tlas.instance_buffer.as_ref(),
465 &[temp],
466 );
467 }
468 }
469
470 unsafe {
471 cmd_buf_raw.transition_buffers(&instance_buffer_barriers);
472
473 cmd_buf_raw.build_acceleration_structures(&tlas_descriptors);
474
475 cmd_buf_raw.place_acceleration_structure_barrier(
476 hal::AccelerationStructureBarrier {
477 usage: hal::StateTransition {
478 from: hal::AccelerationStructureUses::BUILD_OUTPUT,
479 to: hal::AccelerationStructureUses::SHADER_INPUT,
480 },
481 },
482 );
483 }
484
485 if let Some(staging_buffer) = staging_buffer {
486 cmd_buf_data
487 .temp_resources
488 .push(TempResource::StagingBuffer(staging_buffer));
489 }
490 }
491
492 cmd_buf_data
493 .temp_resources
494 .push(TempResource::ScratchBuffer(scratch_buffer));
495
496 cmd_buf_data.as_actions.push(AsAction::Build(build_command));
497
498 Ok(())
499 })
500 }
501}
502
503impl CommandBufferMutable {
504 pub(crate) fn validate_acceleration_structure_actions(
505 &self,
506 snatch_guard: &SnatchGuard,
507 command_index_guard: &mut RwLockWriteGuard<CommandIndices>,
508 ) -> Result<(), ValidateAsActionsError> {
509 profiling::scope!("CommandEncoder::[submission]::validate_as_actions");
510 for action in &self.as_actions {
511 match action {
512 AsAction::Build(build) => {
513 let build_command_index = NonZeroU64::new(
514 command_index_guard.next_acceleration_structure_build_command_index,
515 )
516 .unwrap();
517
518 command_index_guard.next_acceleration_structure_build_command_index += 1;
519 for blas in build.blas_s_built.iter() {
520 let mut state_lock = blas.compacted_state.lock();
521 *state_lock = match *state_lock {
522 BlasCompactState::Compacted => {
523 unreachable!("Should be validated out in build.")
524 }
525 _ => BlasCompactState::Idle,
528 };
529 *blas.built_index.write() = Some(build_command_index);
530 }
531
532 for tlas_build in build.tlas_s_built.iter() {
533 for blas in &tlas_build.dependencies {
534 if blas.built_index.read().is_none() {
535 return Err(ValidateAsActionsError::UsedUnbuiltBlas(
536 blas.error_ident(),
537 tlas_build.tlas.error_ident(),
538 ));
539 }
540 }
541 *tlas_build.tlas.built_index.write() = Some(build_command_index);
542 tlas_build
543 .tlas
544 .dependencies
545 .write()
546 .clone_from(&tlas_build.dependencies)
547 }
548 }
549 AsAction::UseTlas(tlas) => {
550 let tlas_build_index = tlas.built_index.read();
551 let dependencies = tlas.dependencies.read();
552
553 if (*tlas_build_index).is_none() {
554 return Err(ValidateAsActionsError::UsedUnbuiltTlas(tlas.error_ident()));
555 }
556 for blas in dependencies.deref() {
557 let blas_build_index = *blas.built_index.read();
558 if blas_build_index.is_none() {
559 return Err(ValidateAsActionsError::UsedUnbuiltBlas(
560 tlas.error_ident(),
561 blas.error_ident(),
562 ));
563 }
564 if blas_build_index.unwrap() > tlas_build_index.unwrap() {
565 return Err(ValidateAsActionsError::BlasNewerThenTlas(
566 blas.error_ident(),
567 tlas.error_ident(),
568 ));
569 }
570 blas.try_raw(snatch_guard)?;
571 }
572 }
573 }
574 }
575 Ok(())
576 }
577}
578
579fn iter_blas<'a>(
581 blas_iter: impl Iterator<Item = BlasBuildEntry<'a>>,
582 cmd_buf_data: &mut CommandBufferMutable,
583 build_command: &mut AsBuild,
584 buf_storage: &mut Vec<TriangleBufferStore<'a>>,
585 hub: &Hub,
586) -> Result<(), BuildAccelerationStructureError> {
587 let mut temp_buffer = Vec::new();
588 for entry in blas_iter {
589 let blas = hub.blas_s.get(entry.blas_id).get()?;
590 cmd_buf_data.trackers.blas_s.insert_single(blas.clone());
591
592 build_command.blas_s_built.push(blas.clone());
593
594 match entry.geometries {
595 BlasGeometries::TriangleGeometries(triangle_geometries) => {
596 for (i, mesh) in triangle_geometries.enumerate() {
597 let size_desc = match &blas.sizes {
598 wgt::BlasGeometrySizeDescriptors::Triangles { descriptors } => descriptors,
599 };
600 if i >= size_desc.len() {
601 return Err(BuildAccelerationStructureError::IncompatibleBlasBuildSizes(
602 blas.error_ident(),
603 ));
604 }
605 let size_desc = &size_desc[i];
606
607 if size_desc.flags != mesh.size.flags {
608 return Err(BuildAccelerationStructureError::IncompatibleBlasFlags(
609 blas.error_ident(),
610 size_desc.flags,
611 mesh.size.flags,
612 ));
613 }
614
615 if size_desc.vertex_count < mesh.size.vertex_count {
616 return Err(
617 BuildAccelerationStructureError::IncompatibleBlasVertexCount(
618 blas.error_ident(),
619 size_desc.vertex_count,
620 mesh.size.vertex_count,
621 ),
622 );
623 }
624
625 if size_desc.vertex_format != mesh.size.vertex_format {
626 return Err(BuildAccelerationStructureError::DifferentBlasVertexFormats(
627 blas.error_ident(),
628 size_desc.vertex_format,
629 mesh.size.vertex_format,
630 ));
631 }
632
633 if size_desc
634 .vertex_format
635 .min_acceleration_structure_vertex_stride()
636 > mesh.vertex_stride
637 {
638 return Err(BuildAccelerationStructureError::VertexStrideTooSmall(
639 blas.error_ident(),
640 size_desc
641 .vertex_format
642 .min_acceleration_structure_vertex_stride(),
643 mesh.vertex_stride,
644 ));
645 }
646
647 if mesh.vertex_stride
648 % size_desc
649 .vertex_format
650 .acceleration_structure_stride_alignment()
651 != 0
652 {
653 return Err(BuildAccelerationStructureError::VertexStrideUnaligned(
654 blas.error_ident(),
655 size_desc
656 .vertex_format
657 .acceleration_structure_stride_alignment(),
658 mesh.vertex_stride,
659 ));
660 }
661
662 match (size_desc.index_count, mesh.size.index_count) {
663 (Some(_), None) | (None, Some(_)) => {
664 return Err(
665 BuildAccelerationStructureError::BlasIndexCountProvidedMismatch(
666 blas.error_ident(),
667 ),
668 )
669 }
670 (Some(create), Some(build)) if create < build => {
671 return Err(
672 BuildAccelerationStructureError::IncompatibleBlasIndexCount(
673 blas.error_ident(),
674 create,
675 build,
676 ),
677 )
678 }
679 _ => {}
680 }
681
682 if size_desc.index_format != mesh.size.index_format {
683 return Err(BuildAccelerationStructureError::DifferentBlasIndexFormats(
684 blas.error_ident(),
685 size_desc.index_format,
686 mesh.size.index_format,
687 ));
688 }
689
690 if size_desc.index_count.is_some() && mesh.index_buffer.is_none() {
691 return Err(BuildAccelerationStructureError::MissingIndexBuffer(
692 blas.error_ident(),
693 ));
694 }
695 let vertex_buffer = hub.buffers.get(mesh.vertex_buffer).get()?;
696 let vertex_pending = cmd_buf_data.trackers.buffers.set_single(
697 &vertex_buffer,
698 BufferUses::BOTTOM_LEVEL_ACCELERATION_STRUCTURE_INPUT,
699 );
700 let index_data = if let Some(index_id) = mesh.index_buffer {
701 let index_buffer = hub.buffers.get(index_id).get()?;
702 if mesh.first_index.is_none()
703 || mesh.size.index_count.is_none()
704 || mesh.size.index_count.is_none()
705 {
706 return Err(BuildAccelerationStructureError::MissingAssociatedData(
707 index_buffer.error_ident(),
708 ));
709 }
710 let data = cmd_buf_data.trackers.buffers.set_single(
711 &index_buffer,
712 BufferUses::BOTTOM_LEVEL_ACCELERATION_STRUCTURE_INPUT,
713 );
714 Some((index_buffer, data))
715 } else {
716 None
717 };
718 let transform_data = if let Some(transform_id) = mesh.transform_buffer {
719 if !blas
720 .flags
721 .contains(wgt::AccelerationStructureFlags::USE_TRANSFORM)
722 {
723 return Err(BuildAccelerationStructureError::UseTransformMissing(
724 blas.error_ident(),
725 ));
726 }
727 let transform_buffer = hub.buffers.get(transform_id).get()?;
728 if mesh.transform_buffer_offset.is_none() {
729 return Err(BuildAccelerationStructureError::MissingAssociatedData(
730 transform_buffer.error_ident(),
731 ));
732 }
733 let data = cmd_buf_data.trackers.buffers.set_single(
734 &transform_buffer,
735 BufferUses::BOTTOM_LEVEL_ACCELERATION_STRUCTURE_INPUT,
736 );
737 Some((transform_buffer, data))
738 } else {
739 if blas
740 .flags
741 .contains(wgt::AccelerationStructureFlags::USE_TRANSFORM)
742 {
743 return Err(BuildAccelerationStructureError::TransformMissing(
744 blas.error_ident(),
745 ));
746 }
747 None
748 };
749 temp_buffer.push(TriangleBufferStore {
750 vertex_buffer,
751 vertex_transition: vertex_pending,
752 index_buffer_transition: index_data,
753 transform_buffer_transition: transform_data,
754 geometry: mesh,
755 ending_blas: None,
756 });
757 }
758
759 if let Some(last) = temp_buffer.last_mut() {
760 last.ending_blas = Some(blas);
761 buf_storage.append(&mut temp_buffer);
762 }
763 }
764 }
765 }
766 Ok(())
767}
768
769fn iter_buffers<'a, 'b>(
771 buf_storage: &'a mut Vec<TriangleBufferStore<'b>>,
772 snatch_guard: &'a SnatchGuard,
773 input_barriers: &mut Vec<hal::BufferBarrier<'a, dyn hal::DynBuffer>>,
774 cmd_buf_data: &mut CommandBufferMutable,
775 scratch_buffer_blas_size: &mut u64,
776 blas_storage: &mut Vec<BlasStore<'a>>,
777 hub: &Hub,
778 ray_tracing_scratch_buffer_alignment: u32,
779) -> Result<(), BuildAccelerationStructureError> {
780 let mut triangle_entries =
781 Vec::<hal::AccelerationStructureTriangles<dyn hal::DynBuffer>>::new();
782 for buf in buf_storage {
783 let mesh = &buf.geometry;
784 let vertex_buffer = {
785 let vertex_buffer = buf.vertex_buffer.as_ref();
786 let vertex_raw = vertex_buffer.try_raw(snatch_guard)?;
787 vertex_buffer.check_usage(BufferUsages::BLAS_INPUT)?;
788
789 if let Some(barrier) = buf
790 .vertex_transition
791 .take()
792 .map(|pending| pending.into_hal(vertex_buffer, snatch_guard))
793 {
794 input_barriers.push(barrier);
795 }
796 if vertex_buffer.size
797 < (mesh.size.vertex_count + mesh.first_vertex) as u64 * mesh.vertex_stride
798 {
799 return Err(BuildAccelerationStructureError::InsufficientBufferSize(
800 vertex_buffer.error_ident(),
801 vertex_buffer.size,
802 (mesh.size.vertex_count + mesh.first_vertex) as u64 * mesh.vertex_stride,
803 ));
804 }
805 let vertex_buffer_offset = mesh.first_vertex as u64 * mesh.vertex_stride;
806 cmd_buf_data.buffer_memory_init_actions.extend(
807 vertex_buffer.initialization_status.read().create_action(
808 &hub.buffers.get(mesh.vertex_buffer).get()?,
809 vertex_buffer_offset
810 ..(vertex_buffer_offset
811 + mesh.size.vertex_count as u64 * mesh.vertex_stride),
812 MemoryInitKind::NeedsInitializedMemory,
813 ),
814 );
815 vertex_raw
816 };
817 let index_buffer = if let Some((ref mut index_buffer, ref mut index_pending)) =
818 buf.index_buffer_transition
819 {
820 let index_raw = index_buffer.try_raw(snatch_guard)?;
821 index_buffer.check_usage(BufferUsages::BLAS_INPUT)?;
822
823 if let Some(barrier) = index_pending
824 .take()
825 .map(|pending| pending.into_hal(index_buffer, snatch_guard))
826 {
827 input_barriers.push(barrier);
828 }
829 let index_stride = mesh.size.index_format.unwrap().byte_size() as u64;
830 let offset = mesh.first_index.unwrap() as u64 * index_stride;
831 let index_buffer_size = mesh.size.index_count.unwrap() as u64 * index_stride;
832
833 if mesh.size.index_count.unwrap() % 3 != 0 {
834 return Err(BuildAccelerationStructureError::InvalidIndexCount(
835 index_buffer.error_ident(),
836 mesh.size.index_count.unwrap(),
837 ));
838 }
839 if index_buffer.size < mesh.size.index_count.unwrap() as u64 * index_stride + offset {
840 return Err(BuildAccelerationStructureError::InsufficientBufferSize(
841 index_buffer.error_ident(),
842 index_buffer.size,
843 mesh.size.index_count.unwrap() as u64 * index_stride + offset,
844 ));
845 }
846
847 cmd_buf_data.buffer_memory_init_actions.extend(
848 index_buffer.initialization_status.read().create_action(
849 index_buffer,
850 offset..(offset + index_buffer_size),
851 MemoryInitKind::NeedsInitializedMemory,
852 ),
853 );
854 Some(index_raw)
855 } else {
856 None
857 };
858 let transform_buffer = if let Some((ref mut transform_buffer, ref mut transform_pending)) =
859 buf.transform_buffer_transition
860 {
861 if mesh.transform_buffer_offset.is_none() {
862 return Err(BuildAccelerationStructureError::MissingAssociatedData(
863 transform_buffer.error_ident(),
864 ));
865 }
866 let transform_raw = transform_buffer.try_raw(snatch_guard)?;
867 transform_buffer.check_usage(BufferUsages::BLAS_INPUT)?;
868
869 if let Some(barrier) = transform_pending
870 .take()
871 .map(|pending| pending.into_hal(transform_buffer, snatch_guard))
872 {
873 input_barriers.push(barrier);
874 }
875
876 let offset = mesh.transform_buffer_offset.unwrap();
877
878 if offset % wgt::TRANSFORM_BUFFER_ALIGNMENT != 0 {
879 return Err(
880 BuildAccelerationStructureError::UnalignedTransformBufferOffset(
881 transform_buffer.error_ident(),
882 ),
883 );
884 }
885 if transform_buffer.size < 48 + offset {
886 return Err(BuildAccelerationStructureError::InsufficientBufferSize(
887 transform_buffer.error_ident(),
888 transform_buffer.size,
889 48 + offset,
890 ));
891 }
892 cmd_buf_data.buffer_memory_init_actions.extend(
893 transform_buffer.initialization_status.read().create_action(
894 transform_buffer,
895 offset..(offset + 48),
896 MemoryInitKind::NeedsInitializedMemory,
897 ),
898 );
899 Some(transform_raw)
900 } else {
901 None
902 };
903
904 let triangles = hal::AccelerationStructureTriangles {
905 vertex_buffer: Some(vertex_buffer),
906 vertex_format: mesh.size.vertex_format,
907 first_vertex: mesh.first_vertex,
908 vertex_count: mesh.size.vertex_count,
909 vertex_stride: mesh.vertex_stride,
910 indices: index_buffer.map(|index_buffer| {
911 let index_stride = mesh.size.index_format.unwrap().byte_size() as u32;
912 hal::AccelerationStructureTriangleIndices::<dyn hal::DynBuffer> {
913 format: mesh.size.index_format.unwrap(),
914 buffer: Some(index_buffer),
915 offset: mesh.first_index.unwrap() * index_stride,
916 count: mesh.size.index_count.unwrap(),
917 }
918 }),
919 transform: transform_buffer.map(|transform_buffer| {
920 hal::AccelerationStructureTriangleTransform {
921 buffer: transform_buffer,
922 offset: mesh.transform_buffer_offset.unwrap() as u32,
923 }
924 }),
925 flags: mesh.size.flags,
926 };
927 triangle_entries.push(triangles);
928 if let Some(blas) = buf.ending_blas.take() {
929 let scratch_buffer_offset = *scratch_buffer_blas_size;
930 *scratch_buffer_blas_size += align_to(
931 blas.size_info.build_scratch_size as u32,
932 ray_tracing_scratch_buffer_alignment,
933 ) as u64;
934
935 blas_storage.push(BlasStore {
936 blas,
937 entries: hal::AccelerationStructureEntries::Triangles(triangle_entries),
938 scratch_buffer_offset,
939 });
940 triangle_entries = Vec::new();
941 }
942 }
943 Ok(())
944}
945
946fn map_blas<'a>(
947 storage: &'a BlasStore<'_>,
948 scratch_buffer: &'a dyn hal::DynBuffer,
949 snatch_guard: &'a SnatchGuard,
950 blases_compactable: &mut Vec<(
951 &'a dyn hal::DynBuffer,
952 &'a dyn hal::DynAccelerationStructure,
953 )>,
954) -> Result<
955 hal::BuildAccelerationStructureDescriptor<
956 'a,
957 dyn hal::DynBuffer,
958 dyn hal::DynAccelerationStructure,
959 >,
960 BuildAccelerationStructureError,
961> {
962 let BlasStore {
963 blas,
964 entries,
965 scratch_buffer_offset,
966 } = storage;
967 if blas.update_mode == wgt::AccelerationStructureUpdateMode::PreferUpdate {
968 log::info!("only rebuild implemented")
969 }
970 let raw = blas.try_raw(snatch_guard)?;
971
972 let state_lock = blas.compacted_state.lock();
973 if let BlasCompactState::Compacted = *state_lock {
974 return Err(BuildAccelerationStructureError::CompactedBlas(
975 blas.error_ident(),
976 ));
977 }
978
979 if blas
980 .flags
981 .contains(wgpu_types::AccelerationStructureFlags::ALLOW_COMPACTION)
982 {
983 blases_compactable.push((blas.compaction_buffer.as_ref().unwrap().as_ref(), raw));
984 }
985 Ok(hal::BuildAccelerationStructureDescriptor {
986 entries,
987 mode: hal::AccelerationStructureBuildMode::Build,
988 flags: blas.flags,
989 source_acceleration_structure: None,
990 destination_acceleration_structure: raw,
991 scratch_buffer,
992 scratch_buffer_offset: *scratch_buffer_offset,
993 })
994}
995
996fn build_blas<'a>(
997 cmd_buf_raw: &mut dyn hal::DynCommandEncoder,
998 blas_present: bool,
999 tlas_present: bool,
1000 input_barriers: Vec<hal::BufferBarrier<dyn hal::DynBuffer>>,
1001 blas_descriptors: &[hal::BuildAccelerationStructureDescriptor<
1002 'a,
1003 dyn hal::DynBuffer,
1004 dyn hal::DynAccelerationStructure,
1005 >],
1006 scratch_buffer_barrier: hal::BufferBarrier<dyn hal::DynBuffer>,
1007 blas_s_for_compaction: Vec<(
1008 &'a dyn hal::DynBuffer,
1009 &'a dyn hal::DynAccelerationStructure,
1010 )>,
1011) {
1012 unsafe {
1013 cmd_buf_raw.transition_buffers(&input_barriers);
1014 }
1015
1016 if blas_present {
1017 unsafe {
1018 cmd_buf_raw.place_acceleration_structure_barrier(hal::AccelerationStructureBarrier {
1019 usage: hal::StateTransition {
1020 from: hal::AccelerationStructureUses::BUILD_INPUT,
1021 to: hal::AccelerationStructureUses::BUILD_OUTPUT,
1022 },
1023 });
1024
1025 cmd_buf_raw.build_acceleration_structures(blas_descriptors);
1026 }
1027 }
1028
1029 if blas_present && tlas_present {
1030 unsafe {
1031 cmd_buf_raw.transition_buffers(&[scratch_buffer_barrier]);
1032 }
1033 }
1034
1035 let mut source_usage = hal::AccelerationStructureUses::empty();
1036 let mut destination_usage = hal::AccelerationStructureUses::empty();
1037 for &(buf, blas) in blas_s_for_compaction.iter() {
1038 unsafe {
1039 cmd_buf_raw.transition_buffers(&[hal::BufferBarrier {
1040 buffer: buf,
1041 usage: hal::StateTransition {
1042 from: BufferUses::ACCELERATION_STRUCTURE_QUERY,
1043 to: BufferUses::ACCELERATION_STRUCTURE_QUERY,
1044 },
1045 }])
1046 }
1047 unsafe { cmd_buf_raw.read_acceleration_structure_compact_size(blas, buf) }
1048 destination_usage |= hal::AccelerationStructureUses::COPY_SRC;
1049 }
1050
1051 if blas_present {
1052 source_usage |= hal::AccelerationStructureUses::BUILD_OUTPUT;
1053 destination_usage |= hal::AccelerationStructureUses::BUILD_INPUT
1054 }
1055 if tlas_present {
1056 source_usage |= hal::AccelerationStructureUses::SHADER_INPUT;
1057 destination_usage |= hal::AccelerationStructureUses::BUILD_OUTPUT;
1058 }
1059 unsafe {
1060 cmd_buf_raw.place_acceleration_structure_barrier(hal::AccelerationStructureBarrier {
1061 usage: hal::StateTransition {
1062 from: source_usage,
1063 to: destination_usage,
1064 },
1065 });
1066 }
1067}