1use alloc::{sync::Arc, vec::Vec};
2use core::{
3 cmp::max,
4 num::NonZeroU64,
5 ops::{Deref, Range},
6};
7
8use wgt::{math::align_to, BufferUsages, BufferUses, Features};
9
10use crate::{
11 command::encoder::EncodingState,
12 ray_tracing::{AsAction, AsBuild, BlasTriangleGeometryInfo, TlasBuild, ValidateAsActionsError},
13 resource::InvalidResourceError,
14 track::Tracker,
15};
16use crate::{command::EncoderStateError, device::resource::CommandIndices};
17use crate::{
18 command::{ArcCommand, ArcReferences, CommandBufferMutable},
19 device::queue::TempResource,
20 global::Global,
21 id::CommandEncoderId,
22 init_tracker::MemoryInitKind,
23 ray_tracing::{
24 ArcBlasBuildEntry, ArcBlasGeometries, ArcBlasTriangleGeometry, ArcTlasInstance,
25 ArcTlasPackage, BlasBuildEntry, BlasGeometries, BuildAccelerationStructureError,
26 OwnedBlasBuildEntry, OwnedTlasPackage, TlasPackage,
27 },
28 resource::{Blas, BlasCompactState, Buffer, Labeled, StagingBuffer, Tlas},
29 scratch::ScratchBuffer,
30 snatch::SnatchGuard,
31 track::PendingTransition,
32};
33use crate::{lock::RwLockWriteGuard, resource::RawResourceAccess};
34
35use crate::id::{BlasId, TlasId};
36
37struct TriangleBufferStore {
38 vertex_buffer: Arc<Buffer>,
39 vertex_transition: Option<PendingTransition<BufferUses>>,
40 index_buffer_transition: Option<(Arc<Buffer>, Option<PendingTransition<BufferUses>>)>,
41 transform_buffer_transition: Option<(Arc<Buffer>, Option<PendingTransition<BufferUses>>)>,
42 geometry: BlasTriangleGeometryInfo,
43 ending_blas: Option<Arc<Blas>>,
44}
45
46struct BlasStore<'a> {
47 blas: Arc<Blas>,
48 entries: hal::AccelerationStructureEntries<'a, dyn hal::DynBuffer>,
49 scratch_buffer_offset: u64,
50}
51
52struct UnsafeTlasStore<'a> {
53 tlas: Arc<Tlas>,
54 entries: hal::AccelerationStructureEntries<'a, dyn hal::DynBuffer>,
55 scratch_buffer_offset: u64,
56}
57
58struct TlasStore<'a> {
59 internal: UnsafeTlasStore<'a>,
60 range: Range<usize>,
61}
62
63impl Global {
64 fn resolve_blas_id(&self, blas_id: BlasId) -> Result<Arc<Blas>, InvalidResourceError> {
65 self.hub.blas_s.get(blas_id).get()
66 }
67
68 fn resolve_tlas_id(&self, tlas_id: TlasId) -> Result<Arc<Tlas>, InvalidResourceError> {
69 self.hub.tlas_s.get(tlas_id).get()
70 }
71
72 pub fn command_encoder_mark_acceleration_structures_built(
73 &self,
74 command_encoder_id: CommandEncoderId,
75 blas_ids: &[BlasId],
76 tlas_ids: &[TlasId],
77 ) -> Result<(), EncoderStateError> {
78 profiling::scope!("CommandEncoder::mark_acceleration_structures_built");
79
80 let hub = &self.hub;
81
82 let cmd_enc = hub.command_encoders.get(command_encoder_id);
83
84 let mut cmd_buf_data = cmd_enc.data.lock();
85 cmd_buf_data.with_buffer(
86 crate::command::EncodingApi::Raw,
87 |cmd_buf_data| -> Result<(), BuildAccelerationStructureError> {
88 let device = &cmd_enc.device;
89 device.check_is_valid()?;
90 device.require_features(Features::EXPERIMENTAL_RAY_QUERY)?;
91
92 let mut build_command = AsBuild::default();
93
94 for blas in blas_ids {
95 let blas = hub.blas_s.get(*blas).get()?;
96 build_command.blas_s_built.push(blas);
97 }
98
99 for tlas in tlas_ids {
100 let tlas = hub.tlas_s.get(*tlas).get()?;
101 build_command.tlas_s_built.push(TlasBuild {
102 tlas,
103 dependencies: Vec::new(),
104 });
105 }
106
107 cmd_buf_data.as_actions.push(AsAction::Build(build_command));
108 Ok(())
109 },
110 )
111 }
112
113 pub fn command_encoder_build_acceleration_structures<'a>(
114 &self,
115 command_encoder_id: CommandEncoderId,
116 blas_iter: impl Iterator<Item = BlasBuildEntry<'a>>,
117 tlas_iter: impl Iterator<Item = TlasPackage<'a>>,
118 ) -> Result<(), EncoderStateError> {
119 profiling::scope!("CommandEncoder::build_acceleration_structures");
120
121 let hub = &self.hub;
122
123 let cmd_enc = hub.command_encoders.get(command_encoder_id);
124 let mut cmd_buf_data = cmd_enc.data.lock();
125
126 cmd_buf_data.push_with(|| -> Result<_, BuildAccelerationStructureError> {
127 let blas = blas_iter
128 .map(|blas_entry| {
129 let geometries = match blas_entry.geometries {
130 BlasGeometries::TriangleGeometries(triangle_geometries) => {
131 let tri_geo = triangle_geometries
132 .map(|tg| {
133 Ok(ArcBlasTriangleGeometry {
134 size: tg.size.clone(),
135 vertex_buffer: self.resolve_buffer_id(tg.vertex_buffer)?,
136 index_buffer: tg
137 .index_buffer
138 .map(|id| self.resolve_buffer_id(id))
139 .transpose()?,
140 transform_buffer: tg
141 .transform_buffer
142 .map(|id| self.resolve_buffer_id(id))
143 .transpose()?,
144 first_vertex: tg.first_vertex,
145 vertex_stride: tg.vertex_stride,
146 first_index: tg.first_index,
147 transform_buffer_offset: tg.transform_buffer_offset,
148 })
149 })
150 .collect::<Result<_, BuildAccelerationStructureError>>()?;
151 ArcBlasGeometries::TriangleGeometries(tri_geo)
152 }
153 };
154 Ok(ArcBlasBuildEntry {
155 blas: self.resolve_blas_id(blas_entry.blas_id)?,
156 geometries,
157 })
158 })
159 .collect::<Result<_, BuildAccelerationStructureError>>()?;
160
161 let tlas = tlas_iter
162 .map(|tlas_package| {
163 let instances = tlas_package
164 .instances
165 .map(|instance| {
166 instance
167 .as_ref()
168 .map(|instance| {
169 Ok(ArcTlasInstance {
170 blas: self.resolve_blas_id(instance.blas_id)?,
171 transform: *instance.transform,
172 custom_data: instance.custom_data,
173 mask: instance.mask,
174 })
175 })
176 .transpose()
177 })
178 .collect::<Result<_, BuildAccelerationStructureError>>()?;
179 Ok(ArcTlasPackage {
180 tlas: self.resolve_tlas_id(tlas_package.tlas_id)?,
181 instances,
182 lowest_unmodified: tlas_package.lowest_unmodified,
183 })
184 })
185 .collect::<Result<_, BuildAccelerationStructureError>>()?;
186
187 Ok(ArcCommand::BuildAccelerationStructures { blas, tlas })
188 })
189 }
190}
191
192pub(crate) fn build_acceleration_structures(
193 state: &mut EncodingState,
194 blas: Vec<OwnedBlasBuildEntry<ArcReferences>>,
195 tlas: Vec<OwnedTlasPackage<ArcReferences>>,
196) -> Result<(), BuildAccelerationStructureError> {
197 state
198 .device
199 .require_features(Features::EXPERIMENTAL_RAY_QUERY)?;
200
201 let mut build_command = AsBuild::default();
202 let mut buf_storage = Vec::new();
203 iter_blas(
204 blas.into_iter(),
205 state.tracker,
206 &mut build_command,
207 &mut buf_storage,
208 )?;
209
210 let mut input_barriers = Vec::<hal::BufferBarrier<dyn hal::DynBuffer>>::new();
211 let mut scratch_buffer_blas_size = 0;
212 let mut blas_storage = Vec::new();
213 iter_buffers(
214 state,
215 &mut buf_storage,
216 &mut input_barriers,
217 &mut scratch_buffer_blas_size,
218 &mut blas_storage,
219 )?;
220 let mut tlas_lock_store = Vec::<(Option<OwnedTlasPackage<ArcReferences>>, Arc<Tlas>)>::new();
221
222 for package in tlas.into_iter() {
223 let tlas = package.tlas.clone();
224 state.tracker.tlas_s.insert_single(tlas.clone());
225 tlas_lock_store.push((Some(package), tlas))
226 }
227
228 let mut scratch_buffer_tlas_size = 0;
229 let mut tlas_storage = Vec::<TlasStore>::new();
230 let mut instance_buffer_staging_source = Vec::<u8>::new();
231
232 for (package, tlas) in &mut tlas_lock_store {
233 let package = package.take().unwrap();
234
235 let scratch_buffer_offset = scratch_buffer_tlas_size;
236 scratch_buffer_tlas_size += align_to(
237 tlas.size_info.build_scratch_size as u32,
238 state.device.alignments.ray_tracing_scratch_buffer_alignment,
239 ) as u64;
240
241 let first_byte_index = instance_buffer_staging_source.len();
242
243 let mut dependencies = Vec::new();
244
245 let mut instance_count = 0;
246 for instance in package.instances.into_iter().flatten() {
247 if instance.custom_data >= (1u32 << 24u32) {
248 return Err(BuildAccelerationStructureError::TlasInvalidCustomIndex(
249 tlas.error_ident(),
250 ));
251 }
252 let blas = &instance.blas;
253 state.tracker.blas_s.insert_single(blas.clone());
254
255 instance_buffer_staging_source.extend(state.device.raw().tlas_instance_to_bytes(
256 hal::TlasInstance {
257 transform: instance.transform,
258 custom_data: instance.custom_data,
259 mask: instance.mask,
260 blas_address: blas.handle,
261 },
262 ));
263
264 if tlas
265 .flags
266 .contains(wgpu_types::AccelerationStructureFlags::ALLOW_RAY_HIT_VERTEX_RETURN)
267 && !blas
268 .flags
269 .contains(wgpu_types::AccelerationStructureFlags::ALLOW_RAY_HIT_VERTEX_RETURN)
270 {
271 return Err(
272 BuildAccelerationStructureError::TlasDependentMissingVertexReturn(
273 tlas.error_ident(),
274 blas.error_ident(),
275 ),
276 );
277 }
278
279 instance_count += 1;
280
281 dependencies.push(blas.clone());
282 }
283
284 build_command.tlas_s_built.push(TlasBuild {
285 tlas: tlas.clone(),
286 dependencies,
287 });
288
289 if instance_count > tlas.max_instance_count {
290 return Err(BuildAccelerationStructureError::TlasInstanceCountExceeded(
291 tlas.error_ident(),
292 instance_count,
293 tlas.max_instance_count,
294 ));
295 }
296
297 tlas_storage.push(TlasStore {
298 internal: UnsafeTlasStore {
299 tlas: tlas.clone(),
300 entries: hal::AccelerationStructureEntries::Instances(
301 hal::AccelerationStructureInstances {
302 buffer: Some(tlas.instance_buffer.as_ref()),
303 offset: 0,
304 count: instance_count,
305 },
306 ),
307 scratch_buffer_offset,
308 },
309 range: first_byte_index..instance_buffer_staging_source.len(),
310 });
311 }
312
313 let Some(scratch_size) =
314 wgt::BufferSize::new(max(scratch_buffer_blas_size, scratch_buffer_tlas_size))
315 else {
316 return Ok(());
318 };
319
320 let scratch_buffer = ScratchBuffer::new(state.device, scratch_size)?;
321
322 let scratch_buffer_barrier = hal::BufferBarrier::<dyn hal::DynBuffer> {
323 buffer: scratch_buffer.raw(),
324 usage: hal::StateTransition {
325 from: BufferUses::ACCELERATION_STRUCTURE_SCRATCH,
326 to: BufferUses::ACCELERATION_STRUCTURE_SCRATCH,
327 },
328 };
329
330 let mut tlas_descriptors = Vec::with_capacity(tlas_storage.len());
331
332 for &TlasStore {
333 internal:
334 UnsafeTlasStore {
335 ref tlas,
336 ref entries,
337 ref scratch_buffer_offset,
338 },
339 ..
340 } in &tlas_storage
341 {
342 if tlas.update_mode == wgt::AccelerationStructureUpdateMode::PreferUpdate {
343 log::info!("only rebuild implemented")
344 }
345 tlas_descriptors.push(hal::BuildAccelerationStructureDescriptor {
346 entries,
347 mode: hal::AccelerationStructureBuildMode::Build,
348 flags: tlas.flags,
349 source_acceleration_structure: None,
350 destination_acceleration_structure: tlas.try_raw(state.snatch_guard)?,
351 scratch_buffer: scratch_buffer.raw(),
352 scratch_buffer_offset: *scratch_buffer_offset,
353 })
354 }
355
356 let blas_present = !blas_storage.is_empty();
357 let tlas_present = !tlas_storage.is_empty();
358
359 let raw_encoder = &mut state.raw_encoder;
360
361 let mut blas_s_compactable = Vec::new();
362 let mut descriptors = Vec::new();
363
364 for storage in &blas_storage {
365 descriptors.push(map_blas(
366 storage,
367 scratch_buffer.raw(),
368 state.snatch_guard,
369 &mut blas_s_compactable,
370 )?);
371 }
372
373 build_blas(
374 *raw_encoder,
375 blas_present,
376 tlas_present,
377 input_barriers,
378 &descriptors,
379 scratch_buffer_barrier,
380 blas_s_compactable,
381 );
382
383 if tlas_present {
384 let staging_buffer = if !instance_buffer_staging_source.is_empty() {
385 let mut staging_buffer = StagingBuffer::new(
386 state.device,
387 wgt::BufferSize::new(instance_buffer_staging_source.len() as u64).unwrap(),
388 )?;
389 staging_buffer.write(&instance_buffer_staging_source);
390 let flushed = staging_buffer.flush();
391 Some(flushed)
392 } else {
393 None
394 };
395
396 unsafe {
397 if let Some(ref staging_buffer) = staging_buffer {
398 raw_encoder.transition_buffers(&[hal::BufferBarrier::<dyn hal::DynBuffer> {
399 buffer: staging_buffer.raw(),
400 usage: hal::StateTransition {
401 from: BufferUses::MAP_WRITE,
402 to: BufferUses::COPY_SRC,
403 },
404 }]);
405 }
406 }
407
408 let mut instance_buffer_barriers = Vec::new();
409 for &TlasStore {
410 internal: UnsafeTlasStore { ref tlas, .. },
411 ref range,
412 } in &tlas_storage
413 {
414 let size = match wgt::BufferSize::new((range.end - range.start) as u64) {
415 None => continue,
416 Some(size) => size,
417 };
418 instance_buffer_barriers.push(hal::BufferBarrier::<dyn hal::DynBuffer> {
419 buffer: tlas.instance_buffer.as_ref(),
420 usage: hal::StateTransition {
421 from: BufferUses::COPY_DST,
422 to: BufferUses::TOP_LEVEL_ACCELERATION_STRUCTURE_INPUT,
423 },
424 });
425 unsafe {
426 raw_encoder.transition_buffers(&[hal::BufferBarrier::<dyn hal::DynBuffer> {
427 buffer: tlas.instance_buffer.as_ref(),
428 usage: hal::StateTransition {
429 from: BufferUses::TOP_LEVEL_ACCELERATION_STRUCTURE_INPUT,
430 to: BufferUses::COPY_DST,
431 },
432 }]);
433 let temp = hal::BufferCopy {
434 src_offset: range.start as u64,
435 dst_offset: 0,
436 size,
437 };
438 raw_encoder.copy_buffer_to_buffer(
439 staging_buffer.as_ref().unwrap().raw(),
440 tlas.instance_buffer.as_ref(),
441 &[temp],
442 );
443 }
444 }
445
446 unsafe {
447 raw_encoder.transition_buffers(&instance_buffer_barriers);
448
449 raw_encoder.build_acceleration_structures(&tlas_descriptors);
450
451 raw_encoder.place_acceleration_structure_barrier(hal::AccelerationStructureBarrier {
452 usage: hal::StateTransition {
453 from: hal::AccelerationStructureUses::BUILD_OUTPUT,
454 to: hal::AccelerationStructureUses::SHADER_INPUT,
455 },
456 });
457 }
458
459 if let Some(staging_buffer) = staging_buffer {
460 state
461 .temp_resources
462 .push(TempResource::StagingBuffer(staging_buffer));
463 }
464 }
465
466 state
467 .temp_resources
468 .push(TempResource::ScratchBuffer(scratch_buffer));
469
470 state.as_actions.push(AsAction::Build(build_command));
471
472 Ok(())
473}
474
475impl CommandBufferMutable {
476 pub(crate) fn validate_acceleration_structure_actions(
477 &self,
478 snatch_guard: &SnatchGuard,
479 command_index_guard: &mut RwLockWriteGuard<CommandIndices>,
480 ) -> Result<(), ValidateAsActionsError> {
481 profiling::scope!("CommandEncoder::[submission]::validate_as_actions");
482 for action in &self.as_actions {
483 match action {
484 AsAction::Build(build) => {
485 let build_command_index = NonZeroU64::new(
486 command_index_guard.next_acceleration_structure_build_command_index,
487 )
488 .unwrap();
489
490 command_index_guard.next_acceleration_structure_build_command_index += 1;
491 for blas in build.blas_s_built.iter() {
492 let mut state_lock = blas.compacted_state.lock();
493 *state_lock = match *state_lock {
494 BlasCompactState::Compacted => {
495 unreachable!("Should be validated out in build.")
496 }
497 _ => BlasCompactState::Idle,
500 };
501 *blas.built_index.write() = Some(build_command_index);
502 }
503
504 for tlas_build in build.tlas_s_built.iter() {
505 for blas in &tlas_build.dependencies {
506 if blas.built_index.read().is_none() {
507 return Err(ValidateAsActionsError::UsedUnbuiltBlas(
508 blas.error_ident(),
509 tlas_build.tlas.error_ident(),
510 ));
511 }
512 }
513 *tlas_build.tlas.built_index.write() = Some(build_command_index);
514 tlas_build
515 .tlas
516 .dependencies
517 .write()
518 .clone_from(&tlas_build.dependencies)
519 }
520 }
521 AsAction::UseTlas(tlas) => {
522 let tlas_build_index = tlas.built_index.read();
523 let dependencies = tlas.dependencies.read();
524
525 if (*tlas_build_index).is_none() {
526 return Err(ValidateAsActionsError::UsedUnbuiltTlas(tlas.error_ident()));
527 }
528 for blas in dependencies.deref() {
529 let blas_build_index = *blas.built_index.read();
530 if blas_build_index.is_none() {
531 return Err(ValidateAsActionsError::UsedUnbuiltBlas(
532 tlas.error_ident(),
533 blas.error_ident(),
534 ));
535 }
536 if blas_build_index.unwrap() > tlas_build_index.unwrap() {
537 return Err(ValidateAsActionsError::BlasNewerThenTlas(
538 blas.error_ident(),
539 tlas.error_ident(),
540 ));
541 }
542 blas.try_raw(snatch_guard)?;
543 }
544 }
545 }
546 }
547 Ok(())
548 }
549}
550
551fn iter_blas(
553 blas_iter: impl Iterator<Item = OwnedBlasBuildEntry<ArcReferences>>,
554 tracker: &mut Tracker,
555 build_command: &mut AsBuild,
556 buf_storage: &mut Vec<TriangleBufferStore>,
557) -> Result<(), BuildAccelerationStructureError> {
558 let mut temp_buffer = Vec::new();
559 for entry in blas_iter {
560 let blas = &entry.blas;
561 tracker.blas_s.insert_single(blas.clone());
562
563 build_command.blas_s_built.push(blas.clone());
564
565 match entry.geometries {
566 ArcBlasGeometries::TriangleGeometries(triangle_geometries) => {
567 for (i, mesh) in triangle_geometries.into_iter().enumerate() {
568 let size_desc = match &blas.sizes {
569 wgt::BlasGeometrySizeDescriptors::Triangles { descriptors } => descriptors,
570 };
571 if i >= size_desc.len() {
572 return Err(BuildAccelerationStructureError::IncompatibleBlasBuildSizes(
573 blas.error_ident(),
574 ));
575 }
576 let size_desc = &size_desc[i];
577
578 if size_desc.flags != mesh.size.flags {
579 return Err(BuildAccelerationStructureError::IncompatibleBlasFlags(
580 blas.error_ident(),
581 size_desc.flags,
582 mesh.size.flags,
583 ));
584 }
585
586 if size_desc.vertex_count < mesh.size.vertex_count {
587 return Err(
588 BuildAccelerationStructureError::IncompatibleBlasVertexCount(
589 blas.error_ident(),
590 size_desc.vertex_count,
591 mesh.size.vertex_count,
592 ),
593 );
594 }
595
596 if size_desc.vertex_format != mesh.size.vertex_format {
597 return Err(BuildAccelerationStructureError::DifferentBlasVertexFormats(
598 blas.error_ident(),
599 size_desc.vertex_format,
600 mesh.size.vertex_format,
601 ));
602 }
603
604 if size_desc
605 .vertex_format
606 .min_acceleration_structure_vertex_stride()
607 > mesh.vertex_stride
608 {
609 return Err(BuildAccelerationStructureError::VertexStrideTooSmall(
610 blas.error_ident(),
611 size_desc
612 .vertex_format
613 .min_acceleration_structure_vertex_stride(),
614 mesh.vertex_stride,
615 ));
616 }
617
618 if mesh.vertex_stride
619 % size_desc
620 .vertex_format
621 .acceleration_structure_stride_alignment()
622 != 0
623 {
624 return Err(BuildAccelerationStructureError::VertexStrideUnaligned(
625 blas.error_ident(),
626 size_desc
627 .vertex_format
628 .acceleration_structure_stride_alignment(),
629 mesh.vertex_stride,
630 ));
631 }
632
633 match (size_desc.index_count, mesh.size.index_count) {
634 (Some(_), None) | (None, Some(_)) => {
635 return Err(
636 BuildAccelerationStructureError::BlasIndexCountProvidedMismatch(
637 blas.error_ident(),
638 ),
639 )
640 }
641 (Some(create), Some(build)) if create < build => {
642 return Err(
643 BuildAccelerationStructureError::IncompatibleBlasIndexCount(
644 blas.error_ident(),
645 create,
646 build,
647 ),
648 )
649 }
650 _ => {}
651 }
652
653 if size_desc.index_format != mesh.size.index_format {
654 return Err(BuildAccelerationStructureError::DifferentBlasIndexFormats(
655 blas.error_ident(),
656 size_desc.index_format,
657 mesh.size.index_format,
658 ));
659 }
660
661 if size_desc.index_count.is_some() && mesh.index_buffer.is_none() {
662 return Err(BuildAccelerationStructureError::MissingIndexBuffer(
663 blas.error_ident(),
664 ));
665 }
666 let vertex_buffer = mesh.vertex_buffer.clone();
667 let vertex_pending = tracker.buffers.set_single(
668 &vertex_buffer,
669 BufferUses::BOTTOM_LEVEL_ACCELERATION_STRUCTURE_INPUT,
670 );
671 let index_data = if let Some(index_buffer) = mesh.index_buffer {
672 if mesh.first_index.is_none()
673 || mesh.size.index_count.is_none()
674 || mesh.size.index_count.is_none()
675 {
676 return Err(BuildAccelerationStructureError::MissingAssociatedData(
677 index_buffer.error_ident(),
678 ));
679 }
680 let data = tracker.buffers.set_single(
681 &index_buffer,
682 BufferUses::BOTTOM_LEVEL_ACCELERATION_STRUCTURE_INPUT,
683 );
684 Some((index_buffer, data))
685 } else {
686 None
687 };
688 let transform_data = if let Some(transform_buffer) = mesh.transform_buffer {
689 if !blas
690 .flags
691 .contains(wgt::AccelerationStructureFlags::USE_TRANSFORM)
692 {
693 return Err(BuildAccelerationStructureError::UseTransformMissing(
694 blas.error_ident(),
695 ));
696 }
697 if mesh.transform_buffer_offset.is_none() {
698 return Err(BuildAccelerationStructureError::MissingAssociatedData(
699 transform_buffer.error_ident(),
700 ));
701 }
702 let data = tracker.buffers.set_single(
703 &transform_buffer,
704 BufferUses::BOTTOM_LEVEL_ACCELERATION_STRUCTURE_INPUT,
705 );
706 Some((transform_buffer, data))
707 } else {
708 if blas
709 .flags
710 .contains(wgt::AccelerationStructureFlags::USE_TRANSFORM)
711 {
712 return Err(BuildAccelerationStructureError::TransformMissing(
713 blas.error_ident(),
714 ));
715 }
716 None
717 };
718 temp_buffer.push(TriangleBufferStore {
719 vertex_buffer,
720 vertex_transition: vertex_pending,
721 index_buffer_transition: index_data,
722 transform_buffer_transition: transform_data,
723 geometry: BlasTriangleGeometryInfo {
724 size: mesh.size,
725 first_vertex: mesh.first_vertex,
726 vertex_stride: mesh.vertex_stride,
727 first_index: mesh.first_index,
728 transform_buffer_offset: mesh.transform_buffer_offset,
729 },
730 ending_blas: None,
731 });
732 }
733
734 if let Some(last) = temp_buffer.last_mut() {
735 last.ending_blas = Some(blas.clone());
736 buf_storage.append(&mut temp_buffer);
737 }
738 }
739 }
740 }
741 Ok(())
742}
743
744fn iter_buffers<'snatch_guard: 'buffers, 'buffers>(
750 state: &mut EncodingState<'snatch_guard, '_>,
751 buf_storage: &'buffers mut Vec<TriangleBufferStore>,
752 input_barriers: &mut Vec<hal::BufferBarrier<'buffers, dyn hal::DynBuffer>>,
753 scratch_buffer_blas_size: &mut u64,
754 blas_storage: &mut Vec<BlasStore<'buffers>>,
755) -> Result<(), BuildAccelerationStructureError> {
756 let mut triangle_entries =
757 Vec::<hal::AccelerationStructureTriangles<dyn hal::DynBuffer>>::new();
758 for buf in buf_storage {
759 let mesh = &buf.geometry;
760 let vertex_buffer = {
761 let vertex_raw = buf.vertex_buffer.as_ref().try_raw(state.snatch_guard)?;
762 let vertex_buffer = &buf.vertex_buffer;
763 vertex_buffer.check_usage(BufferUsages::BLAS_INPUT)?;
764
765 if let Some(barrier) = buf
766 .vertex_transition
767 .take()
768 .map(|pending| pending.into_hal(buf.vertex_buffer.as_ref(), state.snatch_guard))
769 {
770 input_barriers.push(barrier);
771 }
772 if vertex_buffer.size
773 < (mesh.size.vertex_count + mesh.first_vertex) as u64 * mesh.vertex_stride
774 {
775 return Err(BuildAccelerationStructureError::InsufficientBufferSize(
776 vertex_buffer.error_ident(),
777 vertex_buffer.size,
778 (mesh.size.vertex_count + mesh.first_vertex) as u64 * mesh.vertex_stride,
779 ));
780 }
781 let vertex_buffer_offset = mesh.first_vertex as u64 * mesh.vertex_stride;
782 state.buffer_memory_init_actions.extend(
783 vertex_buffer.initialization_status.read().create_action(
784 vertex_buffer,
785 vertex_buffer_offset
786 ..(vertex_buffer_offset
787 + mesh.size.vertex_count as u64 * mesh.vertex_stride),
788 MemoryInitKind::NeedsInitializedMemory,
789 ),
790 );
791 vertex_raw
792 };
793 let index_buffer = if let Some((ref mut index_buffer, ref mut index_pending)) =
794 buf.index_buffer_transition
795 {
796 let index_raw = index_buffer.try_raw(state.snatch_guard)?;
797 index_buffer.check_usage(BufferUsages::BLAS_INPUT)?;
798
799 if let Some(barrier) = index_pending
800 .take()
801 .map(|pending| pending.into_hal(index_buffer, state.snatch_guard))
802 {
803 input_barriers.push(barrier);
804 }
805 let index_stride = mesh.size.index_format.unwrap().byte_size() as u64;
806 let offset = mesh.first_index.unwrap() as u64 * index_stride;
807 let index_buffer_size = mesh.size.index_count.unwrap() as u64 * index_stride;
808
809 if mesh.size.index_count.unwrap() % 3 != 0 {
810 return Err(BuildAccelerationStructureError::InvalidIndexCount(
811 index_buffer.error_ident(),
812 mesh.size.index_count.unwrap(),
813 ));
814 }
815 if index_buffer.size < mesh.size.index_count.unwrap() as u64 * index_stride + offset {
816 return Err(BuildAccelerationStructureError::InsufficientBufferSize(
817 index_buffer.error_ident(),
818 index_buffer.size,
819 mesh.size.index_count.unwrap() as u64 * index_stride + offset,
820 ));
821 }
822
823 state.buffer_memory_init_actions.extend(
824 index_buffer.initialization_status.read().create_action(
825 index_buffer,
826 offset..(offset + index_buffer_size),
827 MemoryInitKind::NeedsInitializedMemory,
828 ),
829 );
830 Some(index_raw)
831 } else {
832 None
833 };
834 let transform_buffer = if let Some((ref mut transform_buffer, ref mut transform_pending)) =
835 buf.transform_buffer_transition
836 {
837 if mesh.transform_buffer_offset.is_none() {
838 return Err(BuildAccelerationStructureError::MissingAssociatedData(
839 transform_buffer.error_ident(),
840 ));
841 }
842 let transform_raw = transform_buffer.try_raw(state.snatch_guard)?;
843 transform_buffer.check_usage(BufferUsages::BLAS_INPUT)?;
844
845 if let Some(barrier) = transform_pending
846 .take()
847 .map(|pending| pending.into_hal(transform_buffer, state.snatch_guard))
848 {
849 input_barriers.push(barrier);
850 }
851
852 let offset = mesh.transform_buffer_offset.unwrap();
853
854 if offset % wgt::TRANSFORM_BUFFER_ALIGNMENT != 0 {
855 return Err(
856 BuildAccelerationStructureError::UnalignedTransformBufferOffset(
857 transform_buffer.error_ident(),
858 ),
859 );
860 }
861 if transform_buffer.size < 48 + offset {
862 return Err(BuildAccelerationStructureError::InsufficientBufferSize(
863 transform_buffer.error_ident(),
864 transform_buffer.size,
865 48 + offset,
866 ));
867 }
868 state.buffer_memory_init_actions.extend(
869 transform_buffer.initialization_status.read().create_action(
870 transform_buffer,
871 offset..(offset + 48),
872 MemoryInitKind::NeedsInitializedMemory,
873 ),
874 );
875 Some(transform_raw)
876 } else {
877 None
878 };
879
880 let triangles = hal::AccelerationStructureTriangles {
881 vertex_buffer: Some(vertex_buffer),
882 vertex_format: mesh.size.vertex_format,
883 first_vertex: mesh.first_vertex,
884 vertex_count: mesh.size.vertex_count,
885 vertex_stride: mesh.vertex_stride,
886 indices: index_buffer.map(|index_buffer| {
887 let index_stride = mesh.size.index_format.unwrap().byte_size() as u32;
888 hal::AccelerationStructureTriangleIndices::<dyn hal::DynBuffer> {
889 format: mesh.size.index_format.unwrap(),
890 buffer: Some(index_buffer),
891 offset: mesh.first_index.unwrap() * index_stride,
892 count: mesh.size.index_count.unwrap(),
893 }
894 }),
895 transform: transform_buffer.map(|transform_buffer| {
896 hal::AccelerationStructureTriangleTransform {
897 buffer: transform_buffer,
898 offset: mesh.transform_buffer_offset.unwrap() as u32,
899 }
900 }),
901 flags: mesh.size.flags,
902 };
903 triangle_entries.push(triangles);
904 if let Some(blas) = buf.ending_blas.take() {
905 let scratch_buffer_offset = *scratch_buffer_blas_size;
906 *scratch_buffer_blas_size += align_to(
907 blas.size_info.build_scratch_size as u32,
908 state.device.alignments.ray_tracing_scratch_buffer_alignment,
909 ) as u64;
910
911 blas_storage.push(BlasStore {
912 blas,
913 entries: hal::AccelerationStructureEntries::Triangles(triangle_entries),
914 scratch_buffer_offset,
915 });
916 triangle_entries = Vec::new();
917 }
918 }
919 Ok(())
920}
921
922fn map_blas<'a>(
923 storage: &'a BlasStore<'_>,
924 scratch_buffer: &'a dyn hal::DynBuffer,
925 snatch_guard: &'a SnatchGuard,
926 blases_compactable: &mut Vec<(
927 &'a dyn hal::DynBuffer,
928 &'a dyn hal::DynAccelerationStructure,
929 )>,
930) -> Result<
931 hal::BuildAccelerationStructureDescriptor<
932 'a,
933 dyn hal::DynBuffer,
934 dyn hal::DynAccelerationStructure,
935 >,
936 BuildAccelerationStructureError,
937> {
938 let BlasStore {
939 blas,
940 entries,
941 scratch_buffer_offset,
942 } = storage;
943 if blas.update_mode == wgt::AccelerationStructureUpdateMode::PreferUpdate {
944 log::info!("only rebuild implemented")
945 }
946 let raw = blas.try_raw(snatch_guard)?;
947
948 let state_lock = blas.compacted_state.lock();
949 if let BlasCompactState::Compacted = *state_lock {
950 return Err(BuildAccelerationStructureError::CompactedBlas(
951 blas.error_ident(),
952 ));
953 }
954
955 if blas
956 .flags
957 .contains(wgpu_types::AccelerationStructureFlags::ALLOW_COMPACTION)
958 {
959 blases_compactable.push((blas.compaction_buffer.as_ref().unwrap().as_ref(), raw));
960 }
961 Ok(hal::BuildAccelerationStructureDescriptor {
962 entries,
963 mode: hal::AccelerationStructureBuildMode::Build,
964 flags: blas.flags,
965 source_acceleration_structure: None,
966 destination_acceleration_structure: raw,
967 scratch_buffer,
968 scratch_buffer_offset: *scratch_buffer_offset,
969 })
970}
971
972fn build_blas<'a>(
973 cmd_buf_raw: &mut dyn hal::DynCommandEncoder,
974 blas_present: bool,
975 tlas_present: bool,
976 input_barriers: Vec<hal::BufferBarrier<dyn hal::DynBuffer>>,
977 blas_descriptors: &[hal::BuildAccelerationStructureDescriptor<
978 'a,
979 dyn hal::DynBuffer,
980 dyn hal::DynAccelerationStructure,
981 >],
982 scratch_buffer_barrier: hal::BufferBarrier<dyn hal::DynBuffer>,
983 blas_s_for_compaction: Vec<(
984 &'a dyn hal::DynBuffer,
985 &'a dyn hal::DynAccelerationStructure,
986 )>,
987) {
988 unsafe {
989 cmd_buf_raw.transition_buffers(&input_barriers);
990 }
991
992 if blas_present {
993 unsafe {
994 cmd_buf_raw.place_acceleration_structure_barrier(hal::AccelerationStructureBarrier {
995 usage: hal::StateTransition {
996 from: hal::AccelerationStructureUses::BUILD_INPUT,
997 to: hal::AccelerationStructureUses::BUILD_OUTPUT,
998 },
999 });
1000
1001 cmd_buf_raw.build_acceleration_structures(blas_descriptors);
1002 }
1003 }
1004
1005 if blas_present && tlas_present {
1006 unsafe {
1007 cmd_buf_raw.transition_buffers(&[scratch_buffer_barrier]);
1008 }
1009 }
1010
1011 let mut source_usage = hal::AccelerationStructureUses::empty();
1012 let mut destination_usage = hal::AccelerationStructureUses::empty();
1013 for &(buf, blas) in blas_s_for_compaction.iter() {
1014 unsafe {
1015 cmd_buf_raw.transition_buffers(&[hal::BufferBarrier {
1016 buffer: buf,
1017 usage: hal::StateTransition {
1018 from: BufferUses::ACCELERATION_STRUCTURE_QUERY,
1019 to: BufferUses::ACCELERATION_STRUCTURE_QUERY,
1020 },
1021 }])
1022 }
1023 unsafe { cmd_buf_raw.read_acceleration_structure_compact_size(blas, buf) }
1024 destination_usage |= hal::AccelerationStructureUses::COPY_SRC;
1025 }
1026
1027 if blas_present {
1028 source_usage |= hal::AccelerationStructureUses::BUILD_OUTPUT;
1029 destination_usage |= hal::AccelerationStructureUses::BUILD_INPUT
1030 }
1031 if tlas_present {
1032 source_usage |= hal::AccelerationStructureUses::SHADER_INPUT;
1033 destination_usage |= hal::AccelerationStructureUses::BUILD_OUTPUT;
1034 }
1035 unsafe {
1036 cmd_buf_raw.place_acceleration_structure_barrier(hal::AccelerationStructureBarrier {
1037 usage: hal::StateTransition {
1038 from: source_usage,
1039 to: destination_usage,
1040 },
1041 });
1042 }
1043}