1use alloc::{sync::Arc, vec::Vec};
2use core::{
3 cmp::max,
4 num::NonZeroU64,
5 ops::{Deref, Range},
6};
7
8use wgt::{math::align_to, BufferUsages, BufferUses, Features};
9
10use crate::{
11 command::encoder::EncodingState,
12 ray_tracing::{AsAction, AsBuild, TlasBuild, ValidateAsActionsError},
13 resource::InvalidResourceError,
14};
15use crate::{command::EncoderStateError, device::resource::CommandIndices};
16use crate::{
17 command::{ArcCommand, ArcReferences, CommandBufferMutable},
18 device::queue::TempResource,
19 global::Global,
20 id::CommandEncoderId,
21 init_tracker::MemoryInitKind,
22 ray_tracing::{
23 ArcBlasAabbGeometry, ArcBlasBuildEntry, ArcBlasGeometries, ArcBlasTriangleGeometry,
24 ArcTlasInstance, ArcTlasPackage, BlasBuildEntry, BlasGeometries,
25 BuildAccelerationStructureError, OwnedBlasBuildEntry, OwnedTlasPackage, TlasPackage,
26 },
27 resource::{Blas, BlasCompactState, Labeled, StagingBuffer, Tlas},
28 scratch::ScratchBuffer,
29 snatch::SnatchGuard,
30};
31use crate::{lock::RwLockWriteGuard, resource::RawResourceAccess};
32
33use crate::id::{BlasId, TlasId};
34
35struct BlasStore<'a> {
36 blas: Arc<Blas>,
37 entries: hal::AccelerationStructureEntries<'a, dyn hal::DynBuffer>,
38 scratch_buffer_offset: u64,
39}
40
41struct UnsafeTlasStore<'a> {
42 tlas: Arc<Tlas>,
43 entries: hal::AccelerationStructureEntries<'a, dyn hal::DynBuffer>,
44 scratch_buffer_offset: u64,
45}
46
47struct TlasStore<'a> {
48 internal: UnsafeTlasStore<'a>,
49 range: Range<usize>,
50}
51
52impl Global {
53 fn resolve_blas_id(&self, blas_id: BlasId) -> Result<Arc<Blas>, InvalidResourceError> {
54 self.hub.blas_s.get(blas_id).get()
55 }
56
57 fn resolve_tlas_id(&self, tlas_id: TlasId) -> Result<Arc<Tlas>, InvalidResourceError> {
58 self.hub.tlas_s.get(tlas_id).get()
59 }
60
61 pub fn command_encoder_mark_acceleration_structures_built(
62 &self,
63 command_encoder_id: CommandEncoderId,
64 blas_ids: &[BlasId],
65 tlas_ids: &[TlasId],
66 ) -> Result<(), EncoderStateError> {
67 profiling::scope!("CommandEncoder::mark_acceleration_structures_built");
68
69 let hub = &self.hub;
70
71 let cmd_enc = hub.command_encoders.get(command_encoder_id);
72
73 let mut cmd_buf_data = cmd_enc.data.lock();
74 cmd_buf_data.with_buffer(
75 crate::command::EncodingApi::Raw,
76 |cmd_buf_data| -> Result<(), BuildAccelerationStructureError> {
77 let device = &cmd_enc.device;
78 device.check_is_valid()?;
79 device.require_features(Features::EXPERIMENTAL_RAY_QUERY)?;
80
81 let mut build_command = AsBuild::with_capacity(blas_ids.len(), tlas_ids.len());
82
83 for blas in blas_ids {
84 let blas = hub.blas_s.get(*blas).get()?;
85 build_command.blas_s_built.push(blas);
86 }
87
88 for tlas in tlas_ids {
89 let tlas = hub.tlas_s.get(*tlas).get()?;
90 build_command.tlas_s_built.push(TlasBuild {
91 tlas,
92 dependencies: Vec::new(),
93 });
94 }
95
96 cmd_buf_data.as_actions.push(AsAction::Build(build_command));
97 Ok(())
98 },
99 )
100 }
101
102 pub fn command_encoder_build_acceleration_structures<'a>(
103 &self,
104 command_encoder_id: CommandEncoderId,
105 blas_iter: impl Iterator<Item = BlasBuildEntry<'a>>,
106 tlas_iter: impl Iterator<Item = TlasPackage<'a>>,
107 ) -> Result<(), EncoderStateError> {
108 profiling::scope!("CommandEncoder::build_acceleration_structures");
109
110 let hub = &self.hub;
111
112 let cmd_enc = hub.command_encoders.get(command_encoder_id);
113 let mut cmd_buf_data = cmd_enc.data.lock();
114
115 cmd_buf_data.push_with(|| -> Result<_, BuildAccelerationStructureError> {
116 let blas = blas_iter
117 .map(|blas_entry| {
118 let geometries = match blas_entry.geometries {
119 BlasGeometries::TriangleGeometries(triangle_geometries) => {
120 let tri_geo = triangle_geometries
121 .map(|tg| {
122 Ok(ArcBlasTriangleGeometry {
123 size: tg.size.clone(),
124 vertex_buffer: self.resolve_buffer_id(tg.vertex_buffer)?,
125 index_buffer: tg
126 .index_buffer
127 .map(|id| self.resolve_buffer_id(id))
128 .transpose()?,
129 transform_buffer: tg
130 .transform_buffer
131 .map(|id| self.resolve_buffer_id(id))
132 .transpose()?,
133 first_vertex: tg.first_vertex,
134 vertex_stride: tg.vertex_stride,
135 first_index: tg.first_index,
136 transform_buffer_offset: tg.transform_buffer_offset,
137 })
138 })
139 .collect::<Result<_, BuildAccelerationStructureError>>()?;
140 ArcBlasGeometries::TriangleGeometries(tri_geo)
141 }
142 BlasGeometries::AabbGeometries(aabb_geometries) => {
143 let aabb_geo = aabb_geometries
144 .map(|ag| {
145 Ok(ArcBlasAabbGeometry {
146 size: ag.size.clone(),
147 stride: ag.stride,
148 aabb_buffer: self.resolve_buffer_id(ag.aabb_buffer)?,
149 primitive_offset: ag.primitive_offset,
150 })
151 })
152 .collect::<Result<_, BuildAccelerationStructureError>>()?;
153 ArcBlasGeometries::AabbGeometries(aabb_geo)
154 }
155 };
156 Ok(ArcBlasBuildEntry {
157 blas: self.resolve_blas_id(blas_entry.blas_id)?,
158 geometries,
159 })
160 })
161 .collect::<Result<_, BuildAccelerationStructureError>>()?;
162
163 let tlas = tlas_iter
164 .map(|tlas_package| {
165 let instances = tlas_package
166 .instances
167 .map(|instance| {
168 instance
169 .as_ref()
170 .map(|instance| {
171 Ok(ArcTlasInstance {
172 blas: self.resolve_blas_id(instance.blas_id)?,
173 transform: *instance.transform,
174 custom_data: instance.custom_data,
175 mask: instance.mask,
176 })
177 })
178 .transpose()
179 })
180 .collect::<Result<_, BuildAccelerationStructureError>>()?;
181 Ok(ArcTlasPackage {
182 tlas: self.resolve_tlas_id(tlas_package.tlas_id)?,
183 instances,
184 lowest_unmodified: tlas_package.lowest_unmodified,
185 })
186 })
187 .collect::<Result<_, BuildAccelerationStructureError>>()?;
188
189 Ok(ArcCommand::BuildAccelerationStructures { blas, tlas })
190 })
191 }
192}
193
194pub(crate) fn build_acceleration_structures(
195 state: &mut EncodingState,
196 blas: Vec<OwnedBlasBuildEntry<ArcReferences>>,
197 tlas: Vec<OwnedTlasPackage<ArcReferences>>,
198) -> Result<(), BuildAccelerationStructureError> {
199 state
200 .device
201 .require_features(Features::EXPERIMENTAL_RAY_QUERY)?;
202
203 let mut build_command = AsBuild::with_capacity(blas.len(), tlas.len());
204 let mut input_barriers = Vec::<hal::BufferBarrier<dyn hal::DynBuffer>>::new();
205 let mut scratch_buffer_blas_size = 0;
206 let mut blas_storage = Vec::with_capacity(blas.len());
207 iter_blas(
208 blas.iter(),
209 &mut build_command,
210 &mut input_barriers,
211 &mut scratch_buffer_blas_size,
212 &mut blas_storage,
213 state,
214 )?;
215
216 let mut scratch_buffer_tlas_size = 0;
217 let mut tlas_storage = Vec::<TlasStore>::with_capacity(tlas.len());
218 let mut instance_buffer_staging_source = Vec::<u8>::new();
219
220 for package in tlas.iter() {
221 let tlas = &package.tlas;
222 state.tracker.tlas_s.insert_single(tlas.clone());
223
224 let scratch_buffer_offset = scratch_buffer_tlas_size;
225 assert!(
226 tlas.size_info.build_scratch_size
227 <= u64::MAX
228 - u64::from(state.device.alignments.ray_tracing_scratch_buffer_alignment)
229 );
230 scratch_buffer_tlas_size += align_to(
231 tlas.size_info.build_scratch_size,
232 u64::from(state.device.alignments.ray_tracing_scratch_buffer_alignment),
233 );
234
235 let first_byte_index = instance_buffer_staging_source.len();
236
237 let mut dependencies = Vec::new();
238
239 let mut instance_count = 0;
240 for instance in package.instances.iter().flatten() {
241 if instance.custom_data >= (1u32 << 24u32) {
242 return Err(BuildAccelerationStructureError::TlasInvalidCustomIndex(
243 tlas.error_ident(),
244 ));
245 }
246 let blas = &instance.blas;
247 state.tracker.blas_s.insert_single(blas.clone());
248
249 instance_buffer_staging_source.extend(state.device.raw().tlas_instance_to_bytes(
250 hal::TlasInstance {
251 transform: instance.transform,
252 custom_data: instance.custom_data,
253 mask: instance.mask,
254 blas_address: blas.handle,
255 },
256 ));
257
258 if tlas
259 .flags
260 .contains(wgpu_types::AccelerationStructureFlags::ALLOW_RAY_HIT_VERTEX_RETURN)
261 && !blas
262 .flags
263 .contains(wgpu_types::AccelerationStructureFlags::ALLOW_RAY_HIT_VERTEX_RETURN)
264 {
265 return Err(
266 BuildAccelerationStructureError::TlasDependentMissingVertexReturn(
267 tlas.error_ident(),
268 blas.error_ident(),
269 ),
270 );
271 }
272
273 instance_count += 1;
274
275 dependencies.push(blas.clone());
276 }
277
278 build_command.tlas_s_built.push(TlasBuild {
279 tlas: tlas.clone(),
280 dependencies,
281 });
282
283 if instance_count > tlas.max_instance_count {
284 return Err(BuildAccelerationStructureError::TlasInstanceCountExceeded(
285 tlas.error_ident(),
286 instance_count,
287 tlas.max_instance_count,
288 ));
289 }
290
291 tlas_storage.push(TlasStore {
292 internal: UnsafeTlasStore {
293 tlas: tlas.clone(),
294 entries: hal::AccelerationStructureEntries::Instances(
295 hal::AccelerationStructureInstances {
296 buffer: Some(tlas.instance_buffer.as_ref()),
297 offset: 0,
298 count: instance_count,
299 },
300 ),
301 scratch_buffer_offset,
302 },
303 range: first_byte_index..instance_buffer_staging_source.len(),
304 });
305 }
306
307 let Some(scratch_size) =
308 wgt::BufferSize::new(max(scratch_buffer_blas_size, scratch_buffer_tlas_size))
309 else {
310 return Ok(());
312 };
313
314 assert!(scratch_size.get() != u64::MAX);
315
316 let scratch_buffer = ScratchBuffer::new(state.device, scratch_size)?;
317
318 let scratch_buffer_barrier = hal::BufferBarrier::<dyn hal::DynBuffer> {
319 buffer: scratch_buffer.raw(),
320 usage: hal::StateTransition {
321 from: BufferUses::ACCELERATION_STRUCTURE_SCRATCH,
322 to: BufferUses::ACCELERATION_STRUCTURE_SCRATCH,
323 },
324 };
325
326 let mut tlas_descriptors = Vec::with_capacity(tlas_storage.len());
327
328 for &TlasStore {
329 internal:
330 UnsafeTlasStore {
331 ref tlas,
332 ref entries,
333 ref scratch_buffer_offset,
334 },
335 ..
336 } in &tlas_storage
337 {
338 if tlas.update_mode == wgt::AccelerationStructureUpdateMode::PreferUpdate {
339 log::warn!("build_acceleration_structures called with PreferUpdate, but only rebuild is implemented");
340 }
341 tlas_descriptors.push(hal::BuildAccelerationStructureDescriptor {
342 entries,
343 mode: hal::AccelerationStructureBuildMode::Build,
344 flags: tlas.flags,
345 source_acceleration_structure: None,
346 destination_acceleration_structure: tlas.try_raw(state.snatch_guard)?,
347 scratch_buffer: scratch_buffer.raw(),
348 scratch_buffer_offset: *scratch_buffer_offset,
349 })
350 }
351
352 let blas_present = !blas_storage.is_empty();
353 let tlas_present = !tlas_storage.is_empty();
354
355 let raw_encoder = &mut state.raw_encoder;
356
357 let mut blas_s_compactable = Vec::new();
358 let mut descriptors = Vec::with_capacity(blas.len());
359
360 for storage in &blas_storage {
361 descriptors.push(map_blas(
362 storage,
363 scratch_buffer.raw(),
364 state.snatch_guard,
365 &mut blas_s_compactable,
366 )?);
367 }
368
369 build_blas(
370 *raw_encoder,
371 blas_present,
372 tlas_present,
373 input_barriers,
374 &descriptors,
375 scratch_buffer_barrier,
376 blas_s_compactable,
377 );
378
379 if tlas_present {
380 let staging_buffer = if !instance_buffer_staging_source.is_empty() {
381 let mut staging_buffer = StagingBuffer::new(
382 state.device,
383 wgt::BufferSize::new(instance_buffer_staging_source.len() as u64).unwrap(),
384 )?;
385 staging_buffer.write(&instance_buffer_staging_source);
386 let flushed = staging_buffer.flush();
387 Some(flushed)
388 } else {
389 None
390 };
391
392 unsafe {
393 if let Some(ref staging_buffer) = staging_buffer {
394 raw_encoder.transition_buffers(&[hal::BufferBarrier::<dyn hal::DynBuffer> {
395 buffer: staging_buffer.raw(),
396 usage: hal::StateTransition {
397 from: BufferUses::MAP_WRITE,
398 to: BufferUses::COPY_SRC,
399 },
400 }]);
401 }
402 }
403
404 let mut instance_buffer_barriers = Vec::new();
405 for &TlasStore {
406 internal: UnsafeTlasStore { ref tlas, .. },
407 ref range,
408 } in &tlas_storage
409 {
410 let size = match wgt::BufferSize::new((range.end - range.start) as u64) {
411 None => continue,
412 Some(size) => size,
413 };
414 instance_buffer_barriers.push(hal::BufferBarrier::<dyn hal::DynBuffer> {
415 buffer: tlas.instance_buffer.as_ref(),
416 usage: hal::StateTransition {
417 from: BufferUses::COPY_DST,
418 to: BufferUses::TOP_LEVEL_ACCELERATION_STRUCTURE_INPUT,
419 },
420 });
421 unsafe {
422 raw_encoder.transition_buffers(&[hal::BufferBarrier::<dyn hal::DynBuffer> {
423 buffer: tlas.instance_buffer.as_ref(),
424 usage: hal::StateTransition {
425 from: BufferUses::TOP_LEVEL_ACCELERATION_STRUCTURE_INPUT,
426 to: BufferUses::COPY_DST,
427 },
428 }]);
429 let temp = hal::BufferCopy {
430 src_offset: range.start as u64,
431 dst_offset: 0,
432 size,
433 };
434 raw_encoder.copy_buffer_to_buffer(
435 staging_buffer.as_ref().unwrap().raw(),
436 tlas.instance_buffer.as_ref(),
437 &[temp],
438 );
439 }
440 }
441
442 unsafe {
443 raw_encoder.transition_buffers(&instance_buffer_barriers);
444
445 raw_encoder.build_acceleration_structures(&tlas_descriptors);
446
447 raw_encoder.place_acceleration_structure_barrier(hal::AccelerationStructureBarrier {
448 usage: hal::StateTransition {
449 from: hal::AccelerationStructureUses::BUILD_OUTPUT,
450 to: hal::AccelerationStructureUses::SHADER_INPUT,
451 },
452 });
453 }
454
455 if let Some(staging_buffer) = staging_buffer {
456 state
457 .temp_resources
458 .push(TempResource::StagingBuffer(staging_buffer));
459 }
460 }
461
462 state
463 .temp_resources
464 .push(TempResource::ScratchBuffer(scratch_buffer));
465
466 state.as_actions.push(AsAction::Build(build_command));
467
468 Ok(())
469}
470
471impl CommandBufferMutable {
472 pub(crate) fn validate_acceleration_structure_actions(
473 &self,
474 snatch_guard: &SnatchGuard,
475 command_index_guard: &mut RwLockWriteGuard<CommandIndices>,
476 ) -> Result<(), ValidateAsActionsError> {
477 profiling::scope!("CommandEncoder::[submission]::validate_as_actions");
478 for action in &self.as_actions {
479 match action {
480 AsAction::Build(build) => {
481 let build_command_index = NonZeroU64::new(
482 command_index_guard.next_acceleration_structure_build_command_index,
483 )
484 .unwrap();
485
486 command_index_guard.next_acceleration_structure_build_command_index += 1;
487 for blas in build.blas_s_built.iter() {
488 let mut state_lock = blas.compacted_state.lock();
489 *state_lock = match *state_lock {
490 BlasCompactState::Compacted => {
491 unreachable!("Should be validated out in build.")
492 }
493 _ => BlasCompactState::Idle,
496 };
497 *blas.built_index.write() = Some(build_command_index);
498 }
499
500 for tlas_build in build.tlas_s_built.iter() {
501 for blas in &tlas_build.dependencies {
502 if blas.built_index.read().is_none() {
503 return Err(ValidateAsActionsError::UsedUnbuiltBlas(
504 blas.error_ident(),
505 tlas_build.tlas.error_ident(),
506 ));
507 }
508 }
509 *tlas_build.tlas.built_index.write() = Some(build_command_index);
510 tlas_build
511 .tlas
512 .dependencies
513 .write()
514 .clone_from(&tlas_build.dependencies)
515 }
516 }
517 AsAction::UseTlas(tlas) => {
518 let tlas_build_index = tlas.built_index.read();
519 let dependencies = tlas.dependencies.read();
520
521 if (*tlas_build_index).is_none() {
522 return Err(ValidateAsActionsError::UsedUnbuiltTlas(tlas.error_ident()));
523 }
524 for blas in dependencies.deref() {
525 let blas_build_index = *blas.built_index.read();
526 if blas_build_index.is_none() {
527 return Err(ValidateAsActionsError::UsedUnbuiltBlas(
528 tlas.error_ident(),
529 blas.error_ident(),
530 ));
531 }
532 if blas_build_index.unwrap() > tlas_build_index.unwrap() {
533 return Err(ValidateAsActionsError::BlasNewerThenTlas(
534 blas.error_ident(),
535 tlas.error_ident(),
536 ));
537 }
538 blas.try_raw(snatch_guard)?;
539 }
540 }
541 }
542 }
543 Ok(())
544 }
545
546 pub(crate) fn set_acceleration_structure_dependencies(&self, snatch_guard: &SnatchGuard) {
547 profiling::scope!("CommandEncoder::[submission]::set_acceleration_structure_dependencies");
548 let tlas_dependencies_locks: Vec<_> = self
549 .as_actions
550 .iter()
551 .filter_map(|action| {
552 if let AsAction::UseTlas(tlas) = action {
553 Some(tlas.dependencies.read())
554 } else {
555 None
556 }
557 })
558 .collect();
559 let mut tlas_dependencies_lock_iter = tlas_dependencies_locks.iter();
560 let mut dependencies = Vec::new();
561 for action in &self.as_actions {
562 match action {
563 AsAction::Build(build) => {
564 for tlas_build in build.tlas_s_built.iter() {
565 for dependency in &tlas_build.dependencies {
566 if let Some(dependency) = dependency.raw(snatch_guard) {
567 dependencies.push(dependency);
568 }
569 }
570 }
571 }
572 AsAction::UseTlas(_tlas) => {
573 let tlas_dependencies = tlas_dependencies_lock_iter.next().unwrap(); for dependency in tlas_dependencies.iter() {
575 if let Some(dependency) = dependency.raw(snatch_guard) {
576 dependencies.push(dependency);
577 }
578 }
579 }
580 }
581 }
582 if !dependencies.is_empty() {
583 unsafe {
584 self.encoder
585 .raw
586 .set_acceleration_structure_dependencies(&self.encoder.list, &dependencies);
587 }
588 }
589 }
590}
591
592fn iter_blas<'snatch_guard: 'buffers, 'buffers>(
594 blas_iter: impl Iterator<Item = &'buffers OwnedBlasBuildEntry<ArcReferences>>,
595 build_command: &mut AsBuild,
596 input_barriers: &mut Vec<hal::BufferBarrier<'buffers, dyn hal::DynBuffer>>,
597 scratch_buffer_blas_size: &mut u64,
598 blas_storage: &mut Vec<BlasStore<'buffers>>,
599 state: &mut EncodingState<'snatch_guard, '_>,
600) -> Result<(), BuildAccelerationStructureError> {
601 for entry in blas_iter {
602 let blas = &entry.blas;
603 state.tracker.blas_s.insert_single(blas.clone());
604
605 build_command.blas_s_built.push(blas.clone());
606
607 match &entry.geometries {
608 ArcBlasGeometries::TriangleGeometries(triangle_geometries) => {
609 let mut triangle_entries =
610 Vec::<hal::AccelerationStructureTriangles<dyn hal::DynBuffer>>::new();
611
612 for (i, mesh) in triangle_geometries.iter().enumerate() {
613 let size_desc = match &blas.sizes {
614 wgt::BlasGeometrySizeDescriptors::Triangles { descriptors } => descriptors,
615 _ => {
616 return Err(BuildAccelerationStructureError::BlasGeometryKindMismatch(
617 blas.error_ident(),
618 ));
619 }
620 };
621 if i >= size_desc.len() {
622 return Err(BuildAccelerationStructureError::IncompatibleBlasBuildSizes(
623 blas.error_ident(),
624 ));
625 }
626 let size_desc = &size_desc[i];
627
628 if size_desc.flags != mesh.size.flags {
629 return Err(BuildAccelerationStructureError::IncompatibleBlasFlags(
630 blas.error_ident(),
631 size_desc.flags,
632 mesh.size.flags,
633 ));
634 }
635
636 if size_desc.vertex_count < mesh.size.vertex_count {
637 return Err(
638 BuildAccelerationStructureError::IncompatibleBlasVertexCount(
639 blas.error_ident(),
640 size_desc.vertex_count,
641 mesh.size.vertex_count,
642 ),
643 );
644 }
645
646 if size_desc.vertex_format != mesh.size.vertex_format {
647 return Err(BuildAccelerationStructureError::DifferentBlasVertexFormats(
648 blas.error_ident(),
649 size_desc.vertex_format,
650 mesh.size.vertex_format,
651 ));
652 }
653
654 if size_desc
655 .vertex_format
656 .min_acceleration_structure_vertex_stride()
657 > mesh.vertex_stride
658 {
659 return Err(BuildAccelerationStructureError::VertexStrideTooSmall(
660 blas.error_ident(),
661 size_desc
662 .vertex_format
663 .min_acceleration_structure_vertex_stride(),
664 mesh.vertex_stride,
665 ));
666 }
667
668 if mesh.vertex_stride
669 % size_desc
670 .vertex_format
671 .acceleration_structure_stride_alignment()
672 != 0
673 {
674 return Err(BuildAccelerationStructureError::VertexStrideUnaligned(
675 blas.error_ident(),
676 size_desc
677 .vertex_format
678 .acceleration_structure_stride_alignment(),
679 mesh.vertex_stride,
680 ));
681 }
682
683 match (size_desc.index_count, mesh.size.index_count) {
684 (Some(_), None) | (None, Some(_)) => {
685 return Err(
686 BuildAccelerationStructureError::BlasIndexCountProvidedMismatch(
687 blas.error_ident(),
688 ),
689 )
690 }
691 (Some(create), Some(build)) if create < build => {
692 return Err(
693 BuildAccelerationStructureError::IncompatibleBlasIndexCount(
694 blas.error_ident(),
695 create,
696 build,
697 ),
698 )
699 }
700 _ => {}
701 }
702
703 if size_desc.index_format != mesh.size.index_format {
704 return Err(BuildAccelerationStructureError::DifferentBlasIndexFormats(
705 blas.error_ident(),
706 size_desc.index_format,
707 mesh.size.index_format,
708 ));
709 }
710
711 if size_desc.index_count.is_some() && mesh.index_buffer.is_none() {
712 return Err(BuildAccelerationStructureError::MissingIndexBuffer(
713 blas.error_ident(),
714 ));
715 }
716 let vertex_buffer = mesh.vertex_buffer.clone();
717 let vertex_pending = state.tracker.buffers.set_single(
718 &vertex_buffer,
719 BufferUses::BOTTOM_LEVEL_ACCELERATION_STRUCTURE_INPUT,
720 );
721 let vertex_buffer = {
722 let vertex_raw = mesh.vertex_buffer.as_ref().try_raw(state.snatch_guard)?;
723 let vertex_buffer = &mesh.vertex_buffer;
724 vertex_buffer.check_usage(BufferUsages::BLAS_INPUT)?;
725
726 if let Some(barrier) = vertex_pending.map(|pending| {
727 pending.into_hal(vertex_buffer.as_ref(), state.snatch_guard)
728 }) {
729 input_barriers.push(barrier);
730 }
731 if u64::from(mesh.size.vertex_count)
732 .checked_add(u64::from(mesh.first_vertex))
733 .and_then(|end_vertex| end_vertex.checked_mul(mesh.vertex_stride))
734 .is_none_or(|end| vertex_buffer.size < end)
735 {
736 return Err(BuildAccelerationStructureError::InsufficientBufferSize {
737 buffer_ident: vertex_buffer.error_ident(),
738 offset: u64::from(mesh.first_vertex)
739 .saturating_mul(mesh.vertex_stride),
740 region_size: u64::from(mesh.size.vertex_count)
741 .saturating_mul(mesh.vertex_stride),
742 buffer_size: vertex_buffer.size,
743 });
744 }
745 let vertex_buffer_offset = mesh.first_vertex as u64 * mesh.vertex_stride;
746 state.buffer_memory_init_actions.extend(
747 vertex_buffer.initialization_status.read().create_action(
748 vertex_buffer,
749 vertex_buffer_offset
750 ..(vertex_buffer_offset
751 + mesh.size.vertex_count as u64 * mesh.vertex_stride),
752 MemoryInitKind::NeedsInitializedMemory,
753 ),
754 );
755 vertex_raw
756 };
757 let index_buffer = if let Some(ref index_buffer) = mesh.index_buffer {
758 if mesh.first_index.is_none()
759 || mesh.size.index_count.is_none()
760 || mesh.size.index_count.is_none()
761 {
762 return Err(BuildAccelerationStructureError::MissingAssociatedData(
763 index_buffer.error_ident(),
764 ));
765 }
766 let index_pending = state.tracker.buffers.set_single(
767 index_buffer,
768 BufferUses::BOTTOM_LEVEL_ACCELERATION_STRUCTURE_INPUT,
769 );
770 let index_raw = index_buffer.try_raw(state.snatch_guard)?;
771 index_buffer.check_usage(BufferUsages::BLAS_INPUT)?;
772
773 if let Some(barrier) = index_pending.map(|pending| {
774 pending.into_hal(index_buffer.as_ref(), state.snatch_guard)
775 }) {
776 input_barriers.push(barrier);
777 }
778 let index_stride = mesh.size.index_format.unwrap().byte_size() as u32;
779 let vertex_offset = mesh.first_index.unwrap().saturating_mul(index_stride);
781 let indexes_size =
782 mesh.size.index_count.unwrap().saturating_mul(index_stride);
783
784 if mesh.size.index_count.unwrap() % 3 != 0 {
785 return Err(BuildAccelerationStructureError::InvalidIndexCount(
786 index_buffer.error_ident(),
787 mesh.size.index_count.unwrap(),
788 ));
789 }
790 if indexes_size == u32::MAX
791 || vertex_offset == u32::MAX
792 || index_buffer.size < u64::from(vertex_offset)
793 || index_buffer.size - u64::from(vertex_offset)
794 < u64::from(indexes_size)
795 {
796 return Err(BuildAccelerationStructureError::InsufficientBufferSize {
797 buffer_ident: index_buffer.error_ident(),
798 offset: u64::from(vertex_offset),
799 region_size: u64::from(indexes_size),
800 buffer_size: index_buffer.size,
801 });
802 }
803
804 state.buffer_memory_init_actions.extend(
805 index_buffer.initialization_status.read().create_action(
806 index_buffer,
807 u64::from(vertex_offset)
808 ..(u64::from(vertex_offset) + u64::from(indexes_size)),
809 MemoryInitKind::NeedsInitializedMemory,
810 ),
811 );
812 Some((index_raw, vertex_offset))
813 } else {
814 None
815 };
816 let transform_buffer = if let Some(ref transform_buffer) = mesh.transform_buffer
817 {
818 if !blas
819 .flags
820 .contains(wgt::AccelerationStructureFlags::USE_TRANSFORM)
821 {
822 return Err(BuildAccelerationStructureError::UseTransformMissing(
823 blas.error_ident(),
824 ));
825 }
826 if mesh.transform_buffer_offset.is_none() {
827 return Err(BuildAccelerationStructureError::MissingAssociatedData(
828 transform_buffer.error_ident(),
829 ));
830 }
831 let transform_pending = state.tracker.buffers.set_single(
832 transform_buffer,
833 BufferUses::BOTTOM_LEVEL_ACCELERATION_STRUCTURE_INPUT,
834 );
835 if mesh.transform_buffer_offset.is_none() {
836 return Err(BuildAccelerationStructureError::MissingAssociatedData(
837 transform_buffer.error_ident(),
838 ));
839 }
840 let transform_raw = transform_buffer.try_raw(state.snatch_guard)?;
841 transform_buffer.check_usage(BufferUsages::BLAS_INPUT)?;
842
843 if let Some(barrier) = transform_pending.map(|pending| {
844 pending.into_hal(transform_buffer.as_ref(), state.snatch_guard)
845 }) {
846 input_barriers.push(barrier);
847 }
848
849 let offset = u32::try_from(mesh.transform_buffer_offset.unwrap())
851 .unwrap_or(u32::MAX);
852
853 if offset % wgt::TRANSFORM_BUFFER_ALIGNMENT as u32 != 0 {
854 return Err(
855 BuildAccelerationStructureError::UnalignedTransformBufferOffset(
856 transform_buffer.error_ident(),
857 ),
858 );
859 }
860 if offset == u32::MAX
861 || transform_buffer.size < 48
862 || transform_buffer.size - 48 < u64::from(offset)
863 {
864 return Err(BuildAccelerationStructureError::InsufficientBufferSize {
865 buffer_ident: transform_buffer.error_ident(),
866 region_size: 48,
867 offset: u64::from(offset),
868 buffer_size: transform_buffer.size,
869 });
870 }
871 state.buffer_memory_init_actions.extend(
872 transform_buffer.initialization_status.read().create_action(
873 transform_buffer,
874 u64::from(offset)..(u64::from(offset) + 48),
875 MemoryInitKind::NeedsInitializedMemory,
876 ),
877 );
878 Some((transform_raw, offset))
879 } else {
880 if blas
881 .flags
882 .contains(wgt::AccelerationStructureFlags::USE_TRANSFORM)
883 {
884 return Err(BuildAccelerationStructureError::TransformMissing(
885 blas.error_ident(),
886 ));
887 }
888 None
889 };
890
891 let triangles = hal::AccelerationStructureTriangles {
892 vertex_buffer: Some(vertex_buffer),
893 vertex_format: mesh.size.vertex_format,
894 first_vertex: mesh.first_vertex,
895 vertex_count: mesh.size.vertex_count,
896 vertex_stride: mesh.vertex_stride,
897 indices: index_buffer.map(|(index_raw, vertex_offset)| {
898 hal::AccelerationStructureTriangleIndices::<dyn hal::DynBuffer> {
899 format: mesh.size.index_format.unwrap(),
900 buffer: Some(index_raw),
901 offset: vertex_offset,
902 count: mesh.size.index_count.unwrap(),
903 }
904 }),
905 transform: transform_buffer.map(|(transform_raw, offset)| {
906 hal::AccelerationStructureTriangleTransform {
907 buffer: transform_raw,
908 offset,
909 }
910 }),
911 flags: mesh.size.flags,
912 };
913 triangle_entries.push(triangles);
914 }
915
916 {
917 let scratch_buffer_offset = *scratch_buffer_blas_size;
918 assert!(
919 blas.size_info.build_scratch_size
920 <= u64::MAX
921 - u64::from(
922 state.device.alignments.ray_tracing_scratch_buffer_alignment
923 )
924 );
925 *scratch_buffer_blas_size = scratch_buffer_blas_size.saturating_add(align_to(
926 blas.size_info.build_scratch_size,
927 u64::from(state.device.alignments.ray_tracing_scratch_buffer_alignment),
928 ));
929
930 blas_storage.push(BlasStore {
931 blas: blas.clone(),
932 entries: hal::AccelerationStructureEntries::Triangles(triangle_entries),
933 scratch_buffer_offset,
934 });
935 }
936 }
937 ArcBlasGeometries::AabbGeometries(aabb_geometries) => {
938 let mut aabb_entries =
939 Vec::<hal::AccelerationStructureAABBs<dyn hal::DynBuffer>>::new();
940
941 for (i, aabb) in aabb_geometries.iter().enumerate() {
942 let size_desc = match &blas.sizes {
943 wgt::BlasGeometrySizeDescriptors::AABBs { descriptors } => descriptors,
944 _ => {
945 return Err(BuildAccelerationStructureError::BlasGeometryKindMismatch(
946 blas.error_ident(),
947 ));
948 }
949 };
950 if i >= size_desc.len() {
951 return Err(BuildAccelerationStructureError::IncompatibleBlasBuildSizes(
952 blas.error_ident(),
953 ));
954 }
955 let size_desc = &size_desc[i];
956
957 if size_desc.flags != aabb.size.flags {
958 return Err(BuildAccelerationStructureError::IncompatibleBlasFlags(
959 blas.error_ident(),
960 size_desc.flags,
961 aabb.size.flags,
962 ));
963 }
964
965 if size_desc.primitive_count < aabb.size.primitive_count {
966 return Err(
967 BuildAccelerationStructureError::IncompatibleBlasAabbPrimitiveCount(
968 blas.error_ident(),
969 size_desc.primitive_count,
970 aabb.size.primitive_count,
971 ),
972 );
973 }
974
975 if aabb.primitive_offset % 8 != 0 {
976 return Err(
977 BuildAccelerationStructureError::UnalignedAabbPrimitiveOffset(
978 blas.error_ident(),
979 ),
980 );
981 }
982
983 if aabb.stride < wgt::AABB_GEOMETRY_MIN_STRIDE || aabb.stride % 8 != 0 {
984 return Err(BuildAccelerationStructureError::InvalidAabbStride(
985 blas.error_ident(),
986 aabb.stride,
987 ));
988 }
989
990 let aabb_buffer = aabb.aabb_buffer.clone();
991 let aabb_pending = state.tracker.buffers.set_single(
992 &aabb_buffer,
993 BufferUses::BOTTOM_LEVEL_ACCELERATION_STRUCTURE_INPUT,
994 );
995 let aabb_raw = {
996 let aabb_raw = aabb.aabb_buffer.as_ref().try_raw(state.snatch_guard)?;
997 let aabb_buffer = &aabb.aabb_buffer;
998 aabb_buffer.check_usage(BufferUsages::BLAS_INPUT)?;
999
1000 if let Some(barrier) = aabb_pending.map(|pending| {
1001 pending.into_hal(aabb_buffer.as_ref(), state.snatch_guard)
1002 }) {
1003 input_barriers.push(barrier);
1004 }
1005
1006 let aabb_size = aabb
1008 .size
1009 .primitive_count
1010 .saturating_mul(u32::try_from(aabb.stride).unwrap());
1011 if aabb_size == u32::MAX
1012 || aabb_buffer.size < u64::from(aabb.primitive_offset)
1013 || aabb_buffer.size - u64::from(aabb.primitive_offset)
1014 < u64::from(aabb_size)
1015 {
1016 return Err(BuildAccelerationStructureError::InsufficientBufferSize {
1017 buffer_ident: aabb_buffer.error_ident(),
1018 region_size: u64::from(aabb_size),
1019 offset: u64::from(aabb.primitive_offset),
1020 buffer_size: aabb_buffer.size,
1021 });
1022 }
1023
1024 state.buffer_memory_init_actions.extend(
1025 aabb_buffer.initialization_status.read().create_action(
1026 aabb_buffer,
1027 u64::from(aabb.primitive_offset)
1028 ..u64::from(aabb.primitive_offset) + u64::from(aabb_size),
1029 MemoryInitKind::NeedsInitializedMemory,
1030 ),
1031 );
1032 aabb_raw
1033 };
1034
1035 aabb_entries.push(hal::AccelerationStructureAABBs {
1036 buffer: Some(aabb_raw),
1037 offset: aabb.primitive_offset,
1038 count: aabb.size.primitive_count,
1039 stride: aabb.stride,
1040 flags: aabb.size.flags,
1041 });
1042 }
1043
1044 {
1045 let scratch_buffer_offset = *scratch_buffer_blas_size;
1046 *scratch_buffer_blas_size = scratch_buffer_blas_size.saturating_add(align_to(
1047 blas.size_info.build_scratch_size,
1048 u64::from(state.device.alignments.ray_tracing_scratch_buffer_alignment),
1049 ));
1050
1051 blas_storage.push(BlasStore {
1052 blas: blas.clone(),
1053 entries: hal::AccelerationStructureEntries::AABBs(aabb_entries),
1054 scratch_buffer_offset,
1055 });
1056 }
1057 }
1058 }
1059 }
1060 Ok(())
1061}
1062
1063fn map_blas<'a>(
1064 storage: &'a BlasStore<'_>,
1065 scratch_buffer: &'a dyn hal::DynBuffer,
1066 snatch_guard: &'a SnatchGuard,
1067 blases_compactable: &mut Vec<(
1068 &'a dyn hal::DynBuffer,
1069 &'a dyn hal::DynAccelerationStructure,
1070 )>,
1071) -> Result<
1072 hal::BuildAccelerationStructureDescriptor<
1073 'a,
1074 dyn hal::DynBuffer,
1075 dyn hal::DynAccelerationStructure,
1076 >,
1077 BuildAccelerationStructureError,
1078> {
1079 let BlasStore {
1080 blas,
1081 entries,
1082 scratch_buffer_offset,
1083 } = storage;
1084 if blas.update_mode == wgt::AccelerationStructureUpdateMode::PreferUpdate {
1085 log::debug!("only rebuild implemented")
1086 }
1087 let raw = blas.try_raw(snatch_guard)?;
1088
1089 let state_lock = blas.compacted_state.lock();
1090 if let BlasCompactState::Compacted = *state_lock {
1091 return Err(BuildAccelerationStructureError::CompactedBlas(
1092 blas.error_ident(),
1093 ));
1094 }
1095
1096 if blas
1097 .flags
1098 .contains(wgpu_types::AccelerationStructureFlags::ALLOW_COMPACTION)
1099 {
1100 blases_compactable.push((blas.compaction_buffer.as_ref().unwrap().as_ref(), raw));
1101 }
1102 Ok(hal::BuildAccelerationStructureDescriptor {
1103 entries,
1104 mode: hal::AccelerationStructureBuildMode::Build,
1105 flags: blas.flags,
1106 source_acceleration_structure: None,
1107 destination_acceleration_structure: raw,
1108 scratch_buffer,
1109 scratch_buffer_offset: *scratch_buffer_offset,
1110 })
1111}
1112
1113fn build_blas<'a>(
1114 cmd_buf_raw: &mut dyn hal::DynCommandEncoder,
1115 blas_present: bool,
1116 tlas_present: bool,
1117 input_barriers: Vec<hal::BufferBarrier<dyn hal::DynBuffer>>,
1118 blas_descriptors: &[hal::BuildAccelerationStructureDescriptor<
1119 'a,
1120 dyn hal::DynBuffer,
1121 dyn hal::DynAccelerationStructure,
1122 >],
1123 scratch_buffer_barrier: hal::BufferBarrier<dyn hal::DynBuffer>,
1124 blas_s_for_compaction: Vec<(
1125 &'a dyn hal::DynBuffer,
1126 &'a dyn hal::DynAccelerationStructure,
1127 )>,
1128) {
1129 unsafe {
1130 cmd_buf_raw.transition_buffers(&input_barriers);
1131 }
1132
1133 if blas_present {
1134 unsafe {
1135 cmd_buf_raw.place_acceleration_structure_barrier(hal::AccelerationStructureBarrier {
1136 usage: hal::StateTransition {
1137 from: hal::AccelerationStructureUses::BUILD_INPUT,
1138 to: hal::AccelerationStructureUses::BUILD_OUTPUT,
1139 },
1140 });
1141
1142 cmd_buf_raw.build_acceleration_structures(blas_descriptors);
1143 }
1144 }
1145
1146 if blas_present && tlas_present {
1147 unsafe {
1148 cmd_buf_raw.transition_buffers(&[scratch_buffer_barrier]);
1149 }
1150 }
1151
1152 let mut source_usage = hal::AccelerationStructureUses::empty();
1153 let mut destination_usage = hal::AccelerationStructureUses::empty();
1154 for &(buf, blas) in blas_s_for_compaction.iter() {
1155 unsafe {
1156 cmd_buf_raw.transition_buffers(&[hal::BufferBarrier {
1157 buffer: buf,
1158 usage: hal::StateTransition {
1159 from: BufferUses::ACCELERATION_STRUCTURE_QUERY,
1160 to: BufferUses::ACCELERATION_STRUCTURE_QUERY,
1161 },
1162 }])
1163 }
1164 unsafe { cmd_buf_raw.read_acceleration_structure_compact_size(blas, buf) }
1165 destination_usage |= hal::AccelerationStructureUses::COPY_SRC;
1166 }
1167
1168 if blas_present {
1169 source_usage |= hal::AccelerationStructureUses::BUILD_OUTPUT;
1170 destination_usage |= hal::AccelerationStructureUses::BUILD_INPUT
1171 }
1172 if tlas_present {
1173 source_usage |= hal::AccelerationStructureUses::SHADER_INPUT;
1174 destination_usage |= hal::AccelerationStructureUses::BUILD_OUTPUT;
1175 }
1176 unsafe {
1177 cmd_buf_raw.place_acceleration_structure_barrier(hal::AccelerationStructureBarrier {
1178 usage: hal::StateTransition {
1179 from: source_usage,
1180 to: destination_usage,
1181 },
1182 });
1183 }
1184}