1use alloc::{sync::Arc, vec::Vec};
2use core::{
3 cmp::max,
4 num::NonZeroU64,
5 ops::{Deref, Range},
6};
7
8use wgt::{math::align_to, BufferUsages, BufferUses, Features};
9
10use crate::{
11 command::encoder::EncodingState,
12 ray_tracing::{AsAction, AsBuild, TlasBuild, ValidateAsActionsError},
13 resource::InvalidResourceError,
14};
15use crate::{command::EncoderStateError, device::resource::CommandIndices};
16use crate::{
17 command::{ArcCommand, ArcReferences, CommandBufferMutable},
18 device::queue::TempResource,
19 global::Global,
20 id::CommandEncoderId,
21 init_tracker::MemoryInitKind,
22 ray_tracing::{
23 ArcBlasAabbGeometry, ArcBlasBuildEntry, ArcBlasGeometries, ArcBlasTriangleGeometry,
24 ArcTlasInstance, ArcTlasPackage, BlasBuildEntry, BlasGeometries,
25 BuildAccelerationStructureError, OwnedBlasBuildEntry, OwnedTlasPackage, TlasPackage,
26 },
27 resource::{Blas, BlasCompactState, Labeled, StagingBuffer, Tlas},
28 scratch::ScratchBuffer,
29 snatch::SnatchGuard,
30};
31use crate::{lock::RwLockWriteGuard, resource::RawResourceAccess};
32
33use crate::id::{BlasId, TlasId};
34
35struct BlasStore<'a> {
36 blas: Arc<Blas>,
37 entries: hal::AccelerationStructureEntries<'a, dyn hal::DynBuffer>,
38 scratch_buffer_offset: u64,
39}
40
41struct UnsafeTlasStore<'a> {
42 tlas: Arc<Tlas>,
43 entries: hal::AccelerationStructureEntries<'a, dyn hal::DynBuffer>,
44 scratch_buffer_offset: u64,
45}
46
47struct TlasStore<'a> {
48 internal: UnsafeTlasStore<'a>,
49 range: Range<usize>,
50}
51
52impl Global {
53 fn resolve_blas_id(&self, blas_id: BlasId) -> Result<Arc<Blas>, InvalidResourceError> {
54 self.hub.blas_s.get(blas_id).get()
55 }
56
57 fn resolve_tlas_id(&self, tlas_id: TlasId) -> Result<Arc<Tlas>, InvalidResourceError> {
58 self.hub.tlas_s.get(tlas_id).get()
59 }
60
61 pub fn command_encoder_mark_acceleration_structures_built(
62 &self,
63 command_encoder_id: CommandEncoderId,
64 blas_ids: &[BlasId],
65 tlas_ids: &[TlasId],
66 ) -> Result<(), EncoderStateError> {
67 profiling::scope!("CommandEncoder::mark_acceleration_structures_built");
68
69 let hub = &self.hub;
70
71 let cmd_enc = hub.command_encoders.get(command_encoder_id);
72
73 let mut cmd_buf_data = cmd_enc.data.lock();
74 cmd_buf_data.with_buffer(
75 crate::command::EncodingApi::Raw,
76 |cmd_buf_data| -> Result<(), BuildAccelerationStructureError> {
77 let device = &cmd_enc.device;
78 device.check_is_valid()?;
79 device.require_features(Features::EXPERIMENTAL_RAY_QUERY)?;
80
81 let mut build_command = AsBuild::with_capacity(blas_ids.len(), tlas_ids.len());
82
83 for blas in blas_ids {
84 let blas = hub.blas_s.get(*blas).get()?;
85 build_command.blas_s_built.push(blas);
86 }
87
88 for tlas in tlas_ids {
89 let tlas = hub.tlas_s.get(*tlas).get()?;
90 build_command.tlas_s_built.push(TlasBuild {
91 tlas,
92 dependencies: Vec::new(),
93 });
94 }
95
96 cmd_buf_data.as_actions.push(AsAction::Build(build_command));
97 Ok(())
98 },
99 )
100 }
101
102 pub fn command_encoder_build_acceleration_structures<'a>(
103 &self,
104 command_encoder_id: CommandEncoderId,
105 blas_iter: impl Iterator<Item = BlasBuildEntry<'a>>,
106 tlas_iter: impl Iterator<Item = TlasPackage<'a>>,
107 ) -> Result<(), EncoderStateError> {
108 profiling::scope!("CommandEncoder::build_acceleration_structures");
109
110 let hub = &self.hub;
111
112 let cmd_enc = hub.command_encoders.get(command_encoder_id);
113 let mut cmd_buf_data = cmd_enc.data.lock();
114
115 cmd_buf_data.push_with(|| -> Result<_, BuildAccelerationStructureError> {
116 let blas = blas_iter
117 .map(|blas_entry| {
118 let geometries = match blas_entry.geometries {
119 BlasGeometries::TriangleGeometries(triangle_geometries) => {
120 let tri_geo = triangle_geometries
121 .map(|tg| {
122 Ok(ArcBlasTriangleGeometry {
123 size: tg.size.clone(),
124 vertex_buffer: self.resolve_buffer_id(tg.vertex_buffer)?,
125 index_buffer: tg
126 .index_buffer
127 .map(|id| self.resolve_buffer_id(id))
128 .transpose()?,
129 transform_buffer: tg
130 .transform_buffer
131 .map(|id| self.resolve_buffer_id(id))
132 .transpose()?,
133 first_vertex: tg.first_vertex,
134 vertex_stride: tg.vertex_stride,
135 first_index: tg.first_index,
136 transform_buffer_offset: tg.transform_buffer_offset,
137 })
138 })
139 .collect::<Result<_, BuildAccelerationStructureError>>()?;
140 ArcBlasGeometries::TriangleGeometries(tri_geo)
141 }
142 BlasGeometries::AabbGeometries(aabb_geometries) => {
143 let aabb_geo = aabb_geometries
144 .map(|ag| {
145 Ok(ArcBlasAabbGeometry {
146 size: ag.size.clone(),
147 stride: ag.stride,
148 aabb_buffer: self.resolve_buffer_id(ag.aabb_buffer)?,
149 primitive_offset: ag.primitive_offset,
150 })
151 })
152 .collect::<Result<_, BuildAccelerationStructureError>>()?;
153 ArcBlasGeometries::AabbGeometries(aabb_geo)
154 }
155 };
156 Ok(ArcBlasBuildEntry {
157 blas: self.resolve_blas_id(blas_entry.blas_id)?,
158 geometries,
159 })
160 })
161 .collect::<Result<_, BuildAccelerationStructureError>>()?;
162
163 let tlas = tlas_iter
164 .map(|tlas_package| {
165 let instances = tlas_package
166 .instances
167 .map(|instance| {
168 instance
169 .as_ref()
170 .map(|instance| {
171 Ok(ArcTlasInstance {
172 blas: self.resolve_blas_id(instance.blas_id)?,
173 transform: *instance.transform,
174 custom_data: instance.custom_data,
175 mask: instance.mask,
176 })
177 })
178 .transpose()
179 })
180 .collect::<Result<_, BuildAccelerationStructureError>>()?;
181 Ok(ArcTlasPackage {
182 tlas: self.resolve_tlas_id(tlas_package.tlas_id)?,
183 instances,
184 lowest_unmodified: tlas_package.lowest_unmodified,
185 })
186 })
187 .collect::<Result<_, BuildAccelerationStructureError>>()?;
188
189 Ok(ArcCommand::BuildAccelerationStructures { blas, tlas })
190 })
191 }
192}
193
194pub(crate) fn build_acceleration_structures(
195 state: &mut EncodingState,
196 blas: Vec<OwnedBlasBuildEntry<ArcReferences>>,
197 tlas: Vec<OwnedTlasPackage<ArcReferences>>,
198) -> Result<(), BuildAccelerationStructureError> {
199 state
200 .device
201 .require_features(Features::EXPERIMENTAL_RAY_QUERY)?;
202
203 let mut build_command = AsBuild::with_capacity(blas.len(), tlas.len());
204 let mut input_barriers = Vec::<hal::BufferBarrier<dyn hal::DynBuffer>>::new();
205 let mut scratch_buffer_blas_size = 0;
206 let mut blas_storage = Vec::with_capacity(blas.len());
207 iter_blas(
208 blas.iter(),
209 &mut build_command,
210 &mut input_barriers,
211 &mut scratch_buffer_blas_size,
212 &mut blas_storage,
213 state,
214 )?;
215
216 let mut scratch_buffer_tlas_size: u64 = 0;
217 let mut tlas_storage = Vec::<TlasStore>::with_capacity(tlas.len());
218 let mut instance_buffer_staging_source = Vec::<u8>::new();
219
220 for package in tlas.iter() {
221 let tlas = &package.tlas;
222 state.tracker.tlas_s.insert_single(tlas.clone());
223
224 let scratch_buffer_offset = scratch_buffer_tlas_size;
225 scratch_buffer_tlas_size = scratch_buffer_tlas_size.saturating_add(align_to(
226 tlas.size_info.build_scratch_size,
227 u64::from(state.device.alignments.ray_tracing_scratch_buffer_alignment),
228 ));
229
230 let first_byte_index = instance_buffer_staging_source.len();
231
232 let mut dependencies = Vec::new();
233
234 let mut instance_count = 0;
235 for instance in package.instances.iter().flatten() {
236 if instance.custom_data >= (1u32 << 24u32) {
237 return Err(BuildAccelerationStructureError::TlasInvalidCustomIndex(
238 tlas.error_ident(),
239 ));
240 }
241 let blas = &instance.blas;
242 state.tracker.blas_s.insert_single(blas.clone());
243
244 instance_buffer_staging_source.extend(state.device.raw().tlas_instance_to_bytes(
245 hal::TlasInstance {
246 transform: instance.transform,
247 custom_data: instance.custom_data,
248 mask: instance.mask,
249 blas_address: blas.handle,
250 },
251 ));
252
253 if tlas
254 .flags
255 .contains(wgpu_types::AccelerationStructureFlags::ALLOW_RAY_HIT_VERTEX_RETURN)
256 && !blas
257 .flags
258 .contains(wgpu_types::AccelerationStructureFlags::ALLOW_RAY_HIT_VERTEX_RETURN)
259 {
260 return Err(
261 BuildAccelerationStructureError::TlasDependentMissingVertexReturn(
262 tlas.error_ident(),
263 blas.error_ident(),
264 ),
265 );
266 }
267
268 instance_count += 1;
269
270 dependencies.push(blas.clone());
271 }
272
273 build_command.tlas_s_built.push(TlasBuild {
274 tlas: tlas.clone(),
275 dependencies,
276 });
277
278 if instance_count > tlas.max_instance_count {
279 return Err(BuildAccelerationStructureError::TlasInstanceCountExceeded(
280 tlas.error_ident(),
281 instance_count,
282 tlas.max_instance_count,
283 ));
284 }
285
286 tlas_storage.push(TlasStore {
287 internal: UnsafeTlasStore {
288 tlas: tlas.clone(),
289 entries: hal::AccelerationStructureEntries::Instances(
290 hal::AccelerationStructureInstances {
291 buffer: Some(tlas.instance_buffer.as_ref()),
292 offset: 0,
293 count: instance_count,
294 },
295 ),
296 scratch_buffer_offset,
297 },
298 range: first_byte_index..instance_buffer_staging_source.len(),
299 });
300 }
301
302 let Some(scratch_size) =
303 wgt::BufferSize::new(max(scratch_buffer_blas_size, scratch_buffer_tlas_size))
304 else {
305 return Ok(());
307 };
308
309 if scratch_size.get() == u64::MAX {
310 return Err(crate::device::DeviceError::OutOfMemory.into());
311 }
312
313 let scratch_buffer = ScratchBuffer::new(state.device, scratch_size)?;
314
315 let scratch_buffer_barrier = hal::BufferBarrier::<dyn hal::DynBuffer> {
316 buffer: scratch_buffer.raw(),
317 usage: hal::StateTransition {
318 from: BufferUses::ACCELERATION_STRUCTURE_SCRATCH,
319 to: BufferUses::ACCELERATION_STRUCTURE_SCRATCH,
320 },
321 };
322
323 let mut tlas_descriptors = Vec::with_capacity(tlas_storage.len());
324
325 for &TlasStore {
326 internal:
327 UnsafeTlasStore {
328 ref tlas,
329 ref entries,
330 ref scratch_buffer_offset,
331 },
332 ..
333 } in &tlas_storage
334 {
335 if tlas.update_mode == wgt::AccelerationStructureUpdateMode::PreferUpdate {
336 log::warn!("build_acceleration_structures called with PreferUpdate, but only rebuild is implemented");
337 }
338 tlas_descriptors.push(hal::BuildAccelerationStructureDescriptor {
339 entries,
340 mode: hal::AccelerationStructureBuildMode::Build,
341 flags: tlas.flags,
342 source_acceleration_structure: None,
343 destination_acceleration_structure: tlas.try_raw(state.snatch_guard)?,
344 scratch_buffer: scratch_buffer.raw(),
345 scratch_buffer_offset: *scratch_buffer_offset,
346 })
347 }
348
349 let blas_present = !blas_storage.is_empty();
350 let tlas_present = !tlas_storage.is_empty();
351
352 let raw_encoder = &mut state.raw_encoder;
353
354 let mut blas_s_compactable = Vec::new();
355 let mut descriptors = Vec::with_capacity(blas.len());
356
357 for storage in &blas_storage {
358 descriptors.push(map_blas(
359 storage,
360 scratch_buffer.raw(),
361 state.snatch_guard,
362 &mut blas_s_compactable,
363 )?);
364 }
365
366 build_blas(
367 *raw_encoder,
368 blas_present,
369 tlas_present,
370 input_barriers,
371 &descriptors,
372 scratch_buffer_barrier,
373 blas_s_compactable,
374 );
375
376 if tlas_present {
377 let staging_buffer = if !instance_buffer_staging_source.is_empty() {
378 let mut staging_buffer = StagingBuffer::new(
379 state.device,
380 wgt::BufferSize::new(instance_buffer_staging_source.len() as u64).unwrap(),
381 )?;
382 staging_buffer.write(&instance_buffer_staging_source);
383 let flushed = staging_buffer.flush();
384 Some(flushed)
385 } else {
386 None
387 };
388
389 unsafe {
390 if let Some(ref staging_buffer) = staging_buffer {
391 raw_encoder.transition_buffers(&[hal::BufferBarrier::<dyn hal::DynBuffer> {
392 buffer: staging_buffer.raw(),
393 usage: hal::StateTransition {
394 from: BufferUses::MAP_WRITE,
395 to: BufferUses::COPY_SRC,
396 },
397 }]);
398 }
399 }
400
401 let mut instance_buffer_barriers = Vec::new();
402 for &TlasStore {
403 internal: UnsafeTlasStore { ref tlas, .. },
404 ref range,
405 } in &tlas_storage
406 {
407 let size = match wgt::BufferSize::new((range.end - range.start) as u64) {
408 None => continue,
409 Some(size) => size,
410 };
411 instance_buffer_barriers.push(hal::BufferBarrier::<dyn hal::DynBuffer> {
412 buffer: tlas.instance_buffer.as_ref(),
413 usage: hal::StateTransition {
414 from: BufferUses::COPY_DST,
415 to: BufferUses::TOP_LEVEL_ACCELERATION_STRUCTURE_INPUT,
416 },
417 });
418 unsafe {
419 raw_encoder.transition_buffers(&[hal::BufferBarrier::<dyn hal::DynBuffer> {
420 buffer: tlas.instance_buffer.as_ref(),
421 usage: hal::StateTransition {
422 from: BufferUses::TOP_LEVEL_ACCELERATION_STRUCTURE_INPUT,
423 to: BufferUses::COPY_DST,
424 },
425 }]);
426 let temp = hal::BufferCopy {
427 src_offset: range.start as u64,
428 dst_offset: 0,
429 size,
430 };
431 raw_encoder.copy_buffer_to_buffer(
432 staging_buffer.as_ref().unwrap().raw(),
433 tlas.instance_buffer.as_ref(),
434 &[temp],
435 );
436 }
437 }
438
439 unsafe {
440 raw_encoder.transition_buffers(&instance_buffer_barriers);
441
442 raw_encoder.build_acceleration_structures(&tlas_descriptors);
443
444 raw_encoder.place_acceleration_structure_barrier(hal::AccelerationStructureBarrier {
445 usage: hal::StateTransition {
446 from: hal::AccelerationStructureUses::BUILD_OUTPUT,
447 to: hal::AccelerationStructureUses::SHADER_INPUT,
448 },
449 });
450 }
451
452 if let Some(staging_buffer) = staging_buffer {
453 state
454 .temp_resources
455 .push(TempResource::StagingBuffer(staging_buffer));
456 }
457 }
458
459 state
460 .temp_resources
461 .push(TempResource::ScratchBuffer(scratch_buffer));
462
463 state.as_actions.push(AsAction::Build(build_command));
464
465 Ok(())
466}
467
468impl CommandBufferMutable {
469 pub(crate) fn validate_acceleration_structure_actions(
470 &self,
471 snatch_guard: &SnatchGuard,
472 command_index_guard: &mut RwLockWriteGuard<CommandIndices>,
473 ) -> Result<(), ValidateAsActionsError> {
474 profiling::scope!("CommandEncoder::[submission]::validate_as_actions");
475 for action in &self.as_actions {
476 match action {
477 AsAction::Build(build) => {
478 let build_command_index = NonZeroU64::new(
479 command_index_guard.next_acceleration_structure_build_command_index,
480 )
481 .unwrap();
482
483 command_index_guard.next_acceleration_structure_build_command_index += 1;
484 for blas in build.blas_s_built.iter() {
485 let mut state_lock = blas.compacted_state.lock();
486 *state_lock = match *state_lock {
487 BlasCompactState::Compacted => {
488 unreachable!("Should be validated out in build.")
489 }
490 _ => BlasCompactState::Idle,
493 };
494 *blas.built_index.write() = Some(build_command_index);
495 }
496
497 for tlas_build in build.tlas_s_built.iter() {
498 for blas in &tlas_build.dependencies {
499 if blas.built_index.read().is_none() {
500 return Err(ValidateAsActionsError::UsedUnbuiltBlas(
501 blas.error_ident(),
502 tlas_build.tlas.error_ident(),
503 ));
504 }
505 }
506 *tlas_build.tlas.built_index.write() = Some(build_command_index);
507 tlas_build
508 .tlas
509 .dependencies
510 .write()
511 .clone_from(&tlas_build.dependencies)
512 }
513 }
514 AsAction::UseTlas(tlas) => {
515 let tlas_build_index = tlas.built_index.read();
516 let dependencies = tlas.dependencies.read();
517
518 if (*tlas_build_index).is_none() {
519 return Err(ValidateAsActionsError::UsedUnbuiltTlas(tlas.error_ident()));
520 }
521 for blas in dependencies.deref() {
522 let blas_build_index = *blas.built_index.read();
523 if blas_build_index.is_none() {
524 return Err(ValidateAsActionsError::UsedUnbuiltBlas(
525 tlas.error_ident(),
526 blas.error_ident(),
527 ));
528 }
529 if blas_build_index.unwrap() > tlas_build_index.unwrap() {
530 return Err(ValidateAsActionsError::BlasNewerThenTlas(
531 blas.error_ident(),
532 tlas.error_ident(),
533 ));
534 }
535 blas.try_raw(snatch_guard)?;
536 }
537 }
538 }
539 }
540 Ok(())
541 }
542
543 pub(crate) fn set_acceleration_structure_dependencies(&self, snatch_guard: &SnatchGuard) {
544 profiling::scope!("CommandEncoder::[submission]::set_acceleration_structure_dependencies");
545 let tlas_dependencies_locks: Vec<_> = self
546 .as_actions
547 .iter()
548 .filter_map(|action| {
549 if let AsAction::UseTlas(tlas) = action {
550 Some(tlas.dependencies.read())
551 } else {
552 None
553 }
554 })
555 .collect();
556 let mut tlas_dependencies_lock_iter = tlas_dependencies_locks.iter();
557 let mut dependencies = Vec::new();
558 for action in &self.as_actions {
559 match action {
560 AsAction::Build(build) => {
561 for tlas_build in build.tlas_s_built.iter() {
562 for dependency in &tlas_build.dependencies {
563 if let Some(dependency) = dependency.raw(snatch_guard) {
564 dependencies.push(dependency);
565 }
566 }
567 }
568 }
569 AsAction::UseTlas(_tlas) => {
570 let tlas_dependencies = tlas_dependencies_lock_iter.next().unwrap(); for dependency in tlas_dependencies.iter() {
572 if let Some(dependency) = dependency.raw(snatch_guard) {
573 dependencies.push(dependency);
574 }
575 }
576 }
577 }
578 }
579 if !dependencies.is_empty() {
580 unsafe {
581 self.encoder
582 .raw
583 .set_acceleration_structure_dependencies(&self.encoder.list, &dependencies);
584 }
585 }
586 }
587}
588
589fn iter_blas<'snatch_guard: 'buffers, 'buffers>(
591 blas_iter: impl Iterator<Item = &'buffers OwnedBlasBuildEntry<ArcReferences>>,
592 build_command: &mut AsBuild,
593 input_barriers: &mut Vec<hal::BufferBarrier<'buffers, dyn hal::DynBuffer>>,
594 scratch_buffer_blas_size: &mut u64,
595 blas_storage: &mut Vec<BlasStore<'buffers>>,
596 state: &mut EncodingState<'snatch_guard, '_>,
597) -> Result<(), BuildAccelerationStructureError> {
598 for entry in blas_iter {
599 let blas = &entry.blas;
600 state.tracker.blas_s.insert_single(blas.clone());
601
602 build_command.blas_s_built.push(blas.clone());
603
604 match &entry.geometries {
605 ArcBlasGeometries::TriangleGeometries(triangle_geometries) => {
606 let mut triangle_entries =
607 Vec::<hal::AccelerationStructureTriangles<dyn hal::DynBuffer>>::new();
608
609 for (i, mesh) in triangle_geometries.iter().enumerate() {
610 let size_desc = match &blas.sizes {
611 wgt::BlasGeometrySizeDescriptors::Triangles { descriptors } => descriptors,
612 _ => {
613 return Err(BuildAccelerationStructureError::BlasGeometryKindMismatch(
614 blas.error_ident(),
615 ));
616 }
617 };
618 if i >= size_desc.len() {
619 return Err(BuildAccelerationStructureError::IncompatibleBlasBuildSizes(
620 blas.error_ident(),
621 ));
622 }
623 let size_desc = &size_desc[i];
624
625 if size_desc.flags != mesh.size.flags {
626 return Err(BuildAccelerationStructureError::IncompatibleBlasFlags(
627 blas.error_ident(),
628 size_desc.flags,
629 mesh.size.flags,
630 ));
631 }
632
633 if size_desc.vertex_count < mesh.size.vertex_count {
634 return Err(
635 BuildAccelerationStructureError::IncompatibleBlasVertexCount(
636 blas.error_ident(),
637 size_desc.vertex_count,
638 mesh.size.vertex_count,
639 ),
640 );
641 }
642
643 if size_desc.vertex_format != mesh.size.vertex_format {
644 return Err(BuildAccelerationStructureError::DifferentBlasVertexFormats(
645 blas.error_ident(),
646 size_desc.vertex_format,
647 mesh.size.vertex_format,
648 ));
649 }
650
651 if size_desc
652 .vertex_format
653 .min_acceleration_structure_vertex_stride()
654 > mesh.vertex_stride
655 {
656 return Err(BuildAccelerationStructureError::VertexStrideTooSmall(
657 blas.error_ident(),
658 size_desc
659 .vertex_format
660 .min_acceleration_structure_vertex_stride(),
661 mesh.vertex_stride,
662 ));
663 }
664
665 if mesh.vertex_stride
666 % size_desc
667 .vertex_format
668 .acceleration_structure_stride_alignment()
669 != 0
670 {
671 return Err(BuildAccelerationStructureError::VertexStrideUnaligned(
672 blas.error_ident(),
673 size_desc
674 .vertex_format
675 .acceleration_structure_stride_alignment(),
676 mesh.vertex_stride,
677 ));
678 }
679
680 match (size_desc.index_count, mesh.size.index_count) {
681 (Some(_), None) | (None, Some(_)) => {
682 return Err(
683 BuildAccelerationStructureError::BlasIndexCountProvidedMismatch(
684 blas.error_ident(),
685 ),
686 )
687 }
688 (Some(create), Some(build)) if create < build => {
689 return Err(
690 BuildAccelerationStructureError::IncompatibleBlasIndexCount(
691 blas.error_ident(),
692 create,
693 build,
694 ),
695 )
696 }
697 _ => {}
698 }
699
700 if size_desc.index_format != mesh.size.index_format {
701 return Err(BuildAccelerationStructureError::DifferentBlasIndexFormats(
702 blas.error_ident(),
703 size_desc.index_format,
704 mesh.size.index_format,
705 ));
706 }
707
708 if size_desc.index_count.is_some() && mesh.index_buffer.is_none() {
709 return Err(BuildAccelerationStructureError::MissingIndexBuffer(
710 blas.error_ident(),
711 ));
712 }
713 let vertex_buffer = mesh.vertex_buffer.clone();
714 let vertex_pending = state.tracker.buffers.set_single(
715 &vertex_buffer,
716 BufferUses::BOTTOM_LEVEL_ACCELERATION_STRUCTURE_INPUT,
717 );
718 let vertex_buffer = {
719 let vertex_raw = mesh.vertex_buffer.as_ref().try_raw(state.snatch_guard)?;
720 let vertex_buffer = &mesh.vertex_buffer;
721 vertex_buffer.check_usage(BufferUsages::BLAS_INPUT)?;
722
723 if let Some(barrier) = vertex_pending.map(|pending| {
724 pending.into_hal(vertex_buffer.as_ref(), state.snatch_guard)
725 }) {
726 input_barriers.push(barrier);
727 }
728 if u64::from(mesh.size.vertex_count)
729 .checked_add(u64::from(mesh.first_vertex))
730 .and_then(|end_vertex| end_vertex.checked_mul(mesh.vertex_stride))
731 .is_none_or(|end| vertex_buffer.size < end)
732 {
733 return Err(BuildAccelerationStructureError::InsufficientBufferSize {
734 buffer_ident: vertex_buffer.error_ident(),
735 offset: u64::from(mesh.first_vertex)
736 .saturating_mul(mesh.vertex_stride),
737 region_size: u64::from(mesh.size.vertex_count)
738 .saturating_mul(mesh.vertex_stride),
739 buffer_size: vertex_buffer.size,
740 });
741 }
742 let vertex_buffer_offset = mesh.first_vertex as u64 * mesh.vertex_stride;
743 state.buffer_memory_init_actions.extend(
744 vertex_buffer.initialization_status.read().create_action(
745 vertex_buffer,
746 vertex_buffer_offset
747 ..(vertex_buffer_offset
748 + mesh.size.vertex_count as u64 * mesh.vertex_stride),
749 MemoryInitKind::NeedsInitializedMemory,
750 ),
751 );
752 vertex_raw
753 };
754 let index_buffer = if let Some(ref index_buffer) = mesh.index_buffer {
755 if mesh.first_index.is_none()
756 || mesh.size.index_count.is_none()
757 || mesh.size.index_count.is_none()
758 {
759 return Err(BuildAccelerationStructureError::MissingAssociatedData(
760 index_buffer.error_ident(),
761 ));
762 }
763 let index_pending = state.tracker.buffers.set_single(
764 index_buffer,
765 BufferUses::BOTTOM_LEVEL_ACCELERATION_STRUCTURE_INPUT,
766 );
767 let index_raw = index_buffer.try_raw(state.snatch_guard)?;
768 index_buffer.check_usage(BufferUsages::BLAS_INPUT)?;
769
770 if let Some(barrier) = index_pending.map(|pending| {
771 pending.into_hal(index_buffer.as_ref(), state.snatch_guard)
772 }) {
773 input_barriers.push(barrier);
774 }
775 let index_stride = mesh.size.index_format.unwrap().byte_size();
776 let Some(vertex_offset) =
778 mesh.first_index.unwrap().checked_mul(index_stride)
779 else {
780 return Err(BuildAccelerationStructureError::OffsetLimitedTo4GB {
781 buffer_ident: index_buffer.error_ident(),
782 offset: u64::from(mesh.first_index.unwrap())
783 .saturating_mul(u64::from(index_stride)),
784 count: u64::from(mesh.first_index.unwrap()),
785 stride: u64::from(index_stride),
786 });
787 };
788 let Some(indexes_size) =
789 mesh.size.index_count.unwrap().checked_mul(index_stride)
790 else {
791 return Err(BuildAccelerationStructureError::OffsetLimitedTo4GB {
792 buffer_ident: index_buffer.error_ident(),
793 offset: u64::from(mesh.size.index_count.unwrap())
794 .saturating_mul(u64::from(index_stride)),
795 count: u64::from(mesh.size.index_count.unwrap()),
796 stride: u64::from(index_stride),
797 });
798 };
799
800 if mesh.size.index_count.unwrap() % 3 != 0 {
801 return Err(BuildAccelerationStructureError::InvalidIndexCount(
802 index_buffer.error_ident(),
803 mesh.size.index_count.unwrap(),
804 ));
805 }
806 if index_buffer.size < u64::from(vertex_offset)
807 || index_buffer.size - u64::from(vertex_offset)
808 < u64::from(indexes_size)
809 {
810 return Err(BuildAccelerationStructureError::InsufficientBufferSize {
811 buffer_ident: index_buffer.error_ident(),
812 offset: u64::from(vertex_offset),
813 region_size: u64::from(indexes_size),
814 buffer_size: index_buffer.size,
815 });
816 }
817
818 state.buffer_memory_init_actions.extend(
819 index_buffer.initialization_status.read().create_action(
820 index_buffer,
821 u64::from(vertex_offset)
822 ..(u64::from(vertex_offset) + u64::from(indexes_size)),
823 MemoryInitKind::NeedsInitializedMemory,
824 ),
825 );
826 Some((index_raw, vertex_offset))
827 } else {
828 None
829 };
830 let transform_buffer = if let Some(ref transform_buffer) = mesh.transform_buffer
831 {
832 if !blas
833 .flags
834 .contains(wgt::AccelerationStructureFlags::USE_TRANSFORM)
835 {
836 return Err(BuildAccelerationStructureError::UseTransformMissing(
837 blas.error_ident(),
838 ));
839 }
840 if mesh.transform_buffer_offset.is_none() {
841 return Err(BuildAccelerationStructureError::MissingAssociatedData(
842 transform_buffer.error_ident(),
843 ));
844 }
845 let transform_pending = state.tracker.buffers.set_single(
846 transform_buffer,
847 BufferUses::BOTTOM_LEVEL_ACCELERATION_STRUCTURE_INPUT,
848 );
849 if mesh.transform_buffer_offset.is_none() {
850 return Err(BuildAccelerationStructureError::MissingAssociatedData(
851 transform_buffer.error_ident(),
852 ));
853 }
854 let transform_raw = transform_buffer.try_raw(state.snatch_guard)?;
855 transform_buffer.check_usage(BufferUsages::BLAS_INPUT)?;
856
857 if let Some(barrier) = transform_pending.map(|pending| {
858 pending.into_hal(transform_buffer.as_ref(), state.snatch_guard)
859 }) {
860 input_barriers.push(barrier);
861 }
862
863 let Ok(offset) = u32::try_from(mesh.transform_buffer_offset.unwrap())
865 else {
866 return Err(BuildAccelerationStructureError::OffsetLimitedTo4GB {
867 buffer_ident: transform_buffer.error_ident(),
868 offset: mesh.transform_buffer_offset.unwrap(),
869 count: mesh.transform_buffer_offset.unwrap(),
870 stride: 1,
871 });
872 };
873
874 if offset % wgt::TRANSFORM_BUFFER_ALIGNMENT as u32 != 0 {
875 return Err(
876 BuildAccelerationStructureError::UnalignedTransformBufferOffset(
877 transform_buffer.error_ident(),
878 ),
879 );
880 }
881 if transform_buffer.size < 48
882 || transform_buffer.size - 48 < u64::from(offset)
883 {
884 return Err(BuildAccelerationStructureError::InsufficientBufferSize {
885 buffer_ident: transform_buffer.error_ident(),
886 region_size: 48,
887 offset: u64::from(offset),
888 buffer_size: transform_buffer.size,
889 });
890 }
891 state.buffer_memory_init_actions.extend(
892 transform_buffer.initialization_status.read().create_action(
893 transform_buffer,
894 u64::from(offset)..(u64::from(offset) + 48),
895 MemoryInitKind::NeedsInitializedMemory,
896 ),
897 );
898 Some((transform_raw, offset))
899 } else {
900 if blas
901 .flags
902 .contains(wgt::AccelerationStructureFlags::USE_TRANSFORM)
903 {
904 return Err(BuildAccelerationStructureError::TransformMissing(
905 blas.error_ident(),
906 ));
907 }
908 None
909 };
910
911 let triangles = hal::AccelerationStructureTriangles {
912 vertex_buffer: Some(vertex_buffer),
913 vertex_format: mesh.size.vertex_format,
914 first_vertex: mesh.first_vertex,
915 vertex_count: mesh.size.vertex_count,
916 vertex_stride: mesh.vertex_stride,
917 indices: index_buffer.map(|(index_raw, vertex_offset)| {
918 hal::AccelerationStructureTriangleIndices::<dyn hal::DynBuffer> {
919 format: mesh.size.index_format.unwrap(),
920 buffer: Some(index_raw),
921 offset: vertex_offset,
922 count: mesh.size.index_count.unwrap(),
923 }
924 }),
925 transform: transform_buffer.map(|(transform_raw, offset)| {
926 hal::AccelerationStructureTriangleTransform {
927 buffer: transform_raw,
928 offset,
929 }
930 }),
931 flags: mesh.size.flags,
932 };
933 triangle_entries.push(triangles);
934 }
935
936 {
937 let scratch_buffer_offset = *scratch_buffer_blas_size;
938 *scratch_buffer_blas_size = scratch_buffer_blas_size.saturating_add(align_to(
939 blas.size_info.build_scratch_size,
940 u64::from(state.device.alignments.ray_tracing_scratch_buffer_alignment),
941 ));
942
943 blas_storage.push(BlasStore {
944 blas: blas.clone(),
945 entries: hal::AccelerationStructureEntries::Triangles(triangle_entries),
946 scratch_buffer_offset,
947 });
948 }
949 }
950 ArcBlasGeometries::AabbGeometries(aabb_geometries) => {
951 let mut aabb_entries =
952 Vec::<hal::AccelerationStructureAABBs<dyn hal::DynBuffer>>::new();
953
954 for (i, aabb) in aabb_geometries.iter().enumerate() {
955 let size_desc = match &blas.sizes {
956 wgt::BlasGeometrySizeDescriptors::AABBs { descriptors } => descriptors,
957 _ => {
958 return Err(BuildAccelerationStructureError::BlasGeometryKindMismatch(
959 blas.error_ident(),
960 ));
961 }
962 };
963 if i >= size_desc.len() {
964 return Err(BuildAccelerationStructureError::IncompatibleBlasBuildSizes(
965 blas.error_ident(),
966 ));
967 }
968 let size_desc = &size_desc[i];
969
970 if size_desc.flags != aabb.size.flags {
971 return Err(BuildAccelerationStructureError::IncompatibleBlasFlags(
972 blas.error_ident(),
973 size_desc.flags,
974 aabb.size.flags,
975 ));
976 }
977
978 if size_desc.primitive_count < aabb.size.primitive_count {
979 return Err(
980 BuildAccelerationStructureError::IncompatibleBlasAabbPrimitiveCount(
981 blas.error_ident(),
982 size_desc.primitive_count,
983 aabb.size.primitive_count,
984 ),
985 );
986 }
987
988 if aabb.primitive_offset % 8 != 0 {
989 return Err(
990 BuildAccelerationStructureError::UnalignedAabbPrimitiveOffset(
991 blas.error_ident(),
992 ),
993 );
994 }
995
996 if aabb.stride < wgt::AABB_GEOMETRY_MIN_STRIDE || aabb.stride % 8 != 0 {
997 return Err(BuildAccelerationStructureError::InvalidAabbStride(
998 blas.error_ident(),
999 aabb.stride,
1000 ));
1001 }
1002
1003 let aabb_buffer = aabb.aabb_buffer.clone();
1004 let aabb_pending = state.tracker.buffers.set_single(
1005 &aabb_buffer,
1006 BufferUses::BOTTOM_LEVEL_ACCELERATION_STRUCTURE_INPUT,
1007 );
1008 let aabb_raw = {
1009 let aabb_raw = aabb.aabb_buffer.as_ref().try_raw(state.snatch_guard)?;
1010 let aabb_buffer = &aabb.aabb_buffer;
1011 aabb_buffer.check_usage(BufferUsages::BLAS_INPUT)?;
1012
1013 if let Some(barrier) = aabb_pending.map(|pending| {
1014 pending.into_hal(aabb_buffer.as_ref(), state.snatch_guard)
1015 }) {
1016 input_barriers.push(barrier);
1017 }
1018
1019 let Some(aabb_size) = u32::try_from(aabb.stride)
1021 .ok()
1022 .and_then(|stride| aabb.size.primitive_count.checked_mul(stride))
1023 else {
1024 return Err(BuildAccelerationStructureError::OffsetLimitedTo4GB {
1025 buffer_ident: aabb_buffer.error_ident(),
1026 offset: u64::from(aabb.size.primitive_count)
1027 .saturating_mul(aabb.stride),
1028 count: u64::from(aabb.size.primitive_count),
1029 stride: aabb.stride,
1030 });
1031 };
1032
1033 if aabb_buffer.size < u64::from(aabb.primitive_offset)
1034 || aabb_buffer.size - u64::from(aabb.primitive_offset)
1035 < u64::from(aabb_size)
1036 {
1037 return Err(BuildAccelerationStructureError::InsufficientBufferSize {
1038 buffer_ident: aabb_buffer.error_ident(),
1039 region_size: u64::from(aabb_size),
1040 offset: u64::from(aabb.primitive_offset),
1041 buffer_size: aabb_buffer.size,
1042 });
1043 }
1044
1045 state.buffer_memory_init_actions.extend(
1046 aabb_buffer.initialization_status.read().create_action(
1047 aabb_buffer,
1048 u64::from(aabb.primitive_offset)
1049 ..u64::from(aabb.primitive_offset) + u64::from(aabb_size),
1050 MemoryInitKind::NeedsInitializedMemory,
1051 ),
1052 );
1053 aabb_raw
1054 };
1055
1056 aabb_entries.push(hal::AccelerationStructureAABBs {
1057 buffer: Some(aabb_raw),
1058 offset: aabb.primitive_offset,
1059 count: aabb.size.primitive_count,
1060 stride: aabb.stride,
1061 flags: aabb.size.flags,
1062 });
1063 }
1064
1065 {
1066 let scratch_buffer_offset = *scratch_buffer_blas_size;
1067 *scratch_buffer_blas_size = scratch_buffer_blas_size.saturating_add(align_to(
1068 blas.size_info.build_scratch_size,
1069 u64::from(state.device.alignments.ray_tracing_scratch_buffer_alignment),
1070 ));
1071
1072 blas_storage.push(BlasStore {
1073 blas: blas.clone(),
1074 entries: hal::AccelerationStructureEntries::AABBs(aabb_entries),
1075 scratch_buffer_offset,
1076 });
1077 }
1078 }
1079 }
1080 }
1081 Ok(())
1082}
1083
1084fn map_blas<'a>(
1085 storage: &'a BlasStore<'_>,
1086 scratch_buffer: &'a dyn hal::DynBuffer,
1087 snatch_guard: &'a SnatchGuard,
1088 blases_compactable: &mut Vec<(
1089 &'a dyn hal::DynBuffer,
1090 &'a dyn hal::DynAccelerationStructure,
1091 )>,
1092) -> Result<
1093 hal::BuildAccelerationStructureDescriptor<
1094 'a,
1095 dyn hal::DynBuffer,
1096 dyn hal::DynAccelerationStructure,
1097 >,
1098 BuildAccelerationStructureError,
1099> {
1100 let BlasStore {
1101 blas,
1102 entries,
1103 scratch_buffer_offset,
1104 } = storage;
1105 if blas.update_mode == wgt::AccelerationStructureUpdateMode::PreferUpdate {
1106 log::debug!("only rebuild implemented")
1107 }
1108 let raw = blas.try_raw(snatch_guard)?;
1109
1110 let state_lock = blas.compacted_state.lock();
1111 if let BlasCompactState::Compacted = *state_lock {
1112 return Err(BuildAccelerationStructureError::CompactedBlas(
1113 blas.error_ident(),
1114 ));
1115 }
1116
1117 if blas
1118 .flags
1119 .contains(wgpu_types::AccelerationStructureFlags::ALLOW_COMPACTION)
1120 {
1121 blases_compactable.push((blas.compaction_buffer.as_ref().unwrap().as_ref(), raw));
1122 }
1123 Ok(hal::BuildAccelerationStructureDescriptor {
1124 entries,
1125 mode: hal::AccelerationStructureBuildMode::Build,
1126 flags: blas.flags,
1127 source_acceleration_structure: None,
1128 destination_acceleration_structure: raw,
1129 scratch_buffer,
1130 scratch_buffer_offset: *scratch_buffer_offset,
1131 })
1132}
1133
1134fn build_blas<'a>(
1135 cmd_buf_raw: &mut dyn hal::DynCommandEncoder,
1136 blas_present: bool,
1137 tlas_present: bool,
1138 input_barriers: Vec<hal::BufferBarrier<dyn hal::DynBuffer>>,
1139 blas_descriptors: &[hal::BuildAccelerationStructureDescriptor<
1140 'a,
1141 dyn hal::DynBuffer,
1142 dyn hal::DynAccelerationStructure,
1143 >],
1144 scratch_buffer_barrier: hal::BufferBarrier<dyn hal::DynBuffer>,
1145 blas_s_for_compaction: Vec<(
1146 &'a dyn hal::DynBuffer,
1147 &'a dyn hal::DynAccelerationStructure,
1148 )>,
1149) {
1150 unsafe {
1151 cmd_buf_raw.transition_buffers(&input_barriers);
1152 }
1153
1154 if blas_present {
1155 unsafe {
1156 cmd_buf_raw.place_acceleration_structure_barrier(hal::AccelerationStructureBarrier {
1157 usage: hal::StateTransition {
1158 from: hal::AccelerationStructureUses::BUILD_INPUT,
1159 to: hal::AccelerationStructureUses::BUILD_OUTPUT,
1160 },
1161 });
1162
1163 cmd_buf_raw.build_acceleration_structures(blas_descriptors);
1164 }
1165 }
1166
1167 if blas_present && tlas_present {
1168 unsafe {
1169 cmd_buf_raw.transition_buffers(&[scratch_buffer_barrier]);
1170 }
1171 }
1172
1173 let mut source_usage = hal::AccelerationStructureUses::empty();
1174 let mut destination_usage = hal::AccelerationStructureUses::empty();
1175 for &(buf, blas) in blas_s_for_compaction.iter() {
1176 unsafe {
1177 cmd_buf_raw.transition_buffers(&[hal::BufferBarrier {
1178 buffer: buf,
1179 usage: hal::StateTransition {
1180 from: BufferUses::ACCELERATION_STRUCTURE_QUERY,
1181 to: BufferUses::ACCELERATION_STRUCTURE_QUERY,
1182 },
1183 }])
1184 }
1185 unsafe { cmd_buf_raw.read_acceleration_structure_compact_size(blas, buf) }
1186 destination_usage |= hal::AccelerationStructureUses::COPY_SRC;
1187 }
1188
1189 if blas_present {
1190 source_usage |= hal::AccelerationStructureUses::BUILD_OUTPUT;
1191 destination_usage |= hal::AccelerationStructureUses::BUILD_INPUT
1192 }
1193 if tlas_present {
1194 source_usage |= hal::AccelerationStructureUses::SHADER_INPUT;
1195 destination_usage |= hal::AccelerationStructureUses::BUILD_OUTPUT;
1196 }
1197 unsafe {
1198 cmd_buf_raw.place_acceleration_structure_barrier(hal::AccelerationStructureBarrier {
1199 usage: hal::StateTransition {
1200 from: source_usage,
1201 to: destination_usage,
1202 },
1203 });
1204 }
1205}