1use alloc::{sync::Arc, vec::Vec};
2use core::{
3 cmp::max,
4 num::NonZeroU64,
5 ops::{Deref, Range},
6};
7
8use wgt::{math::align_to, BufferUsages, BufferUses, Features};
9
10use crate::{
11 command::encoder::EncodingState,
12 ray_tracing::{AsAction, AsBuild, TlasBuild, ValidateAsActionsError},
13 resource::InvalidResourceError,
14};
15use crate::{command::EncoderStateError, device::resource::CommandIndices};
16use crate::{
17 command::{ArcCommand, ArcReferences, CommandBufferMutable},
18 device::queue::TempResource,
19 global::Global,
20 id::CommandEncoderId,
21 init_tracker::MemoryInitKind,
22 ray_tracing::{
23 ArcBlasBuildEntry, ArcBlasGeometries, ArcBlasTriangleGeometry, ArcTlasInstance,
24 ArcTlasPackage, BlasBuildEntry, BlasGeometries, BuildAccelerationStructureError,
25 OwnedBlasBuildEntry, OwnedTlasPackage, TlasPackage,
26 },
27 resource::{Blas, BlasCompactState, Labeled, StagingBuffer, Tlas},
28 scratch::ScratchBuffer,
29 snatch::SnatchGuard,
30};
31use crate::{lock::RwLockWriteGuard, resource::RawResourceAccess};
32
33use crate::id::{BlasId, TlasId};
34
35struct BlasStore<'a> {
36 blas: Arc<Blas>,
37 entries: hal::AccelerationStructureEntries<'a, dyn hal::DynBuffer>,
38 scratch_buffer_offset: u64,
39}
40
41struct UnsafeTlasStore<'a> {
42 tlas: Arc<Tlas>,
43 entries: hal::AccelerationStructureEntries<'a, dyn hal::DynBuffer>,
44 scratch_buffer_offset: u64,
45}
46
47struct TlasStore<'a> {
48 internal: UnsafeTlasStore<'a>,
49 range: Range<usize>,
50}
51
52impl Global {
53 fn resolve_blas_id(&self, blas_id: BlasId) -> Result<Arc<Blas>, InvalidResourceError> {
54 self.hub.blas_s.get(blas_id).get()
55 }
56
57 fn resolve_tlas_id(&self, tlas_id: TlasId) -> Result<Arc<Tlas>, InvalidResourceError> {
58 self.hub.tlas_s.get(tlas_id).get()
59 }
60
61 pub fn command_encoder_mark_acceleration_structures_built(
62 &self,
63 command_encoder_id: CommandEncoderId,
64 blas_ids: &[BlasId],
65 tlas_ids: &[TlasId],
66 ) -> Result<(), EncoderStateError> {
67 profiling::scope!("CommandEncoder::mark_acceleration_structures_built");
68
69 let hub = &self.hub;
70
71 let cmd_enc = hub.command_encoders.get(command_encoder_id);
72
73 let mut cmd_buf_data = cmd_enc.data.lock();
74 cmd_buf_data.with_buffer(
75 crate::command::EncodingApi::Raw,
76 |cmd_buf_data| -> Result<(), BuildAccelerationStructureError> {
77 let device = &cmd_enc.device;
78 device.check_is_valid()?;
79 device.require_features(Features::EXPERIMENTAL_RAY_QUERY)?;
80
81 let mut build_command = AsBuild::with_capacity(blas_ids.len(), tlas_ids.len());
82
83 for blas in blas_ids {
84 let blas = hub.blas_s.get(*blas).get()?;
85 build_command.blas_s_built.push(blas);
86 }
87
88 for tlas in tlas_ids {
89 let tlas = hub.tlas_s.get(*tlas).get()?;
90 build_command.tlas_s_built.push(TlasBuild {
91 tlas,
92 dependencies: Vec::new(),
93 });
94 }
95
96 cmd_buf_data.as_actions.push(AsAction::Build(build_command));
97 Ok(())
98 },
99 )
100 }
101
102 pub fn command_encoder_build_acceleration_structures<'a>(
103 &self,
104 command_encoder_id: CommandEncoderId,
105 blas_iter: impl Iterator<Item = BlasBuildEntry<'a>>,
106 tlas_iter: impl Iterator<Item = TlasPackage<'a>>,
107 ) -> Result<(), EncoderStateError> {
108 profiling::scope!("CommandEncoder::build_acceleration_structures");
109
110 let hub = &self.hub;
111
112 let cmd_enc = hub.command_encoders.get(command_encoder_id);
113 let mut cmd_buf_data = cmd_enc.data.lock();
114
115 cmd_buf_data.push_with(|| -> Result<_, BuildAccelerationStructureError> {
116 let blas = blas_iter
117 .map(|blas_entry| {
118 let geometries = match blas_entry.geometries {
119 BlasGeometries::TriangleGeometries(triangle_geometries) => {
120 let tri_geo = triangle_geometries
121 .map(|tg| {
122 Ok(ArcBlasTriangleGeometry {
123 size: tg.size.clone(),
124 vertex_buffer: self.resolve_buffer_id(tg.vertex_buffer)?,
125 index_buffer: tg
126 .index_buffer
127 .map(|id| self.resolve_buffer_id(id))
128 .transpose()?,
129 transform_buffer: tg
130 .transform_buffer
131 .map(|id| self.resolve_buffer_id(id))
132 .transpose()?,
133 first_vertex: tg.first_vertex,
134 vertex_stride: tg.vertex_stride,
135 first_index: tg.first_index,
136 transform_buffer_offset: tg.transform_buffer_offset,
137 })
138 })
139 .collect::<Result<_, BuildAccelerationStructureError>>()?;
140 ArcBlasGeometries::TriangleGeometries(tri_geo)
141 }
142 };
143 Ok(ArcBlasBuildEntry {
144 blas: self.resolve_blas_id(blas_entry.blas_id)?,
145 geometries,
146 })
147 })
148 .collect::<Result<_, BuildAccelerationStructureError>>()?;
149
150 let tlas = tlas_iter
151 .map(|tlas_package| {
152 let instances = tlas_package
153 .instances
154 .map(|instance| {
155 instance
156 .as_ref()
157 .map(|instance| {
158 Ok(ArcTlasInstance {
159 blas: self.resolve_blas_id(instance.blas_id)?,
160 transform: *instance.transform,
161 custom_data: instance.custom_data,
162 mask: instance.mask,
163 })
164 })
165 .transpose()
166 })
167 .collect::<Result<_, BuildAccelerationStructureError>>()?;
168 Ok(ArcTlasPackage {
169 tlas: self.resolve_tlas_id(tlas_package.tlas_id)?,
170 instances,
171 lowest_unmodified: tlas_package.lowest_unmodified,
172 })
173 })
174 .collect::<Result<_, BuildAccelerationStructureError>>()?;
175
176 Ok(ArcCommand::BuildAccelerationStructures { blas, tlas })
177 })
178 }
179}
180
181pub(crate) fn build_acceleration_structures(
182 state: &mut EncodingState,
183 blas: Vec<OwnedBlasBuildEntry<ArcReferences>>,
184 tlas: Vec<OwnedTlasPackage<ArcReferences>>,
185) -> Result<(), BuildAccelerationStructureError> {
186 state
187 .device
188 .require_features(Features::EXPERIMENTAL_RAY_QUERY)?;
189
190 let mut build_command = AsBuild::with_capacity(blas.len(), tlas.len());
191 let mut input_barriers = Vec::<hal::BufferBarrier<dyn hal::DynBuffer>>::new();
192 let mut scratch_buffer_blas_size = 0;
193 let mut blas_storage = Vec::with_capacity(blas.len());
194 iter_blas(
195 blas.iter(),
196 &mut build_command,
197 &mut input_barriers,
198 &mut scratch_buffer_blas_size,
199 &mut blas_storage,
200 state,
201 )?;
202
203 let mut scratch_buffer_tlas_size = 0;
204 let mut tlas_storage = Vec::<TlasStore>::with_capacity(tlas.len());
205 let mut instance_buffer_staging_source = Vec::<u8>::new();
206
207 for package in tlas.iter() {
208 let tlas = &package.tlas;
209 state.tracker.tlas_s.insert_single(tlas.clone());
210
211 let scratch_buffer_offset = scratch_buffer_tlas_size;
212 scratch_buffer_tlas_size += align_to(
213 tlas.size_info.build_scratch_size as u32,
214 state.device.alignments.ray_tracing_scratch_buffer_alignment,
215 ) as u64;
216
217 let first_byte_index = instance_buffer_staging_source.len();
218
219 let mut dependencies = Vec::new();
220
221 let mut instance_count = 0;
222 for instance in package.instances.iter().flatten() {
223 if instance.custom_data >= (1u32 << 24u32) {
224 return Err(BuildAccelerationStructureError::TlasInvalidCustomIndex(
225 tlas.error_ident(),
226 ));
227 }
228 let blas = &instance.blas;
229 state.tracker.blas_s.insert_single(blas.clone());
230
231 instance_buffer_staging_source.extend(state.device.raw().tlas_instance_to_bytes(
232 hal::TlasInstance {
233 transform: instance.transform,
234 custom_data: instance.custom_data,
235 mask: instance.mask,
236 blas_address: blas.handle,
237 },
238 ));
239
240 if tlas
241 .flags
242 .contains(wgpu_types::AccelerationStructureFlags::ALLOW_RAY_HIT_VERTEX_RETURN)
243 && !blas
244 .flags
245 .contains(wgpu_types::AccelerationStructureFlags::ALLOW_RAY_HIT_VERTEX_RETURN)
246 {
247 return Err(
248 BuildAccelerationStructureError::TlasDependentMissingVertexReturn(
249 tlas.error_ident(),
250 blas.error_ident(),
251 ),
252 );
253 }
254
255 instance_count += 1;
256
257 dependencies.push(blas.clone());
258 }
259
260 build_command.tlas_s_built.push(TlasBuild {
261 tlas: tlas.clone(),
262 dependencies,
263 });
264
265 if instance_count > tlas.max_instance_count {
266 return Err(BuildAccelerationStructureError::TlasInstanceCountExceeded(
267 tlas.error_ident(),
268 instance_count,
269 tlas.max_instance_count,
270 ));
271 }
272
273 tlas_storage.push(TlasStore {
274 internal: UnsafeTlasStore {
275 tlas: tlas.clone(),
276 entries: hal::AccelerationStructureEntries::Instances(
277 hal::AccelerationStructureInstances {
278 buffer: Some(tlas.instance_buffer.as_ref()),
279 offset: 0,
280 count: instance_count,
281 },
282 ),
283 scratch_buffer_offset,
284 },
285 range: first_byte_index..instance_buffer_staging_source.len(),
286 });
287 }
288
289 let Some(scratch_size) =
290 wgt::BufferSize::new(max(scratch_buffer_blas_size, scratch_buffer_tlas_size))
291 else {
292 return Ok(());
294 };
295
296 let scratch_buffer = ScratchBuffer::new(state.device, scratch_size)?;
297
298 let scratch_buffer_barrier = hal::BufferBarrier::<dyn hal::DynBuffer> {
299 buffer: scratch_buffer.raw(),
300 usage: hal::StateTransition {
301 from: BufferUses::ACCELERATION_STRUCTURE_SCRATCH,
302 to: BufferUses::ACCELERATION_STRUCTURE_SCRATCH,
303 },
304 };
305
306 let mut tlas_descriptors = Vec::with_capacity(tlas_storage.len());
307
308 for &TlasStore {
309 internal:
310 UnsafeTlasStore {
311 ref tlas,
312 ref entries,
313 ref scratch_buffer_offset,
314 },
315 ..
316 } in &tlas_storage
317 {
318 if tlas.update_mode == wgt::AccelerationStructureUpdateMode::PreferUpdate {
319 log::warn!("build_acceleration_structures called with PreferUpdate, but only rebuild is implemented");
320 }
321 tlas_descriptors.push(hal::BuildAccelerationStructureDescriptor {
322 entries,
323 mode: hal::AccelerationStructureBuildMode::Build,
324 flags: tlas.flags,
325 source_acceleration_structure: None,
326 destination_acceleration_structure: tlas.try_raw(state.snatch_guard)?,
327 scratch_buffer: scratch_buffer.raw(),
328 scratch_buffer_offset: *scratch_buffer_offset,
329 })
330 }
331
332 let blas_present = !blas_storage.is_empty();
333 let tlas_present = !tlas_storage.is_empty();
334
335 let raw_encoder = &mut state.raw_encoder;
336
337 let mut blas_s_compactable = Vec::new();
338 let mut descriptors = Vec::with_capacity(blas.len());
339
340 for storage in &blas_storage {
341 descriptors.push(map_blas(
342 storage,
343 scratch_buffer.raw(),
344 state.snatch_guard,
345 &mut blas_s_compactable,
346 )?);
347 }
348
349 build_blas(
350 *raw_encoder,
351 blas_present,
352 tlas_present,
353 input_barriers,
354 &descriptors,
355 scratch_buffer_barrier,
356 blas_s_compactable,
357 );
358
359 if tlas_present {
360 let staging_buffer = if !instance_buffer_staging_source.is_empty() {
361 let mut staging_buffer = StagingBuffer::new(
362 state.device,
363 wgt::BufferSize::new(instance_buffer_staging_source.len() as u64).unwrap(),
364 )?;
365 staging_buffer.write(&instance_buffer_staging_source);
366 let flushed = staging_buffer.flush();
367 Some(flushed)
368 } else {
369 None
370 };
371
372 unsafe {
373 if let Some(ref staging_buffer) = staging_buffer {
374 raw_encoder.transition_buffers(&[hal::BufferBarrier::<dyn hal::DynBuffer> {
375 buffer: staging_buffer.raw(),
376 usage: hal::StateTransition {
377 from: BufferUses::MAP_WRITE,
378 to: BufferUses::COPY_SRC,
379 },
380 }]);
381 }
382 }
383
384 let mut instance_buffer_barriers = Vec::new();
385 for &TlasStore {
386 internal: UnsafeTlasStore { ref tlas, .. },
387 ref range,
388 } in &tlas_storage
389 {
390 let size = match wgt::BufferSize::new((range.end - range.start) as u64) {
391 None => continue,
392 Some(size) => size,
393 };
394 instance_buffer_barriers.push(hal::BufferBarrier::<dyn hal::DynBuffer> {
395 buffer: tlas.instance_buffer.as_ref(),
396 usage: hal::StateTransition {
397 from: BufferUses::COPY_DST,
398 to: BufferUses::TOP_LEVEL_ACCELERATION_STRUCTURE_INPUT,
399 },
400 });
401 unsafe {
402 raw_encoder.transition_buffers(&[hal::BufferBarrier::<dyn hal::DynBuffer> {
403 buffer: tlas.instance_buffer.as_ref(),
404 usage: hal::StateTransition {
405 from: BufferUses::TOP_LEVEL_ACCELERATION_STRUCTURE_INPUT,
406 to: BufferUses::COPY_DST,
407 },
408 }]);
409 let temp = hal::BufferCopy {
410 src_offset: range.start as u64,
411 dst_offset: 0,
412 size,
413 };
414 raw_encoder.copy_buffer_to_buffer(
415 staging_buffer.as_ref().unwrap().raw(),
416 tlas.instance_buffer.as_ref(),
417 &[temp],
418 );
419 }
420 }
421
422 unsafe {
423 raw_encoder.transition_buffers(&instance_buffer_barriers);
424
425 raw_encoder.build_acceleration_structures(&tlas_descriptors);
426
427 raw_encoder.place_acceleration_structure_barrier(hal::AccelerationStructureBarrier {
428 usage: hal::StateTransition {
429 from: hal::AccelerationStructureUses::BUILD_OUTPUT,
430 to: hal::AccelerationStructureUses::SHADER_INPUT,
431 },
432 });
433 }
434
435 if let Some(staging_buffer) = staging_buffer {
436 state
437 .temp_resources
438 .push(TempResource::StagingBuffer(staging_buffer));
439 }
440 }
441
442 state
443 .temp_resources
444 .push(TempResource::ScratchBuffer(scratch_buffer));
445
446 state.as_actions.push(AsAction::Build(build_command));
447
448 Ok(())
449}
450
451impl CommandBufferMutable {
452 pub(crate) fn validate_acceleration_structure_actions(
453 &self,
454 snatch_guard: &SnatchGuard,
455 command_index_guard: &mut RwLockWriteGuard<CommandIndices>,
456 ) -> Result<(), ValidateAsActionsError> {
457 profiling::scope!("CommandEncoder::[submission]::validate_as_actions");
458 for action in &self.as_actions {
459 match action {
460 AsAction::Build(build) => {
461 let build_command_index = NonZeroU64::new(
462 command_index_guard.next_acceleration_structure_build_command_index,
463 )
464 .unwrap();
465
466 command_index_guard.next_acceleration_structure_build_command_index += 1;
467 for blas in build.blas_s_built.iter() {
468 let mut state_lock = blas.compacted_state.lock();
469 *state_lock = match *state_lock {
470 BlasCompactState::Compacted => {
471 unreachable!("Should be validated out in build.")
472 }
473 _ => BlasCompactState::Idle,
476 };
477 *blas.built_index.write() = Some(build_command_index);
478 }
479
480 for tlas_build in build.tlas_s_built.iter() {
481 for blas in &tlas_build.dependencies {
482 if blas.built_index.read().is_none() {
483 return Err(ValidateAsActionsError::UsedUnbuiltBlas(
484 blas.error_ident(),
485 tlas_build.tlas.error_ident(),
486 ));
487 }
488 }
489 *tlas_build.tlas.built_index.write() = Some(build_command_index);
490 tlas_build
491 .tlas
492 .dependencies
493 .write()
494 .clone_from(&tlas_build.dependencies)
495 }
496 }
497 AsAction::UseTlas(tlas) => {
498 let tlas_build_index = tlas.built_index.read();
499 let dependencies = tlas.dependencies.read();
500
501 if (*tlas_build_index).is_none() {
502 return Err(ValidateAsActionsError::UsedUnbuiltTlas(tlas.error_ident()));
503 }
504 for blas in dependencies.deref() {
505 let blas_build_index = *blas.built_index.read();
506 if blas_build_index.is_none() {
507 return Err(ValidateAsActionsError::UsedUnbuiltBlas(
508 tlas.error_ident(),
509 blas.error_ident(),
510 ));
511 }
512 if blas_build_index.unwrap() > tlas_build_index.unwrap() {
513 return Err(ValidateAsActionsError::BlasNewerThenTlas(
514 blas.error_ident(),
515 tlas.error_ident(),
516 ));
517 }
518 blas.try_raw(snatch_guard)?;
519 }
520 }
521 }
522 }
523 Ok(())
524 }
525
526 pub(crate) fn set_acceleration_structure_dependencies(&self, snatch_guard: &SnatchGuard) {
527 profiling::scope!("CommandEncoder::[submission]::set_acceleration_structure_dependencies");
528 let tlas_dependencies_locks: Vec<_> = self
529 .as_actions
530 .iter()
531 .filter_map(|action| {
532 if let AsAction::UseTlas(tlas) = action {
533 Some(tlas.dependencies.read())
534 } else {
535 None
536 }
537 })
538 .collect();
539 let mut tlas_dependencies_lock_iter = tlas_dependencies_locks.iter();
540 let mut dependencies = Vec::new();
541 for action in &self.as_actions {
542 match action {
543 AsAction::Build(build) => {
544 for tlas_build in build.tlas_s_built.iter() {
545 for dependency in &tlas_build.dependencies {
546 if let Some(dependency) = dependency.raw(snatch_guard) {
547 dependencies.push(dependency);
548 }
549 }
550 }
551 }
552 AsAction::UseTlas(_tlas) => {
553 let tlas_dependencies = tlas_dependencies_lock_iter.next().unwrap(); for dependency in tlas_dependencies.iter() {
555 if let Some(dependency) = dependency.raw(snatch_guard) {
556 dependencies.push(dependency);
557 }
558 }
559 }
560 }
561 }
562 if !dependencies.is_empty() {
563 unsafe {
564 self.encoder
565 .raw
566 .set_acceleration_structure_dependencies(&self.encoder.list, &dependencies);
567 }
568 }
569 }
570}
571
572fn iter_blas<'snatch_guard: 'buffers, 'buffers>(
574 blas_iter: impl Iterator<Item = &'buffers OwnedBlasBuildEntry<ArcReferences>>,
575 build_command: &mut AsBuild,
576 input_barriers: &mut Vec<hal::BufferBarrier<'buffers, dyn hal::DynBuffer>>,
577 scratch_buffer_blas_size: &mut u64,
578 blas_storage: &mut Vec<BlasStore<'buffers>>,
579 state: &mut EncodingState<'snatch_guard, '_>,
580) -> Result<(), BuildAccelerationStructureError> {
581 for entry in blas_iter {
582 let blas = &entry.blas;
583 state.tracker.blas_s.insert_single(blas.clone());
584
585 build_command.blas_s_built.push(blas.clone());
586
587 match &entry.geometries {
588 ArcBlasGeometries::TriangleGeometries(triangle_geometries) => {
589 let mut triangle_entries =
590 Vec::<hal::AccelerationStructureTriangles<dyn hal::DynBuffer>>::new();
591
592 for (i, mesh) in triangle_geometries.iter().enumerate() {
593 let size_desc = match &blas.sizes {
594 wgt::BlasGeometrySizeDescriptors::Triangles { descriptors } => descriptors,
595 };
596 if i >= size_desc.len() {
597 return Err(BuildAccelerationStructureError::IncompatibleBlasBuildSizes(
598 blas.error_ident(),
599 ));
600 }
601 let size_desc = &size_desc[i];
602
603 if size_desc.flags != mesh.size.flags {
604 return Err(BuildAccelerationStructureError::IncompatibleBlasFlags(
605 blas.error_ident(),
606 size_desc.flags,
607 mesh.size.flags,
608 ));
609 }
610
611 if size_desc.vertex_count < mesh.size.vertex_count {
612 return Err(
613 BuildAccelerationStructureError::IncompatibleBlasVertexCount(
614 blas.error_ident(),
615 size_desc.vertex_count,
616 mesh.size.vertex_count,
617 ),
618 );
619 }
620
621 if size_desc.vertex_format != mesh.size.vertex_format {
622 return Err(BuildAccelerationStructureError::DifferentBlasVertexFormats(
623 blas.error_ident(),
624 size_desc.vertex_format,
625 mesh.size.vertex_format,
626 ));
627 }
628
629 if size_desc
630 .vertex_format
631 .min_acceleration_structure_vertex_stride()
632 > mesh.vertex_stride
633 {
634 return Err(BuildAccelerationStructureError::VertexStrideTooSmall(
635 blas.error_ident(),
636 size_desc
637 .vertex_format
638 .min_acceleration_structure_vertex_stride(),
639 mesh.vertex_stride,
640 ));
641 }
642
643 if mesh.vertex_stride
644 % size_desc
645 .vertex_format
646 .acceleration_structure_stride_alignment()
647 != 0
648 {
649 return Err(BuildAccelerationStructureError::VertexStrideUnaligned(
650 blas.error_ident(),
651 size_desc
652 .vertex_format
653 .acceleration_structure_stride_alignment(),
654 mesh.vertex_stride,
655 ));
656 }
657
658 match (size_desc.index_count, mesh.size.index_count) {
659 (Some(_), None) | (None, Some(_)) => {
660 return Err(
661 BuildAccelerationStructureError::BlasIndexCountProvidedMismatch(
662 blas.error_ident(),
663 ),
664 )
665 }
666 (Some(create), Some(build)) if create < build => {
667 return Err(
668 BuildAccelerationStructureError::IncompatibleBlasIndexCount(
669 blas.error_ident(),
670 create,
671 build,
672 ),
673 )
674 }
675 _ => {}
676 }
677
678 if size_desc.index_format != mesh.size.index_format {
679 return Err(BuildAccelerationStructureError::DifferentBlasIndexFormats(
680 blas.error_ident(),
681 size_desc.index_format,
682 mesh.size.index_format,
683 ));
684 }
685
686 if size_desc.index_count.is_some() && mesh.index_buffer.is_none() {
687 return Err(BuildAccelerationStructureError::MissingIndexBuffer(
688 blas.error_ident(),
689 ));
690 }
691 let vertex_buffer = mesh.vertex_buffer.clone();
692 let vertex_pending = state.tracker.buffers.set_single(
693 &vertex_buffer,
694 BufferUses::BOTTOM_LEVEL_ACCELERATION_STRUCTURE_INPUT,
695 );
696 let vertex_buffer = {
697 let vertex_raw = mesh.vertex_buffer.as_ref().try_raw(state.snatch_guard)?;
698 let vertex_buffer = &mesh.vertex_buffer;
699 vertex_buffer.check_usage(BufferUsages::BLAS_INPUT)?;
700
701 if let Some(barrier) = vertex_pending.map(|pending| {
702 pending.into_hal(vertex_buffer.as_ref(), state.snatch_guard)
703 }) {
704 input_barriers.push(barrier);
705 }
706 if vertex_buffer.size
707 < (mesh.size.vertex_count + mesh.first_vertex) as u64
708 * mesh.vertex_stride
709 {
710 return Err(BuildAccelerationStructureError::InsufficientBufferSize(
711 vertex_buffer.error_ident(),
712 vertex_buffer.size,
713 (mesh.size.vertex_count + mesh.first_vertex) as u64
714 * mesh.vertex_stride,
715 ));
716 }
717 let vertex_buffer_offset = mesh.first_vertex as u64 * mesh.vertex_stride;
718 state.buffer_memory_init_actions.extend(
719 vertex_buffer.initialization_status.read().create_action(
720 vertex_buffer,
721 vertex_buffer_offset
722 ..(vertex_buffer_offset
723 + mesh.size.vertex_count as u64 * mesh.vertex_stride),
724 MemoryInitKind::NeedsInitializedMemory,
725 ),
726 );
727 vertex_raw
728 };
729 let index_buffer = if let Some(ref index_buffer) = mesh.index_buffer {
730 if mesh.first_index.is_none()
731 || mesh.size.index_count.is_none()
732 || mesh.size.index_count.is_none()
733 {
734 return Err(BuildAccelerationStructureError::MissingAssociatedData(
735 index_buffer.error_ident(),
736 ));
737 }
738 let index_pending = state.tracker.buffers.set_single(
739 index_buffer,
740 BufferUses::BOTTOM_LEVEL_ACCELERATION_STRUCTURE_INPUT,
741 );
742 let index_raw = index_buffer.try_raw(state.snatch_guard)?;
743 index_buffer.check_usage(BufferUsages::BLAS_INPUT)?;
744
745 if let Some(barrier) = index_pending.map(|pending| {
746 pending.into_hal(index_buffer.as_ref(), state.snatch_guard)
747 }) {
748 input_barriers.push(barrier);
749 }
750 let index_stride = mesh.size.index_format.unwrap().byte_size() as u64;
751 let offset = mesh.first_index.unwrap() as u64 * index_stride;
752 let index_buffer_size =
753 mesh.size.index_count.unwrap() as u64 * index_stride;
754
755 if mesh.size.index_count.unwrap() % 3 != 0 {
756 return Err(BuildAccelerationStructureError::InvalidIndexCount(
757 index_buffer.error_ident(),
758 mesh.size.index_count.unwrap(),
759 ));
760 }
761 if index_buffer.size
762 < mesh.size.index_count.unwrap() as u64 * index_stride + offset
763 {
764 return Err(BuildAccelerationStructureError::InsufficientBufferSize(
765 index_buffer.error_ident(),
766 index_buffer.size,
767 mesh.size.index_count.unwrap() as u64 * index_stride + offset,
768 ));
769 }
770
771 state.buffer_memory_init_actions.extend(
772 index_buffer.initialization_status.read().create_action(
773 index_buffer,
774 offset..(offset + index_buffer_size),
775 MemoryInitKind::NeedsInitializedMemory,
776 ),
777 );
778 Some(index_raw)
779 } else {
780 None
781 };
782 let transform_buffer = if let Some(ref transform_buffer) = mesh.transform_buffer
783 {
784 if !blas
785 .flags
786 .contains(wgt::AccelerationStructureFlags::USE_TRANSFORM)
787 {
788 return Err(BuildAccelerationStructureError::UseTransformMissing(
789 blas.error_ident(),
790 ));
791 }
792 if mesh.transform_buffer_offset.is_none() {
793 return Err(BuildAccelerationStructureError::MissingAssociatedData(
794 transform_buffer.error_ident(),
795 ));
796 }
797 let transform_pending = state.tracker.buffers.set_single(
798 transform_buffer,
799 BufferUses::BOTTOM_LEVEL_ACCELERATION_STRUCTURE_INPUT,
800 );
801 if mesh.transform_buffer_offset.is_none() {
802 return Err(BuildAccelerationStructureError::MissingAssociatedData(
803 transform_buffer.error_ident(),
804 ));
805 }
806 let transform_raw = transform_buffer.try_raw(state.snatch_guard)?;
807 transform_buffer.check_usage(BufferUsages::BLAS_INPUT)?;
808
809 if let Some(barrier) = transform_pending.map(|pending| {
810 pending.into_hal(transform_buffer.as_ref(), state.snatch_guard)
811 }) {
812 input_barriers.push(barrier);
813 }
814
815 let offset = mesh.transform_buffer_offset.unwrap();
816
817 if offset % wgt::TRANSFORM_BUFFER_ALIGNMENT != 0 {
818 return Err(
819 BuildAccelerationStructureError::UnalignedTransformBufferOffset(
820 transform_buffer.error_ident(),
821 ),
822 );
823 }
824 if transform_buffer.size < 48 + offset {
825 return Err(BuildAccelerationStructureError::InsufficientBufferSize(
826 transform_buffer.error_ident(),
827 transform_buffer.size,
828 48 + offset,
829 ));
830 }
831 state.buffer_memory_init_actions.extend(
832 transform_buffer.initialization_status.read().create_action(
833 transform_buffer,
834 offset..(offset + 48),
835 MemoryInitKind::NeedsInitializedMemory,
836 ),
837 );
838 Some(transform_raw)
839 } else {
840 if blas
841 .flags
842 .contains(wgt::AccelerationStructureFlags::USE_TRANSFORM)
843 {
844 return Err(BuildAccelerationStructureError::TransformMissing(
845 blas.error_ident(),
846 ));
847 }
848 None
849 };
850
851 let triangles = hal::AccelerationStructureTriangles {
852 vertex_buffer: Some(vertex_buffer),
853 vertex_format: mesh.size.vertex_format,
854 first_vertex: mesh.first_vertex,
855 vertex_count: mesh.size.vertex_count,
856 vertex_stride: mesh.vertex_stride,
857 indices: index_buffer.map(|index_buffer| {
858 let index_stride = mesh.size.index_format.unwrap().byte_size() as u32;
859 hal::AccelerationStructureTriangleIndices::<dyn hal::DynBuffer> {
860 format: mesh.size.index_format.unwrap(),
861 buffer: Some(index_buffer),
862 offset: mesh.first_index.unwrap() * index_stride,
863 count: mesh.size.index_count.unwrap(),
864 }
865 }),
866 transform: transform_buffer.map(|transform_buffer| {
867 hal::AccelerationStructureTriangleTransform {
868 buffer: transform_buffer,
869 offset: mesh.transform_buffer_offset.unwrap() as u32,
870 }
871 }),
872 flags: mesh.size.flags,
873 };
874 triangle_entries.push(triangles);
875 }
876
877 {
878 let scratch_buffer_offset = *scratch_buffer_blas_size;
879 *scratch_buffer_blas_size += align_to(
880 blas.size_info.build_scratch_size as u32,
881 state.device.alignments.ray_tracing_scratch_buffer_alignment,
882 ) as u64;
883
884 blas_storage.push(BlasStore {
885 blas: blas.clone(),
886 entries: hal::AccelerationStructureEntries::Triangles(triangle_entries),
887 scratch_buffer_offset,
888 });
889 }
890 }
891 }
892 }
893 Ok(())
894}
895
896fn map_blas<'a>(
897 storage: &'a BlasStore<'_>,
898 scratch_buffer: &'a dyn hal::DynBuffer,
899 snatch_guard: &'a SnatchGuard,
900 blases_compactable: &mut Vec<(
901 &'a dyn hal::DynBuffer,
902 &'a dyn hal::DynAccelerationStructure,
903 )>,
904) -> Result<
905 hal::BuildAccelerationStructureDescriptor<
906 'a,
907 dyn hal::DynBuffer,
908 dyn hal::DynAccelerationStructure,
909 >,
910 BuildAccelerationStructureError,
911> {
912 let BlasStore {
913 blas,
914 entries,
915 scratch_buffer_offset,
916 } = storage;
917 if blas.update_mode == wgt::AccelerationStructureUpdateMode::PreferUpdate {
918 log::debug!("only rebuild implemented")
919 }
920 let raw = blas.try_raw(snatch_guard)?;
921
922 let state_lock = blas.compacted_state.lock();
923 if let BlasCompactState::Compacted = *state_lock {
924 return Err(BuildAccelerationStructureError::CompactedBlas(
925 blas.error_ident(),
926 ));
927 }
928
929 if blas
930 .flags
931 .contains(wgpu_types::AccelerationStructureFlags::ALLOW_COMPACTION)
932 {
933 blases_compactable.push((blas.compaction_buffer.as_ref().unwrap().as_ref(), raw));
934 }
935 Ok(hal::BuildAccelerationStructureDescriptor {
936 entries,
937 mode: hal::AccelerationStructureBuildMode::Build,
938 flags: blas.flags,
939 source_acceleration_structure: None,
940 destination_acceleration_structure: raw,
941 scratch_buffer,
942 scratch_buffer_offset: *scratch_buffer_offset,
943 })
944}
945
946fn build_blas<'a>(
947 cmd_buf_raw: &mut dyn hal::DynCommandEncoder,
948 blas_present: bool,
949 tlas_present: bool,
950 input_barriers: Vec<hal::BufferBarrier<dyn hal::DynBuffer>>,
951 blas_descriptors: &[hal::BuildAccelerationStructureDescriptor<
952 'a,
953 dyn hal::DynBuffer,
954 dyn hal::DynAccelerationStructure,
955 >],
956 scratch_buffer_barrier: hal::BufferBarrier<dyn hal::DynBuffer>,
957 blas_s_for_compaction: Vec<(
958 &'a dyn hal::DynBuffer,
959 &'a dyn hal::DynAccelerationStructure,
960 )>,
961) {
962 unsafe {
963 cmd_buf_raw.transition_buffers(&input_barriers);
964 }
965
966 if blas_present {
967 unsafe {
968 cmd_buf_raw.place_acceleration_structure_barrier(hal::AccelerationStructureBarrier {
969 usage: hal::StateTransition {
970 from: hal::AccelerationStructureUses::BUILD_INPUT,
971 to: hal::AccelerationStructureUses::BUILD_OUTPUT,
972 },
973 });
974
975 cmd_buf_raw.build_acceleration_structures(blas_descriptors);
976 }
977 }
978
979 if blas_present && tlas_present {
980 unsafe {
981 cmd_buf_raw.transition_buffers(&[scratch_buffer_barrier]);
982 }
983 }
984
985 let mut source_usage = hal::AccelerationStructureUses::empty();
986 let mut destination_usage = hal::AccelerationStructureUses::empty();
987 for &(buf, blas) in blas_s_for_compaction.iter() {
988 unsafe {
989 cmd_buf_raw.transition_buffers(&[hal::BufferBarrier {
990 buffer: buf,
991 usage: hal::StateTransition {
992 from: BufferUses::ACCELERATION_STRUCTURE_QUERY,
993 to: BufferUses::ACCELERATION_STRUCTURE_QUERY,
994 },
995 }])
996 }
997 unsafe { cmd_buf_raw.read_acceleration_structure_compact_size(blas, buf) }
998 destination_usage |= hal::AccelerationStructureUses::COPY_SRC;
999 }
1000
1001 if blas_present {
1002 source_usage |= hal::AccelerationStructureUses::BUILD_OUTPUT;
1003 destination_usage |= hal::AccelerationStructureUses::BUILD_INPUT
1004 }
1005 if tlas_present {
1006 source_usage |= hal::AccelerationStructureUses::SHADER_INPUT;
1007 destination_usage |= hal::AccelerationStructureUses::BUILD_OUTPUT;
1008 }
1009 unsafe {
1010 cmd_buf_raw.place_acceleration_structure_barrier(hal::AccelerationStructureBarrier {
1011 usage: hal::StateTransition {
1012 from: source_usage,
1013 to: destination_usage,
1014 },
1015 });
1016 }
1017}