1use alloc::{sync::Arc, vec::Vec};
2use core::{
3 cmp::max,
4 num::NonZeroU64,
5 ops::{Deref, Range},
6};
7
8use wgt::{math::align_to, BufferUsages, BufferUses, Features};
9
10use crate::{
11 command::encoder::EncodingState,
12 ray_tracing::{AsAction, AsBuild, TlasBuild, ValidateAsActionsError},
13 resource::InvalidResourceError,
14};
15use crate::{command::EncoderStateError, device::resource::CommandIndices};
16use crate::{
17 command::{ArcCommand, ArcReferences, CommandBufferMutable},
18 device::queue::TempResource,
19 global::Global,
20 id::CommandEncoderId,
21 init_tracker::MemoryInitKind,
22 ray_tracing::{
23 ArcBlasAabbGeometry, ArcBlasBuildEntry, ArcBlasGeometries, ArcBlasTriangleGeometry,
24 ArcTlasInstance, ArcTlasPackage, BlasBuildEntry, BlasGeometries,
25 BuildAccelerationStructureError, OwnedBlasBuildEntry, OwnedTlasPackage, TlasPackage,
26 },
27 resource::{Blas, BlasCompactState, Labeled, StagingBuffer, Tlas},
28 scratch::ScratchBuffer,
29 snatch::SnatchGuard,
30};
31use crate::{lock::RwLockWriteGuard, resource::RawResourceAccess};
32
33use crate::id::{BlasId, TlasId};
34
35struct BlasStore<'a> {
36 blas: Arc<Blas>,
37 entries: hal::AccelerationStructureEntries<'a, dyn hal::DynBuffer>,
38 scratch_buffer_offset: u64,
39}
40
41struct UnsafeTlasStore<'a> {
42 tlas: Arc<Tlas>,
43 entries: hal::AccelerationStructureEntries<'a, dyn hal::DynBuffer>,
44 scratch_buffer_offset: u64,
45}
46
47struct TlasStore<'a> {
48 internal: UnsafeTlasStore<'a>,
49 range: Range<usize>,
50}
51
52impl Global {
53 fn resolve_blas_id(&self, blas_id: BlasId) -> Result<Arc<Blas>, InvalidResourceError> {
54 self.hub.blas_s.get(blas_id).get()
55 }
56
57 fn resolve_tlas_id(&self, tlas_id: TlasId) -> Result<Arc<Tlas>, InvalidResourceError> {
58 self.hub.tlas_s.get(tlas_id).get()
59 }
60
61 pub fn command_encoder_mark_acceleration_structures_built(
62 &self,
63 command_encoder_id: CommandEncoderId,
64 blas_ids: &[BlasId],
65 tlas_ids: &[TlasId],
66 ) -> Result<(), EncoderStateError> {
67 profiling::scope!("CommandEncoder::mark_acceleration_structures_built");
68
69 let hub = &self.hub;
70
71 let cmd_enc = hub.command_encoders.get(command_encoder_id);
72
73 let mut cmd_buf_data = cmd_enc.data.lock();
74 cmd_buf_data.with_buffer(
75 crate::command::EncodingApi::Raw,
76 |cmd_buf_data| -> Result<(), BuildAccelerationStructureError> {
77 let device = &cmd_enc.device;
78 device.check_is_valid()?;
79 device.require_features(Features::EXPERIMENTAL_RAY_QUERY)?;
80
81 let mut build_command = AsBuild::with_capacity(blas_ids.len(), tlas_ids.len());
82
83 for blas in blas_ids {
84 let blas = hub.blas_s.get(*blas).get()?;
85 build_command.blas_s_built.push(blas);
86 }
87
88 for tlas in tlas_ids {
89 let tlas = hub.tlas_s.get(*tlas).get()?;
90 build_command.tlas_s_built.push(TlasBuild {
91 tlas,
92 dependencies: Vec::new(),
93 });
94 }
95
96 cmd_buf_data.as_actions.push(AsAction::Build(build_command));
97 Ok(())
98 },
99 )
100 }
101
102 pub fn command_encoder_build_acceleration_structures<'a>(
103 &self,
104 command_encoder_id: CommandEncoderId,
105 blas_iter: impl Iterator<Item = BlasBuildEntry<'a>>,
106 tlas_iter: impl Iterator<Item = TlasPackage<'a>>,
107 ) -> Result<(), EncoderStateError> {
108 profiling::scope!("CommandEncoder::build_acceleration_structures");
109
110 let hub = &self.hub;
111
112 let cmd_enc = hub.command_encoders.get(command_encoder_id);
113 let mut cmd_buf_data = cmd_enc.data.lock();
114
115 cmd_buf_data.push_with(|| -> Result<_, BuildAccelerationStructureError> {
116 let blas = blas_iter
117 .map(|blas_entry| {
118 let geometries = match blas_entry.geometries {
119 BlasGeometries::TriangleGeometries(triangle_geometries) => {
120 let tri_geo = triangle_geometries
121 .map(|tg| {
122 Ok(ArcBlasTriangleGeometry {
123 size: tg.size.clone(),
124 vertex_buffer: self.resolve_buffer_id(tg.vertex_buffer)?,
125 index_buffer: tg
126 .index_buffer
127 .map(|id| self.resolve_buffer_id(id))
128 .transpose()?,
129 transform_buffer: tg
130 .transform_buffer
131 .map(|id| self.resolve_buffer_id(id))
132 .transpose()?,
133 first_vertex: tg.first_vertex,
134 vertex_stride: tg.vertex_stride,
135 first_index: tg.first_index,
136 transform_buffer_offset: tg.transform_buffer_offset,
137 })
138 })
139 .collect::<Result<_, BuildAccelerationStructureError>>()?;
140 ArcBlasGeometries::TriangleGeometries(tri_geo)
141 }
142 BlasGeometries::AabbGeometries(aabb_geometries) => {
143 let aabb_geo = aabb_geometries
144 .map(|ag| {
145 Ok(ArcBlasAabbGeometry {
146 size: ag.size.clone(),
147 stride: ag.stride,
148 aabb_buffer: self.resolve_buffer_id(ag.aabb_buffer)?,
149 primitive_offset: ag.primitive_offset,
150 })
151 })
152 .collect::<Result<_, BuildAccelerationStructureError>>()?;
153 ArcBlasGeometries::AabbGeometries(aabb_geo)
154 }
155 };
156 Ok(ArcBlasBuildEntry {
157 blas: self.resolve_blas_id(blas_entry.blas_id)?,
158 geometries,
159 })
160 })
161 .collect::<Result<_, BuildAccelerationStructureError>>()?;
162
163 let tlas = tlas_iter
164 .map(|tlas_package| {
165 let instances = tlas_package
166 .instances
167 .map(|instance| {
168 instance
169 .as_ref()
170 .map(|instance| {
171 Ok(ArcTlasInstance {
172 blas: self.resolve_blas_id(instance.blas_id)?,
173 transform: *instance.transform,
174 custom_data: instance.custom_data,
175 mask: instance.mask,
176 })
177 })
178 .transpose()
179 })
180 .collect::<Result<_, BuildAccelerationStructureError>>()?;
181 Ok(ArcTlasPackage {
182 tlas: self.resolve_tlas_id(tlas_package.tlas_id)?,
183 instances,
184 lowest_unmodified: tlas_package.lowest_unmodified,
185 })
186 })
187 .collect::<Result<_, BuildAccelerationStructureError>>()?;
188
189 Ok(ArcCommand::BuildAccelerationStructures { blas, tlas })
190 })
191 }
192}
193
194pub(crate) fn build_acceleration_structures(
195 state: &mut EncodingState,
196 blas: Vec<OwnedBlasBuildEntry<ArcReferences>>,
197 tlas: Vec<OwnedTlasPackage<ArcReferences>>,
198) -> Result<(), BuildAccelerationStructureError> {
199 state
200 .device
201 .require_features(Features::EXPERIMENTAL_RAY_QUERY)?;
202
203 let mut build_command = AsBuild::with_capacity(blas.len(), tlas.len());
204 let mut input_barriers = Vec::<hal::BufferBarrier<dyn hal::DynBuffer>>::new();
205 let mut scratch_buffer_blas_size = 0;
206 let mut blas_storage = Vec::with_capacity(blas.len());
207 iter_blas(
208 blas.iter(),
209 &mut build_command,
210 &mut input_barriers,
211 &mut scratch_buffer_blas_size,
212 &mut blas_storage,
213 state,
214 )?;
215
216 let mut scratch_buffer_tlas_size: u64 = 0;
217 let mut tlas_storage = Vec::<TlasStore>::with_capacity(tlas.len());
218 let mut instance_buffer_staging_source = Vec::<u8>::new();
219
220 for package in tlas.iter() {
221 let tlas = &package.tlas;
222 state.tracker.tlas_s.insert_single(tlas.clone());
223
224 let scratch_buffer_offset = scratch_buffer_tlas_size;
225 scratch_buffer_tlas_size = scratch_buffer_tlas_size.saturating_add(align_to(
226 tlas.size_info.build_scratch_size,
227 u64::from(state.device.alignments.ray_tracing_scratch_buffer_alignment),
228 ));
229
230 let first_byte_index = instance_buffer_staging_source.len();
231
232 let mut dependencies = Vec::new();
233
234 let mut instance_count = 0;
235 for instance in package.instances.iter().flatten() {
236 if instance.custom_data >= (1u32 << 24u32) {
237 return Err(BuildAccelerationStructureError::TlasInvalidCustomIndex(
238 tlas.error_ident(),
239 ));
240 }
241 let blas = &instance.blas;
242 state.tracker.blas_s.insert_single(blas.clone());
243
244 instance_buffer_staging_source.extend(state.device.raw().tlas_instance_to_bytes(
245 hal::TlasInstance {
246 transform: instance.transform,
247 custom_data: instance.custom_data,
248 mask: instance.mask,
249 blas_address: blas.handle,
250 pipeline_intersection_data_offset: 0,
251 },
252 ));
253
254 if tlas
255 .flags
256 .contains(wgpu_types::AccelerationStructureFlags::ALLOW_RAY_HIT_VERTEX_RETURN)
257 && !blas
258 .flags
259 .contains(wgpu_types::AccelerationStructureFlags::ALLOW_RAY_HIT_VERTEX_RETURN)
260 {
261 return Err(
262 BuildAccelerationStructureError::TlasDependentMissingVertexReturn(
263 tlas.error_ident(),
264 blas.error_ident(),
265 ),
266 );
267 }
268
269 instance_count += 1;
270
271 dependencies.push(blas.clone());
272 }
273
274 build_command.tlas_s_built.push(TlasBuild {
275 tlas: tlas.clone(),
276 dependencies,
277 });
278
279 if instance_count > tlas.max_instance_count {
280 return Err(BuildAccelerationStructureError::TlasInstanceCountExceeded(
281 tlas.error_ident(),
282 instance_count,
283 tlas.max_instance_count,
284 ));
285 }
286
287 tlas_storage.push(TlasStore {
288 internal: UnsafeTlasStore {
289 tlas: tlas.clone(),
290 entries: hal::AccelerationStructureEntries::Instances(
291 hal::AccelerationStructureInstances {
292 buffer: Some(tlas.instance_buffer.as_ref()),
293 offset: 0,
294 count: instance_count,
295 },
296 ),
297 scratch_buffer_offset,
298 },
299 range: first_byte_index..instance_buffer_staging_source.len(),
300 });
301 }
302
303 let Some(scratch_size) =
304 wgt::BufferSize::new(max(scratch_buffer_blas_size, scratch_buffer_tlas_size))
305 else {
306 return Ok(());
308 };
309
310 if scratch_size.get() == u64::MAX {
311 return Err(crate::device::DeviceError::OutOfMemory.into());
312 }
313
314 let scratch_buffer = ScratchBuffer::new(state.device, scratch_size)?;
315
316 let scratch_buffer_barrier = hal::BufferBarrier::<dyn hal::DynBuffer> {
317 buffer: scratch_buffer.raw(),
318 usage: hal::StateTransition {
319 from: BufferUses::ACCELERATION_STRUCTURE_SCRATCH,
320 to: BufferUses::ACCELERATION_STRUCTURE_SCRATCH,
321 },
322 };
323
324 let mut tlas_descriptors = Vec::with_capacity(tlas_storage.len());
325
326 for &TlasStore {
327 internal:
328 UnsafeTlasStore {
329 ref tlas,
330 ref entries,
331 ref scratch_buffer_offset,
332 },
333 ..
334 } in &tlas_storage
335 {
336 if tlas.update_mode == wgt::AccelerationStructureUpdateMode::PreferUpdate {
337 log::warn!("build_acceleration_structures called with PreferUpdate, but only rebuild is implemented");
338 }
339 tlas_descriptors.push(hal::BuildAccelerationStructureDescriptor {
340 entries,
341 mode: hal::AccelerationStructureBuildMode::Build,
342 flags: tlas.flags,
343 source_acceleration_structure: None,
344 destination_acceleration_structure: tlas.try_raw(state.snatch_guard)?,
345 scratch_buffer: scratch_buffer.raw(),
346 scratch_buffer_offset: *scratch_buffer_offset,
347 })
348 }
349
350 let blas_present = !blas_storage.is_empty();
351 let tlas_present = !tlas_storage.is_empty();
352
353 let raw_encoder = &mut state.raw_encoder;
354
355 let mut blas_s_compactable = Vec::new();
356 let mut descriptors = Vec::with_capacity(blas.len());
357
358 for storage in &blas_storage {
359 descriptors.push(map_blas(
360 storage,
361 scratch_buffer.raw(),
362 state.snatch_guard,
363 &mut blas_s_compactable,
364 )?);
365 }
366
367 build_blas(
368 *raw_encoder,
369 blas_present,
370 tlas_present,
371 input_barriers,
372 &descriptors,
373 scratch_buffer_barrier,
374 blas_s_compactable,
375 );
376
377 if tlas_present {
378 let staging_buffer = if !instance_buffer_staging_source.is_empty() {
379 let mut staging_buffer = StagingBuffer::new(
380 state.device,
381 wgt::BufferSize::new(instance_buffer_staging_source.len() as u64).unwrap(),
382 )?;
383 staging_buffer.write(&instance_buffer_staging_source);
384 let flushed = staging_buffer.flush();
385 Some(flushed)
386 } else {
387 None
388 };
389
390 unsafe {
391 if let Some(ref staging_buffer) = staging_buffer {
392 raw_encoder.transition_buffers(&[hal::BufferBarrier::<dyn hal::DynBuffer> {
393 buffer: staging_buffer.raw(),
394 usage: hal::StateTransition {
395 from: BufferUses::MAP_WRITE,
396 to: BufferUses::COPY_SRC,
397 },
398 }]);
399 }
400 }
401
402 let mut instance_buffer_barriers = Vec::new();
403 for &TlasStore {
404 internal: UnsafeTlasStore { ref tlas, .. },
405 ref range,
406 } in &tlas_storage
407 {
408 let size = match wgt::BufferSize::new((range.end - range.start) as u64) {
409 None => continue,
410 Some(size) => size,
411 };
412 instance_buffer_barriers.push(hal::BufferBarrier::<dyn hal::DynBuffer> {
413 buffer: tlas.instance_buffer.as_ref(),
414 usage: hal::StateTransition {
415 from: BufferUses::COPY_DST,
416 to: BufferUses::TOP_LEVEL_ACCELERATION_STRUCTURE_INPUT,
417 },
418 });
419 unsafe {
420 raw_encoder.transition_buffers(&[hal::BufferBarrier::<dyn hal::DynBuffer> {
421 buffer: tlas.instance_buffer.as_ref(),
422 usage: hal::StateTransition {
423 from: BufferUses::TOP_LEVEL_ACCELERATION_STRUCTURE_INPUT,
424 to: BufferUses::COPY_DST,
425 },
426 }]);
427 let temp = hal::BufferCopy {
428 src_offset: range.start as u64,
429 dst_offset: 0,
430 size,
431 };
432 raw_encoder.copy_buffer_to_buffer(
433 staging_buffer.as_ref().unwrap().raw(),
434 tlas.instance_buffer.as_ref(),
435 &[temp],
436 );
437 }
438 }
439
440 unsafe {
441 raw_encoder.transition_buffers(&instance_buffer_barriers);
442
443 raw_encoder.build_acceleration_structures(&tlas_descriptors);
444
445 raw_encoder.place_acceleration_structure_barrier(hal::AccelerationStructureBarrier {
446 usage: hal::StateTransition {
447 from: hal::AccelerationStructureUses::BUILD_OUTPUT,
448 to: hal::AccelerationStructureUses::SHADER_INPUT,
449 },
450 });
451 }
452
453 if let Some(staging_buffer) = staging_buffer {
454 state
455 .temp_resources
456 .push(TempResource::StagingBuffer(staging_buffer));
457 }
458 }
459
460 state
461 .temp_resources
462 .push(TempResource::ScratchBuffer(scratch_buffer));
463
464 state.as_actions.push(AsAction::Build(build_command));
465
466 Ok(())
467}
468
469impl CommandBufferMutable {
470 pub(crate) fn validate_acceleration_structure_actions(
471 &self,
472 snatch_guard: &SnatchGuard,
473 command_index_guard: &mut RwLockWriteGuard<CommandIndices>,
474 ) -> Result<(), ValidateAsActionsError> {
475 profiling::scope!("CommandEncoder::[submission]::validate_as_actions");
476 for action in &self.as_actions {
477 match action {
478 AsAction::Build(build) => {
479 let build_command_index = NonZeroU64::new(
480 command_index_guard.next_acceleration_structure_build_command_index,
481 )
482 .unwrap();
483
484 command_index_guard.next_acceleration_structure_build_command_index += 1;
485 for blas in build.blas_s_built.iter() {
486 let mut state_lock = blas.compacted_state.lock();
487 *state_lock = match *state_lock {
488 BlasCompactState::Compacted => {
489 unreachable!("Should be validated out in build.")
490 }
491 _ => BlasCompactState::Idle,
494 };
495 *blas.built_index.write() = Some(build_command_index);
496 }
497
498 for tlas_build in build.tlas_s_built.iter() {
499 for blas in &tlas_build.dependencies {
500 if blas.built_index.read().is_none() {
501 return Err(ValidateAsActionsError::UsedUnbuiltBlas(
502 blas.error_ident(),
503 tlas_build.tlas.error_ident(),
504 ));
505 }
506 }
507 *tlas_build.tlas.built_index.write() = Some(build_command_index);
508 tlas_build
509 .tlas
510 .dependencies
511 .write()
512 .clone_from(&tlas_build.dependencies)
513 }
514 }
515 AsAction::UseTlas(tlas) => {
516 let tlas_build_index = tlas.built_index.read();
517 let dependencies = tlas.dependencies.read();
518
519 if (*tlas_build_index).is_none() {
520 return Err(ValidateAsActionsError::UsedUnbuiltTlas(tlas.error_ident()));
521 }
522 for blas in dependencies.deref() {
523 let blas_build_index = *blas.built_index.read();
524 if blas_build_index.is_none() {
525 return Err(ValidateAsActionsError::UsedUnbuiltBlas(
526 tlas.error_ident(),
527 blas.error_ident(),
528 ));
529 }
530 if blas_build_index.unwrap() > tlas_build_index.unwrap() {
531 return Err(ValidateAsActionsError::BlasNewerThenTlas(
532 blas.error_ident(),
533 tlas.error_ident(),
534 ));
535 }
536 blas.try_raw(snatch_guard)?;
537 }
538 }
539 }
540 }
541 Ok(())
542 }
543
544 pub(crate) fn set_acceleration_structure_dependencies(&self, snatch_guard: &SnatchGuard) {
545 profiling::scope!("CommandEncoder::[submission]::set_acceleration_structure_dependencies");
546 let tlas_dependencies_locks: Vec<_> = self
547 .as_actions
548 .iter()
549 .filter_map(|action| {
550 if let AsAction::UseTlas(tlas) = action {
551 Some(tlas.dependencies.read())
552 } else {
553 None
554 }
555 })
556 .collect();
557 let mut tlas_dependencies_lock_iter = tlas_dependencies_locks.iter();
558 let mut dependencies = Vec::new();
559 for action in &self.as_actions {
560 match action {
561 AsAction::Build(build) => {
562 for tlas_build in build.tlas_s_built.iter() {
563 for dependency in &tlas_build.dependencies {
564 if let Some(dependency) = dependency.raw(snatch_guard) {
565 dependencies.push(dependency);
566 }
567 }
568 }
569 }
570 AsAction::UseTlas(_tlas) => {
571 let tlas_dependencies = tlas_dependencies_lock_iter.next().unwrap(); for dependency in tlas_dependencies.iter() {
573 if let Some(dependency) = dependency.raw(snatch_guard) {
574 dependencies.push(dependency);
575 }
576 }
577 }
578 }
579 }
580 if !dependencies.is_empty() {
581 unsafe {
582 self.encoder
583 .raw
584 .set_acceleration_structure_dependencies(&self.encoder.list, &dependencies);
585 }
586 }
587 }
588}
589
590fn iter_blas<'snatch_guard: 'buffers, 'buffers>(
592 blas_iter: impl Iterator<Item = &'buffers OwnedBlasBuildEntry<ArcReferences>>,
593 build_command: &mut AsBuild,
594 input_barriers: &mut Vec<hal::BufferBarrier<'buffers, dyn hal::DynBuffer>>,
595 scratch_buffer_blas_size: &mut u64,
596 blas_storage: &mut Vec<BlasStore<'buffers>>,
597 state: &mut EncodingState<'snatch_guard, '_>,
598) -> Result<(), BuildAccelerationStructureError> {
599 for entry in blas_iter {
600 let blas = &entry.blas;
601 state.tracker.blas_s.insert_single(blas.clone());
602
603 build_command.blas_s_built.push(blas.clone());
604
605 match &entry.geometries {
606 ArcBlasGeometries::TriangleGeometries(triangle_geometries) => {
607 let mut triangle_entries =
608 Vec::<hal::AccelerationStructureTriangles<dyn hal::DynBuffer>>::new();
609
610 for (i, mesh) in triangle_geometries.iter().enumerate() {
611 let size_desc = match &blas.sizes {
612 wgt::BlasGeometrySizeDescriptors::Triangles { descriptors } => descriptors,
613 _ => {
614 return Err(BuildAccelerationStructureError::BlasGeometryKindMismatch(
615 blas.error_ident(),
616 ));
617 }
618 };
619 if i >= size_desc.len() {
620 return Err(BuildAccelerationStructureError::IncompatibleBlasBuildSizes(
621 blas.error_ident(),
622 ));
623 }
624 let size_desc = &size_desc[i];
625
626 if size_desc.flags != mesh.size.flags {
627 return Err(BuildAccelerationStructureError::IncompatibleBlasFlags(
628 blas.error_ident(),
629 size_desc.flags,
630 mesh.size.flags,
631 ));
632 }
633
634 if size_desc.vertex_count < mesh.size.vertex_count {
635 return Err(
636 BuildAccelerationStructureError::IncompatibleBlasVertexCount(
637 blas.error_ident(),
638 size_desc.vertex_count,
639 mesh.size.vertex_count,
640 ),
641 );
642 }
643
644 if size_desc.vertex_format != mesh.size.vertex_format {
645 return Err(BuildAccelerationStructureError::DifferentBlasVertexFormats(
646 blas.error_ident(),
647 size_desc.vertex_format,
648 mesh.size.vertex_format,
649 ));
650 }
651
652 if size_desc
653 .vertex_format
654 .min_acceleration_structure_vertex_stride()
655 > mesh.vertex_stride
656 {
657 return Err(BuildAccelerationStructureError::VertexStrideTooSmall(
658 blas.error_ident(),
659 size_desc
660 .vertex_format
661 .min_acceleration_structure_vertex_stride(),
662 mesh.vertex_stride,
663 ));
664 }
665
666 if mesh.vertex_stride
667 % size_desc
668 .vertex_format
669 .acceleration_structure_stride_alignment()
670 != 0
671 {
672 return Err(BuildAccelerationStructureError::VertexStrideUnaligned(
673 blas.error_ident(),
674 size_desc
675 .vertex_format
676 .acceleration_structure_stride_alignment(),
677 mesh.vertex_stride,
678 ));
679 }
680
681 match (size_desc.index_count, mesh.size.index_count) {
682 (Some(_), None) | (None, Some(_)) => {
683 return Err(
684 BuildAccelerationStructureError::BlasIndexCountProvidedMismatch(
685 blas.error_ident(),
686 ),
687 )
688 }
689 (Some(create), Some(build)) if create < build => {
690 return Err(
691 BuildAccelerationStructureError::IncompatibleBlasIndexCount(
692 blas.error_ident(),
693 create,
694 build,
695 ),
696 )
697 }
698 _ => {}
699 }
700
701 if size_desc.index_format != mesh.size.index_format {
702 return Err(BuildAccelerationStructureError::DifferentBlasIndexFormats(
703 blas.error_ident(),
704 size_desc.index_format,
705 mesh.size.index_format,
706 ));
707 }
708
709 if size_desc.index_count.is_some() && mesh.index_buffer.is_none() {
710 return Err(BuildAccelerationStructureError::MissingIndexBuffer(
711 blas.error_ident(),
712 ));
713 }
714 let vertex_buffer = mesh.vertex_buffer.clone();
715 let vertex_pending = state.tracker.buffers.set_single(
716 &vertex_buffer,
717 BufferUses::BOTTOM_LEVEL_ACCELERATION_STRUCTURE_INPUT,
718 );
719 let vertex_buffer = {
720 let vertex_raw = mesh.vertex_buffer.as_ref().try_raw(state.snatch_guard)?;
721 let vertex_buffer = &mesh.vertex_buffer;
722 vertex_buffer.check_usage(BufferUsages::BLAS_INPUT)?;
723
724 if let Some(barrier) = vertex_pending.map(|pending| {
725 pending.into_hal(vertex_buffer.as_ref(), state.snatch_guard)
726 }) {
727 input_barriers.push(barrier);
728 }
729 if u64::from(mesh.size.vertex_count)
730 .checked_add(u64::from(mesh.first_vertex))
731 .and_then(|end_vertex| end_vertex.checked_mul(mesh.vertex_stride))
732 .is_none_or(|end| vertex_buffer.size < end)
733 {
734 return Err(BuildAccelerationStructureError::InsufficientBufferSize {
735 buffer_ident: vertex_buffer.error_ident(),
736 offset: u64::from(mesh.first_vertex)
737 .saturating_mul(mesh.vertex_stride),
738 region_size: u64::from(mesh.size.vertex_count)
739 .saturating_mul(mesh.vertex_stride),
740 buffer_size: vertex_buffer.size,
741 });
742 }
743 let vertex_buffer_offset = mesh.first_vertex as u64 * mesh.vertex_stride;
744 state.buffer_memory_init_actions.extend(
745 vertex_buffer.initialization_status.read().create_action(
746 vertex_buffer,
747 vertex_buffer_offset
748 ..(vertex_buffer_offset
749 + mesh.size.vertex_count as u64 * mesh.vertex_stride),
750 MemoryInitKind::NeedsInitializedMemory,
751 ),
752 );
753 vertex_raw
754 };
755 let index_buffer = if let Some(ref index_buffer) = mesh.index_buffer {
756 if mesh.first_index.is_none()
757 || mesh.size.index_count.is_none()
758 || mesh.size.index_count.is_none()
759 {
760 return Err(BuildAccelerationStructureError::MissingAssociatedData(
761 index_buffer.error_ident(),
762 ));
763 }
764 let index_pending = state.tracker.buffers.set_single(
765 index_buffer,
766 BufferUses::BOTTOM_LEVEL_ACCELERATION_STRUCTURE_INPUT,
767 );
768 let index_raw = index_buffer.try_raw(state.snatch_guard)?;
769 index_buffer.check_usage(BufferUsages::BLAS_INPUT)?;
770
771 if let Some(barrier) = index_pending.map(|pending| {
772 pending.into_hal(index_buffer.as_ref(), state.snatch_guard)
773 }) {
774 input_barriers.push(barrier);
775 }
776 let index_stride = mesh.size.index_format.unwrap().byte_size();
777 let Some(vertex_offset) =
779 mesh.first_index.unwrap().checked_mul(index_stride)
780 else {
781 return Err(BuildAccelerationStructureError::OffsetLimitedTo4GB {
782 buffer_ident: index_buffer.error_ident(),
783 offset: u64::from(mesh.first_index.unwrap())
784 .saturating_mul(u64::from(index_stride)),
785 count: u64::from(mesh.first_index.unwrap()),
786 stride: u64::from(index_stride),
787 });
788 };
789 let Some(indexes_size) =
790 mesh.size.index_count.unwrap().checked_mul(index_stride)
791 else {
792 return Err(BuildAccelerationStructureError::OffsetLimitedTo4GB {
793 buffer_ident: index_buffer.error_ident(),
794 offset: u64::from(mesh.size.index_count.unwrap())
795 .saturating_mul(u64::from(index_stride)),
796 count: u64::from(mesh.size.index_count.unwrap()),
797 stride: u64::from(index_stride),
798 });
799 };
800
801 if mesh.size.index_count.unwrap() % 3 != 0 {
802 return Err(BuildAccelerationStructureError::InvalidIndexCount(
803 index_buffer.error_ident(),
804 mesh.size.index_count.unwrap(),
805 ));
806 }
807 if index_buffer.size < u64::from(vertex_offset)
808 || index_buffer.size - u64::from(vertex_offset)
809 < u64::from(indexes_size)
810 {
811 return Err(BuildAccelerationStructureError::InsufficientBufferSize {
812 buffer_ident: index_buffer.error_ident(),
813 offset: u64::from(vertex_offset),
814 region_size: u64::from(indexes_size),
815 buffer_size: index_buffer.size,
816 });
817 }
818
819 state.buffer_memory_init_actions.extend(
820 index_buffer.initialization_status.read().create_action(
821 index_buffer,
822 u64::from(vertex_offset)
823 ..(u64::from(vertex_offset) + u64::from(indexes_size)),
824 MemoryInitKind::NeedsInitializedMemory,
825 ),
826 );
827 Some((index_raw, vertex_offset))
828 } else {
829 None
830 };
831 let transform_buffer = if let Some(ref transform_buffer) = mesh.transform_buffer
832 {
833 if !blas
834 .flags
835 .contains(wgt::AccelerationStructureFlags::USE_TRANSFORM)
836 {
837 return Err(BuildAccelerationStructureError::UseTransformMissing(
838 blas.error_ident(),
839 ));
840 }
841 if mesh.transform_buffer_offset.is_none() {
842 return Err(BuildAccelerationStructureError::MissingAssociatedData(
843 transform_buffer.error_ident(),
844 ));
845 }
846 let transform_pending = state.tracker.buffers.set_single(
847 transform_buffer,
848 BufferUses::BOTTOM_LEVEL_ACCELERATION_STRUCTURE_INPUT,
849 );
850 if mesh.transform_buffer_offset.is_none() {
851 return Err(BuildAccelerationStructureError::MissingAssociatedData(
852 transform_buffer.error_ident(),
853 ));
854 }
855 let transform_raw = transform_buffer.try_raw(state.snatch_guard)?;
856 transform_buffer.check_usage(BufferUsages::BLAS_INPUT)?;
857
858 if let Some(barrier) = transform_pending.map(|pending| {
859 pending.into_hal(transform_buffer.as_ref(), state.snatch_guard)
860 }) {
861 input_barriers.push(barrier);
862 }
863
864 let Ok(offset) = u32::try_from(mesh.transform_buffer_offset.unwrap())
866 else {
867 return Err(BuildAccelerationStructureError::OffsetLimitedTo4GB {
868 buffer_ident: transform_buffer.error_ident(),
869 offset: mesh.transform_buffer_offset.unwrap(),
870 count: mesh.transform_buffer_offset.unwrap(),
871 stride: 1,
872 });
873 };
874
875 if offset % wgt::TRANSFORM_BUFFER_ALIGNMENT as u32 != 0 {
876 return Err(
877 BuildAccelerationStructureError::UnalignedTransformBufferOffset(
878 transform_buffer.error_ident(),
879 ),
880 );
881 }
882 if transform_buffer.size < 48
883 || transform_buffer.size - 48 < u64::from(offset)
884 {
885 return Err(BuildAccelerationStructureError::InsufficientBufferSize {
886 buffer_ident: transform_buffer.error_ident(),
887 region_size: 48,
888 offset: u64::from(offset),
889 buffer_size: transform_buffer.size,
890 });
891 }
892 state.buffer_memory_init_actions.extend(
893 transform_buffer.initialization_status.read().create_action(
894 transform_buffer,
895 u64::from(offset)..(u64::from(offset) + 48),
896 MemoryInitKind::NeedsInitializedMemory,
897 ),
898 );
899 Some((transform_raw, offset))
900 } else {
901 if blas
902 .flags
903 .contains(wgt::AccelerationStructureFlags::USE_TRANSFORM)
904 {
905 return Err(BuildAccelerationStructureError::TransformMissing(
906 blas.error_ident(),
907 ));
908 }
909 None
910 };
911
912 let triangles = hal::AccelerationStructureTriangles {
913 vertex_buffer: Some(vertex_buffer),
914 vertex_format: mesh.size.vertex_format,
915 first_vertex: mesh.first_vertex,
916 vertex_count: mesh.size.vertex_count,
917 vertex_stride: mesh.vertex_stride,
918 indices: index_buffer.map(|(index_raw, vertex_offset)| {
919 hal::AccelerationStructureTriangleIndices::<dyn hal::DynBuffer> {
920 format: mesh.size.index_format.unwrap(),
921 buffer: Some(index_raw),
922 offset: vertex_offset,
923 count: mesh.size.index_count.unwrap(),
924 }
925 }),
926 transform: transform_buffer.map(|(transform_raw, offset)| {
927 hal::AccelerationStructureTriangleTransform {
928 buffer: transform_raw,
929 offset,
930 }
931 }),
932 flags: mesh.size.flags,
933 };
934 triangle_entries.push(triangles);
935 }
936
937 {
938 let scratch_buffer_offset = *scratch_buffer_blas_size;
939 *scratch_buffer_blas_size = scratch_buffer_blas_size.saturating_add(align_to(
940 blas.size_info.build_scratch_size,
941 u64::from(state.device.alignments.ray_tracing_scratch_buffer_alignment),
942 ));
943
944 blas_storage.push(BlasStore {
945 blas: blas.clone(),
946 entries: hal::AccelerationStructureEntries::Triangles(triangle_entries),
947 scratch_buffer_offset,
948 });
949 }
950 }
951 ArcBlasGeometries::AabbGeometries(aabb_geometries) => {
952 let mut aabb_entries =
953 Vec::<hal::AccelerationStructureAABBs<dyn hal::DynBuffer>>::new();
954
955 for (i, aabb) in aabb_geometries.iter().enumerate() {
956 let size_desc = match &blas.sizes {
957 wgt::BlasGeometrySizeDescriptors::AABBs { descriptors } => descriptors,
958 _ => {
959 return Err(BuildAccelerationStructureError::BlasGeometryKindMismatch(
960 blas.error_ident(),
961 ));
962 }
963 };
964 if i >= size_desc.len() {
965 return Err(BuildAccelerationStructureError::IncompatibleBlasBuildSizes(
966 blas.error_ident(),
967 ));
968 }
969 let size_desc = &size_desc[i];
970
971 if size_desc.flags != aabb.size.flags {
972 return Err(BuildAccelerationStructureError::IncompatibleBlasFlags(
973 blas.error_ident(),
974 size_desc.flags,
975 aabb.size.flags,
976 ));
977 }
978
979 if size_desc.primitive_count < aabb.size.primitive_count {
980 return Err(
981 BuildAccelerationStructureError::IncompatibleBlasAabbPrimitiveCount(
982 blas.error_ident(),
983 size_desc.primitive_count,
984 aabb.size.primitive_count,
985 ),
986 );
987 }
988
989 if aabb.primitive_offset % 8 != 0 {
990 return Err(
991 BuildAccelerationStructureError::UnalignedAabbPrimitiveOffset(
992 blas.error_ident(),
993 ),
994 );
995 }
996
997 if aabb.stride < wgt::AABB_GEOMETRY_MIN_STRIDE || aabb.stride % 8 != 0 {
998 return Err(BuildAccelerationStructureError::InvalidAabbStride(
999 blas.error_ident(),
1000 aabb.stride,
1001 ));
1002 }
1003
1004 let aabb_buffer = aabb.aabb_buffer.clone();
1005 let aabb_pending = state.tracker.buffers.set_single(
1006 &aabb_buffer,
1007 BufferUses::BOTTOM_LEVEL_ACCELERATION_STRUCTURE_INPUT,
1008 );
1009 let aabb_raw = {
1010 let aabb_raw = aabb.aabb_buffer.as_ref().try_raw(state.snatch_guard)?;
1011 let aabb_buffer = &aabb.aabb_buffer;
1012 aabb_buffer.check_usage(BufferUsages::BLAS_INPUT)?;
1013
1014 if let Some(barrier) = aabb_pending.map(|pending| {
1015 pending.into_hal(aabb_buffer.as_ref(), state.snatch_guard)
1016 }) {
1017 input_barriers.push(barrier);
1018 }
1019
1020 let Some(aabb_size) = u32::try_from(aabb.stride)
1022 .ok()
1023 .and_then(|stride| aabb.size.primitive_count.checked_mul(stride))
1024 else {
1025 return Err(BuildAccelerationStructureError::OffsetLimitedTo4GB {
1026 buffer_ident: aabb_buffer.error_ident(),
1027 offset: u64::from(aabb.size.primitive_count)
1028 .saturating_mul(aabb.stride),
1029 count: u64::from(aabb.size.primitive_count),
1030 stride: aabb.stride,
1031 });
1032 };
1033
1034 if aabb_buffer.size < u64::from(aabb.primitive_offset)
1035 || aabb_buffer.size - u64::from(aabb.primitive_offset)
1036 < u64::from(aabb_size)
1037 {
1038 return Err(BuildAccelerationStructureError::InsufficientBufferSize {
1039 buffer_ident: aabb_buffer.error_ident(),
1040 region_size: u64::from(aabb_size),
1041 offset: u64::from(aabb.primitive_offset),
1042 buffer_size: aabb_buffer.size,
1043 });
1044 }
1045
1046 state.buffer_memory_init_actions.extend(
1047 aabb_buffer.initialization_status.read().create_action(
1048 aabb_buffer,
1049 u64::from(aabb.primitive_offset)
1050 ..u64::from(aabb.primitive_offset) + u64::from(aabb_size),
1051 MemoryInitKind::NeedsInitializedMemory,
1052 ),
1053 );
1054 aabb_raw
1055 };
1056
1057 aabb_entries.push(hal::AccelerationStructureAABBs {
1058 buffer: Some(aabb_raw),
1059 offset: aabb.primitive_offset,
1060 count: aabb.size.primitive_count,
1061 stride: aabb.stride,
1062 flags: aabb.size.flags,
1063 });
1064 }
1065
1066 {
1067 let scratch_buffer_offset = *scratch_buffer_blas_size;
1068 *scratch_buffer_blas_size = scratch_buffer_blas_size.saturating_add(align_to(
1069 blas.size_info.build_scratch_size,
1070 u64::from(state.device.alignments.ray_tracing_scratch_buffer_alignment),
1071 ));
1072
1073 blas_storage.push(BlasStore {
1074 blas: blas.clone(),
1075 entries: hal::AccelerationStructureEntries::AABBs(aabb_entries),
1076 scratch_buffer_offset,
1077 });
1078 }
1079 }
1080 }
1081 }
1082 Ok(())
1083}
1084
1085fn map_blas<'a>(
1086 storage: &'a BlasStore<'_>,
1087 scratch_buffer: &'a dyn hal::DynBuffer,
1088 snatch_guard: &'a SnatchGuard,
1089 blases_compactable: &mut Vec<(
1090 &'a dyn hal::DynBuffer,
1091 &'a dyn hal::DynAccelerationStructure,
1092 )>,
1093) -> Result<
1094 hal::BuildAccelerationStructureDescriptor<
1095 'a,
1096 dyn hal::DynBuffer,
1097 dyn hal::DynAccelerationStructure,
1098 >,
1099 BuildAccelerationStructureError,
1100> {
1101 let BlasStore {
1102 blas,
1103 entries,
1104 scratch_buffer_offset,
1105 } = storage;
1106 if blas.update_mode == wgt::AccelerationStructureUpdateMode::PreferUpdate {
1107 log::debug!("only rebuild implemented")
1108 }
1109 let raw = blas.try_raw(snatch_guard)?;
1110
1111 let state_lock = blas.compacted_state.lock();
1112 if let BlasCompactState::Compacted = *state_lock {
1113 return Err(BuildAccelerationStructureError::CompactedBlas(
1114 blas.error_ident(),
1115 ));
1116 }
1117
1118 if blas
1119 .flags
1120 .contains(wgpu_types::AccelerationStructureFlags::ALLOW_COMPACTION)
1121 {
1122 blases_compactable.push((blas.compaction_buffer.as_ref().unwrap().as_ref(), raw));
1123 }
1124 Ok(hal::BuildAccelerationStructureDescriptor {
1125 entries,
1126 mode: hal::AccelerationStructureBuildMode::Build,
1127 flags: blas.flags,
1128 source_acceleration_structure: None,
1129 destination_acceleration_structure: raw,
1130 scratch_buffer,
1131 scratch_buffer_offset: *scratch_buffer_offset,
1132 })
1133}
1134
1135fn build_blas<'a>(
1136 cmd_buf_raw: &mut dyn hal::DynCommandEncoder,
1137 blas_present: bool,
1138 tlas_present: bool,
1139 input_barriers: Vec<hal::BufferBarrier<dyn hal::DynBuffer>>,
1140 blas_descriptors: &[hal::BuildAccelerationStructureDescriptor<
1141 'a,
1142 dyn hal::DynBuffer,
1143 dyn hal::DynAccelerationStructure,
1144 >],
1145 scratch_buffer_barrier: hal::BufferBarrier<dyn hal::DynBuffer>,
1146 blas_s_for_compaction: Vec<(
1147 &'a dyn hal::DynBuffer,
1148 &'a dyn hal::DynAccelerationStructure,
1149 )>,
1150) {
1151 unsafe {
1152 cmd_buf_raw.transition_buffers(&input_barriers);
1153 }
1154
1155 if blas_present {
1156 unsafe {
1157 cmd_buf_raw.place_acceleration_structure_barrier(hal::AccelerationStructureBarrier {
1158 usage: hal::StateTransition {
1159 from: hal::AccelerationStructureUses::BUILD_INPUT,
1160 to: hal::AccelerationStructureUses::BUILD_OUTPUT,
1161 },
1162 });
1163
1164 cmd_buf_raw.build_acceleration_structures(blas_descriptors);
1165 }
1166 }
1167
1168 if blas_present && tlas_present {
1169 unsafe {
1170 cmd_buf_raw.transition_buffers(&[scratch_buffer_barrier]);
1171 }
1172 }
1173
1174 let mut source_usage = hal::AccelerationStructureUses::empty();
1175 let mut destination_usage = hal::AccelerationStructureUses::empty();
1176 for &(buf, blas) in blas_s_for_compaction.iter() {
1177 unsafe {
1178 cmd_buf_raw.transition_buffers(&[hal::BufferBarrier {
1179 buffer: buf,
1180 usage: hal::StateTransition {
1181 from: BufferUses::ACCELERATION_STRUCTURE_QUERY,
1182 to: BufferUses::ACCELERATION_STRUCTURE_QUERY,
1183 },
1184 }])
1185 }
1186 unsafe { cmd_buf_raw.read_acceleration_structure_compact_size(blas, buf) }
1187 destination_usage |= hal::AccelerationStructureUses::COPY_SRC;
1188 }
1189
1190 if blas_present {
1191 source_usage |= hal::AccelerationStructureUses::BUILD_OUTPUT;
1192 destination_usage |= hal::AccelerationStructureUses::BUILD_INPUT
1193 }
1194 if tlas_present {
1195 source_usage |= hal::AccelerationStructureUses::SHADER_INPUT;
1196 destination_usage |= hal::AccelerationStructureUses::BUILD_OUTPUT;
1197 }
1198 unsafe {
1199 cmd_buf_raw.place_acceleration_structure_barrier(hal::AccelerationStructureBarrier {
1200 usage: hal::StateTransition {
1201 from: source_usage,
1202 to: destination_usage,
1203 },
1204 });
1205 }
1206}