1use alloc::{boxed::Box, sync::Arc, vec::Vec};
2use core::{
3 cmp::max,
4 num::NonZeroU64,
5 ops::{Deref, Range},
6};
7
8use wgt::{math::align_to, BufferUsages, BufferUses, Features};
9
10use crate::{
11 command::CommandBufferMutable,
12 device::queue::TempResource,
13 global::Global,
14 hub::Hub,
15 id::CommandEncoderId,
16 init_tracker::MemoryInitKind,
17 ray_tracing::{
18 BlasBuildEntry, BlasGeometries, BlasTriangleGeometry, BuildAccelerationStructureError,
19 TlasInstance, TlasPackage, TraceBlasBuildEntry, TraceBlasGeometries,
20 TraceBlasTriangleGeometry, TraceTlasInstance, TraceTlasPackage,
21 },
22 resource::{Blas, BlasCompactState, Buffer, Labeled, StagingBuffer, Tlas},
23 scratch::ScratchBuffer,
24 snatch::SnatchGuard,
25 track::PendingTransition,
26};
27use crate::{
28 command::CommandEncoder,
29 ray_tracing::{AsAction, AsBuild, TlasBuild, ValidateAsActionsError},
30};
31use crate::{command::EncoderStateError, device::resource::CommandIndices};
32use crate::{lock::RwLockWriteGuard, resource::RawResourceAccess};
33
34use crate::id::{BlasId, TlasId};
35
36struct TriangleBufferStore<'a> {
37 vertex_buffer: Arc<Buffer>,
38 vertex_transition: Option<PendingTransition<BufferUses>>,
39 index_buffer_transition: Option<(Arc<Buffer>, Option<PendingTransition<BufferUses>>)>,
40 transform_buffer_transition: Option<(Arc<Buffer>, Option<PendingTransition<BufferUses>>)>,
41 geometry: BlasTriangleGeometry<'a>,
42 ending_blas: Option<Arc<Blas>>,
43}
44
45struct BlasStore<'a> {
46 blas: Arc<Blas>,
47 entries: hal::AccelerationStructureEntries<'a, dyn hal::DynBuffer>,
48 scratch_buffer_offset: u64,
49}
50
51struct UnsafeTlasStore<'a> {
52 tlas: Arc<Tlas>,
53 entries: hal::AccelerationStructureEntries<'a, dyn hal::DynBuffer>,
54 scratch_buffer_offset: u64,
55}
56
57struct TlasStore<'a> {
58 internal: UnsafeTlasStore<'a>,
59 range: Range<usize>,
60}
61
62impl Global {
63 pub fn command_encoder_mark_acceleration_structures_built(
64 &self,
65 command_encoder_id: CommandEncoderId,
66 blas_ids: &[BlasId],
67 tlas_ids: &[TlasId],
68 ) -> Result<(), EncoderStateError> {
69 profiling::scope!("CommandEncoder::mark_acceleration_structures_built");
70
71 let hub = &self.hub;
72
73 let cmd_enc = hub.command_encoders.get(command_encoder_id);
74
75 let mut cmd_buf_data = cmd_enc.data.lock();
76 cmd_buf_data.record_with(
77 |cmd_buf_data| -> Result<(), BuildAccelerationStructureError> {
78 let device = &cmd_enc.device;
79 device.check_is_valid()?;
80 device.require_features(Features::EXPERIMENTAL_RAY_QUERY)?;
81
82 let mut build_command = AsBuild::default();
83
84 for blas in blas_ids {
85 let blas = hub.blas_s.get(*blas).get()?;
86 build_command.blas_s_built.push(blas);
87 }
88
89 for tlas in tlas_ids {
90 let tlas = hub.tlas_s.get(*tlas).get()?;
91 build_command.tlas_s_built.push(TlasBuild {
92 tlas,
93 dependencies: Vec::new(),
94 });
95 }
96
97 cmd_buf_data.as_actions.push(AsAction::Build(build_command));
98 Ok(())
99 },
100 )
101 }
102
103 pub fn command_encoder_build_acceleration_structures<'a>(
104 &self,
105 command_encoder_id: CommandEncoderId,
106 blas_iter: impl Iterator<Item = BlasBuildEntry<'a>>,
107 tlas_iter: impl Iterator<Item = TlasPackage<'a>>,
108 ) -> Result<(), EncoderStateError> {
109 profiling::scope!("CommandEncoder::build_acceleration_structures");
110
111 let hub = &self.hub;
112
113 let cmd_enc = hub.command_encoders.get(command_encoder_id);
114
115 let trace_blas: Vec<TraceBlasBuildEntry> = blas_iter
116 .map(|blas_entry| {
117 let geometries = match blas_entry.geometries {
118 BlasGeometries::TriangleGeometries(triangle_geometries) => {
119 TraceBlasGeometries::TriangleGeometries(
120 triangle_geometries
121 .map(|tg| TraceBlasTriangleGeometry {
122 size: tg.size.clone(),
123 vertex_buffer: tg.vertex_buffer,
124 index_buffer: tg.index_buffer,
125 transform_buffer: tg.transform_buffer,
126 first_vertex: tg.first_vertex,
127 vertex_stride: tg.vertex_stride,
128 first_index: tg.first_index,
129 transform_buffer_offset: tg.transform_buffer_offset,
130 })
131 .collect(),
132 )
133 }
134 };
135 TraceBlasBuildEntry {
136 blas_id: blas_entry.blas_id,
137 geometries,
138 }
139 })
140 .collect();
141
142 let trace_tlas: Vec<TraceTlasPackage> = tlas_iter
143 .map(|package: TlasPackage| {
144 let instances = package
145 .instances
146 .map(|instance| {
147 instance.map(|instance| TraceTlasInstance {
148 blas_id: instance.blas_id,
149 transform: *instance.transform,
150 custom_data: instance.custom_data,
151 mask: instance.mask,
152 })
153 })
154 .collect();
155 TraceTlasPackage {
156 tlas_id: package.tlas_id,
157 instances,
158 lowest_unmodified: package.lowest_unmodified,
159 }
160 })
161 .collect();
162
163 let mut cmd_buf_data = cmd_enc.data.lock();
164 cmd_buf_data.record_with(|cmd_buf_data| {
165 let blas_iter = trace_blas.iter().map(|blas_entry| {
166 let geometries = match &blas_entry.geometries {
167 TraceBlasGeometries::TriangleGeometries(triangle_geometries) => {
168 let iter = triangle_geometries.iter().map(|tg| BlasTriangleGeometry {
169 size: &tg.size,
170 vertex_buffer: tg.vertex_buffer,
171 index_buffer: tg.index_buffer,
172 transform_buffer: tg.transform_buffer,
173 first_vertex: tg.first_vertex,
174 vertex_stride: tg.vertex_stride,
175 first_index: tg.first_index,
176 transform_buffer_offset: tg.transform_buffer_offset,
177 });
178 BlasGeometries::TriangleGeometries(Box::new(iter))
179 }
180 };
181 BlasBuildEntry {
182 blas_id: blas_entry.blas_id,
183 geometries,
184 }
185 });
186
187 let tlas_iter = trace_tlas.iter().map(|tlas_package| {
188 let instances = tlas_package.instances.iter().map(|instance| {
189 instance.as_ref().map(|instance| TlasInstance {
190 blas_id: instance.blas_id,
191 transform: &instance.transform,
192 custom_data: instance.custom_data,
193 mask: instance.mask,
194 })
195 });
196 TlasPackage {
197 tlas_id: tlas_package.tlas_id,
198 instances: Box::new(instances),
199 lowest_unmodified: tlas_package.lowest_unmodified,
200 }
201 });
202
203 build_acceleration_structures(
204 cmd_buf_data,
205 hub,
206 &cmd_enc,
207 trace_blas.clone(),
208 trace_tlas.clone(),
209 blas_iter,
210 tlas_iter,
211 )
212 })
213 }
214}
215
216pub(crate) fn build_acceleration_structures<'a>(
217 cmd_buf_data: &'a mut CommandBufferMutable,
218 hub: &'a Hub,
219 cmd_enc: &'a Arc<CommandEncoder>,
220 trace_blas: Vec<TraceBlasBuildEntry>,
221 trace_tlas: Vec<TraceTlasPackage>,
222 blas_iter: impl Iterator<Item = BlasBuildEntry<'a>>,
223 tlas_iter: impl Iterator<Item = TlasPackage<'a>>,
224) -> Result<(), BuildAccelerationStructureError> {
225 #[cfg(feature = "trace")]
226 if let Some(ref mut list) = cmd_buf_data.trace_commands {
227 list.push(crate::command::Command::BuildAccelerationStructures {
228 blas: trace_blas,
229 tlas: trace_tlas,
230 });
231 }
232 #[cfg(not(feature = "trace"))]
233 {
234 let _ = trace_blas;
235 let _ = trace_tlas;
236 }
237
238 let device = &cmd_enc.device;
239 device.check_is_valid()?;
240 device.require_features(Features::EXPERIMENTAL_RAY_QUERY)?;
241
242 let mut build_command = AsBuild::default();
243 let mut buf_storage = Vec::new();
244 iter_blas(
245 blas_iter,
246 cmd_buf_data,
247 &mut build_command,
248 &mut buf_storage,
249 hub,
250 )?;
251
252 let snatch_guard = device.snatchable_lock.read();
253 let mut input_barriers = Vec::<hal::BufferBarrier<dyn hal::DynBuffer>>::new();
254 let mut scratch_buffer_blas_size = 0;
255 let mut blas_storage = Vec::new();
256 iter_buffers(
257 &mut buf_storage,
258 &snatch_guard,
259 &mut input_barriers,
260 cmd_buf_data,
261 &mut scratch_buffer_blas_size,
262 &mut blas_storage,
263 hub,
264 device.alignments.ray_tracing_scratch_buffer_alignment,
265 )?;
266 let mut tlas_lock_store = Vec::<(Option<TlasPackage>, Arc<Tlas>)>::new();
267
268 for package in tlas_iter {
269 let tlas = hub.tlas_s.get(package.tlas_id).get()?;
270
271 cmd_buf_data.trackers.tlas_s.insert_single(tlas.clone());
272
273 tlas_lock_store.push((Some(package), tlas))
274 }
275
276 let mut scratch_buffer_tlas_size = 0;
277 let mut tlas_storage = Vec::<TlasStore>::new();
278 let mut instance_buffer_staging_source = Vec::<u8>::new();
279
280 for (package, tlas) in &mut tlas_lock_store {
281 let package = package.take().unwrap();
282
283 let scratch_buffer_offset = scratch_buffer_tlas_size;
284 scratch_buffer_tlas_size += align_to(
285 tlas.size_info.build_scratch_size as u32,
286 device.alignments.ray_tracing_scratch_buffer_alignment,
287 ) as u64;
288
289 let first_byte_index = instance_buffer_staging_source.len();
290
291 let mut dependencies = Vec::new();
292
293 let mut instance_count = 0;
294 for instance in package.instances.flatten() {
295 if instance.custom_data >= (1u32 << 24u32) {
296 return Err(BuildAccelerationStructureError::TlasInvalidCustomIndex(
297 tlas.error_ident(),
298 ));
299 }
300 let blas = hub.blas_s.get(instance.blas_id).get()?;
301
302 cmd_buf_data.trackers.blas_s.insert_single(blas.clone());
303
304 instance_buffer_staging_source.extend(device.raw().tlas_instance_to_bytes(
305 hal::TlasInstance {
306 transform: *instance.transform,
307 custom_data: instance.custom_data,
308 mask: instance.mask,
309 blas_address: blas.handle,
310 },
311 ));
312
313 if tlas
314 .flags
315 .contains(wgpu_types::AccelerationStructureFlags::ALLOW_RAY_HIT_VERTEX_RETURN)
316 && !blas
317 .flags
318 .contains(wgpu_types::AccelerationStructureFlags::ALLOW_RAY_HIT_VERTEX_RETURN)
319 {
320 return Err(
321 BuildAccelerationStructureError::TlasDependentMissingVertexReturn(
322 tlas.error_ident(),
323 blas.error_ident(),
324 ),
325 );
326 }
327
328 instance_count += 1;
329
330 dependencies.push(blas.clone());
331 }
332
333 build_command.tlas_s_built.push(TlasBuild {
334 tlas: tlas.clone(),
335 dependencies,
336 });
337
338 if instance_count > tlas.max_instance_count {
339 return Err(BuildAccelerationStructureError::TlasInstanceCountExceeded(
340 tlas.error_ident(),
341 instance_count,
342 tlas.max_instance_count,
343 ));
344 }
345
346 tlas_storage.push(TlasStore {
347 internal: UnsafeTlasStore {
348 tlas: tlas.clone(),
349 entries: hal::AccelerationStructureEntries::Instances(
350 hal::AccelerationStructureInstances {
351 buffer: Some(tlas.instance_buffer.as_ref()),
352 offset: 0,
353 count: instance_count,
354 },
355 ),
356 scratch_buffer_offset,
357 },
358 range: first_byte_index..instance_buffer_staging_source.len(),
359 });
360 }
361
362 let Some(scratch_size) =
363 wgt::BufferSize::new(max(scratch_buffer_blas_size, scratch_buffer_tlas_size))
364 else {
365 return Ok(());
367 };
368
369 let scratch_buffer = ScratchBuffer::new(device, scratch_size)?;
370
371 let scratch_buffer_barrier = hal::BufferBarrier::<dyn hal::DynBuffer> {
372 buffer: scratch_buffer.raw(),
373 usage: hal::StateTransition {
374 from: BufferUses::ACCELERATION_STRUCTURE_SCRATCH,
375 to: BufferUses::ACCELERATION_STRUCTURE_SCRATCH,
376 },
377 };
378
379 let mut tlas_descriptors = Vec::with_capacity(tlas_storage.len());
380
381 for &TlasStore {
382 internal:
383 UnsafeTlasStore {
384 ref tlas,
385 ref entries,
386 ref scratch_buffer_offset,
387 },
388 ..
389 } in &tlas_storage
390 {
391 if tlas.update_mode == wgt::AccelerationStructureUpdateMode::PreferUpdate {
392 log::info!("only rebuild implemented")
393 }
394 tlas_descriptors.push(hal::BuildAccelerationStructureDescriptor {
395 entries,
396 mode: hal::AccelerationStructureBuildMode::Build,
397 flags: tlas.flags,
398 source_acceleration_structure: None,
399 destination_acceleration_structure: tlas.try_raw(&snatch_guard)?,
400 scratch_buffer: scratch_buffer.raw(),
401 scratch_buffer_offset: *scratch_buffer_offset,
402 })
403 }
404
405 let blas_present = !blas_storage.is_empty();
406 let tlas_present = !tlas_storage.is_empty();
407
408 let cmd_buf_raw = cmd_buf_data.encoder.open()?;
409
410 let mut blas_s_compactable = Vec::new();
411 let mut descriptors = Vec::new();
412
413 for storage in &blas_storage {
414 descriptors.push(map_blas(
415 storage,
416 scratch_buffer.raw(),
417 &snatch_guard,
418 &mut blas_s_compactable,
419 )?);
420 }
421
422 build_blas(
423 cmd_buf_raw,
424 blas_present,
425 tlas_present,
426 input_barriers,
427 &descriptors,
428 scratch_buffer_barrier,
429 blas_s_compactable,
430 );
431
432 if tlas_present {
433 let staging_buffer = if !instance_buffer_staging_source.is_empty() {
434 let mut staging_buffer = StagingBuffer::new(
435 device,
436 wgt::BufferSize::new(instance_buffer_staging_source.len() as u64).unwrap(),
437 )?;
438 staging_buffer.write(&instance_buffer_staging_source);
439 let flushed = staging_buffer.flush();
440 Some(flushed)
441 } else {
442 None
443 };
444
445 unsafe {
446 if let Some(ref staging_buffer) = staging_buffer {
447 cmd_buf_raw.transition_buffers(&[hal::BufferBarrier::<dyn hal::DynBuffer> {
448 buffer: staging_buffer.raw(),
449 usage: hal::StateTransition {
450 from: BufferUses::MAP_WRITE,
451 to: BufferUses::COPY_SRC,
452 },
453 }]);
454 }
455 }
456
457 let mut instance_buffer_barriers = Vec::new();
458 for &TlasStore {
459 internal: UnsafeTlasStore { ref tlas, .. },
460 ref range,
461 } in &tlas_storage
462 {
463 let size = match wgt::BufferSize::new((range.end - range.start) as u64) {
464 None => continue,
465 Some(size) => size,
466 };
467 instance_buffer_barriers.push(hal::BufferBarrier::<dyn hal::DynBuffer> {
468 buffer: tlas.instance_buffer.as_ref(),
469 usage: hal::StateTransition {
470 from: BufferUses::COPY_DST,
471 to: BufferUses::TOP_LEVEL_ACCELERATION_STRUCTURE_INPUT,
472 },
473 });
474 unsafe {
475 cmd_buf_raw.transition_buffers(&[hal::BufferBarrier::<dyn hal::DynBuffer> {
476 buffer: tlas.instance_buffer.as_ref(),
477 usage: hal::StateTransition {
478 from: BufferUses::TOP_LEVEL_ACCELERATION_STRUCTURE_INPUT,
479 to: BufferUses::COPY_DST,
480 },
481 }]);
482 let temp = hal::BufferCopy {
483 src_offset: range.start as u64,
484 dst_offset: 0,
485 size,
486 };
487 cmd_buf_raw.copy_buffer_to_buffer(
488 staging_buffer.as_ref().unwrap().raw(),
491 tlas.instance_buffer.as_ref(),
492 &[temp],
493 );
494 }
495 }
496
497 unsafe {
498 cmd_buf_raw.transition_buffers(&instance_buffer_barriers);
499
500 cmd_buf_raw.build_acceleration_structures(&tlas_descriptors);
501
502 cmd_buf_raw.place_acceleration_structure_barrier(hal::AccelerationStructureBarrier {
503 usage: hal::StateTransition {
504 from: hal::AccelerationStructureUses::BUILD_OUTPUT,
505 to: hal::AccelerationStructureUses::SHADER_INPUT,
506 },
507 });
508 }
509
510 if let Some(staging_buffer) = staging_buffer {
511 cmd_buf_data
512 .temp_resources
513 .push(TempResource::StagingBuffer(staging_buffer));
514 }
515 }
516
517 cmd_buf_data
518 .temp_resources
519 .push(TempResource::ScratchBuffer(scratch_buffer));
520
521 cmd_buf_data.as_actions.push(AsAction::Build(build_command));
522
523 Ok(())
524}
525
526impl CommandBufferMutable {
527 pub(crate) fn validate_acceleration_structure_actions(
528 &self,
529 snatch_guard: &SnatchGuard,
530 command_index_guard: &mut RwLockWriteGuard<CommandIndices>,
531 ) -> Result<(), ValidateAsActionsError> {
532 profiling::scope!("CommandEncoder::[submission]::validate_as_actions");
533 for action in &self.as_actions {
534 match action {
535 AsAction::Build(build) => {
536 let build_command_index = NonZeroU64::new(
537 command_index_guard.next_acceleration_structure_build_command_index,
538 )
539 .unwrap();
540
541 command_index_guard.next_acceleration_structure_build_command_index += 1;
542 for blas in build.blas_s_built.iter() {
543 let mut state_lock = blas.compacted_state.lock();
544 *state_lock = match *state_lock {
545 BlasCompactState::Compacted => {
546 unreachable!("Should be validated out in build.")
547 }
548 _ => BlasCompactState::Idle,
551 };
552 *blas.built_index.write() = Some(build_command_index);
553 }
554
555 for tlas_build in build.tlas_s_built.iter() {
556 for blas in &tlas_build.dependencies {
557 if blas.built_index.read().is_none() {
558 return Err(ValidateAsActionsError::UsedUnbuiltBlas(
559 blas.error_ident(),
560 tlas_build.tlas.error_ident(),
561 ));
562 }
563 }
564 *tlas_build.tlas.built_index.write() = Some(build_command_index);
565 tlas_build
566 .tlas
567 .dependencies
568 .write()
569 .clone_from(&tlas_build.dependencies)
570 }
571 }
572 AsAction::UseTlas(tlas) => {
573 let tlas_build_index = tlas.built_index.read();
574 let dependencies = tlas.dependencies.read();
575
576 if (*tlas_build_index).is_none() {
577 return Err(ValidateAsActionsError::UsedUnbuiltTlas(tlas.error_ident()));
578 }
579 for blas in dependencies.deref() {
580 let blas_build_index = *blas.built_index.read();
581 if blas_build_index.is_none() {
582 return Err(ValidateAsActionsError::UsedUnbuiltBlas(
583 tlas.error_ident(),
584 blas.error_ident(),
585 ));
586 }
587 if blas_build_index.unwrap() > tlas_build_index.unwrap() {
588 return Err(ValidateAsActionsError::BlasNewerThenTlas(
589 blas.error_ident(),
590 tlas.error_ident(),
591 ));
592 }
593 blas.try_raw(snatch_guard)?;
594 }
595 }
596 }
597 }
598 Ok(())
599 }
600}
601
602fn iter_blas<'a>(
604 blas_iter: impl Iterator<Item = BlasBuildEntry<'a>>,
605 cmd_buf_data: &mut CommandBufferMutable,
606 build_command: &mut AsBuild,
607 buf_storage: &mut Vec<TriangleBufferStore<'a>>,
608 hub: &Hub,
609) -> Result<(), BuildAccelerationStructureError> {
610 let mut temp_buffer = Vec::new();
611 for entry in blas_iter {
612 let blas = hub.blas_s.get(entry.blas_id).get()?;
613 cmd_buf_data.trackers.blas_s.insert_single(blas.clone());
614
615 build_command.blas_s_built.push(blas.clone());
616
617 match entry.geometries {
618 BlasGeometries::TriangleGeometries(triangle_geometries) => {
619 for (i, mesh) in triangle_geometries.enumerate() {
620 let size_desc = match &blas.sizes {
621 wgt::BlasGeometrySizeDescriptors::Triangles { descriptors } => descriptors,
622 };
623 if i >= size_desc.len() {
624 return Err(BuildAccelerationStructureError::IncompatibleBlasBuildSizes(
625 blas.error_ident(),
626 ));
627 }
628 let size_desc = &size_desc[i];
629
630 if size_desc.flags != mesh.size.flags {
631 return Err(BuildAccelerationStructureError::IncompatibleBlasFlags(
632 blas.error_ident(),
633 size_desc.flags,
634 mesh.size.flags,
635 ));
636 }
637
638 if size_desc.vertex_count < mesh.size.vertex_count {
639 return Err(
640 BuildAccelerationStructureError::IncompatibleBlasVertexCount(
641 blas.error_ident(),
642 size_desc.vertex_count,
643 mesh.size.vertex_count,
644 ),
645 );
646 }
647
648 if size_desc.vertex_format != mesh.size.vertex_format {
649 return Err(BuildAccelerationStructureError::DifferentBlasVertexFormats(
650 blas.error_ident(),
651 size_desc.vertex_format,
652 mesh.size.vertex_format,
653 ));
654 }
655
656 if size_desc
657 .vertex_format
658 .min_acceleration_structure_vertex_stride()
659 > mesh.vertex_stride
660 {
661 return Err(BuildAccelerationStructureError::VertexStrideTooSmall(
662 blas.error_ident(),
663 size_desc
664 .vertex_format
665 .min_acceleration_structure_vertex_stride(),
666 mesh.vertex_stride,
667 ));
668 }
669
670 if mesh.vertex_stride
671 % size_desc
672 .vertex_format
673 .acceleration_structure_stride_alignment()
674 != 0
675 {
676 return Err(BuildAccelerationStructureError::VertexStrideUnaligned(
677 blas.error_ident(),
678 size_desc
679 .vertex_format
680 .acceleration_structure_stride_alignment(),
681 mesh.vertex_stride,
682 ));
683 }
684
685 match (size_desc.index_count, mesh.size.index_count) {
686 (Some(_), None) | (None, Some(_)) => {
687 return Err(
688 BuildAccelerationStructureError::BlasIndexCountProvidedMismatch(
689 blas.error_ident(),
690 ),
691 )
692 }
693 (Some(create), Some(build)) if create < build => {
694 return Err(
695 BuildAccelerationStructureError::IncompatibleBlasIndexCount(
696 blas.error_ident(),
697 create,
698 build,
699 ),
700 )
701 }
702 _ => {}
703 }
704
705 if size_desc.index_format != mesh.size.index_format {
706 return Err(BuildAccelerationStructureError::DifferentBlasIndexFormats(
707 blas.error_ident(),
708 size_desc.index_format,
709 mesh.size.index_format,
710 ));
711 }
712
713 if size_desc.index_count.is_some() && mesh.index_buffer.is_none() {
714 return Err(BuildAccelerationStructureError::MissingIndexBuffer(
715 blas.error_ident(),
716 ));
717 }
718 let vertex_buffer = hub.buffers.get(mesh.vertex_buffer).get()?;
719 let vertex_pending = cmd_buf_data.trackers.buffers.set_single(
720 &vertex_buffer,
721 BufferUses::BOTTOM_LEVEL_ACCELERATION_STRUCTURE_INPUT,
722 );
723 let index_data = if let Some(index_id) = mesh.index_buffer {
724 let index_buffer = hub.buffers.get(index_id).get()?;
725 if mesh.first_index.is_none()
726 || mesh.size.index_count.is_none()
727 || mesh.size.index_count.is_none()
728 {
729 return Err(BuildAccelerationStructureError::MissingAssociatedData(
730 index_buffer.error_ident(),
731 ));
732 }
733 let data = cmd_buf_data.trackers.buffers.set_single(
734 &index_buffer,
735 BufferUses::BOTTOM_LEVEL_ACCELERATION_STRUCTURE_INPUT,
736 );
737 Some((index_buffer, data))
738 } else {
739 None
740 };
741 let transform_data = if let Some(transform_id) = mesh.transform_buffer {
742 if !blas
743 .flags
744 .contains(wgt::AccelerationStructureFlags::USE_TRANSFORM)
745 {
746 return Err(BuildAccelerationStructureError::UseTransformMissing(
747 blas.error_ident(),
748 ));
749 }
750 let transform_buffer = hub.buffers.get(transform_id).get()?;
751 if mesh.transform_buffer_offset.is_none() {
752 return Err(BuildAccelerationStructureError::MissingAssociatedData(
753 transform_buffer.error_ident(),
754 ));
755 }
756 let data = cmd_buf_data.trackers.buffers.set_single(
757 &transform_buffer,
758 BufferUses::BOTTOM_LEVEL_ACCELERATION_STRUCTURE_INPUT,
759 );
760 Some((transform_buffer, data))
761 } else {
762 if blas
763 .flags
764 .contains(wgt::AccelerationStructureFlags::USE_TRANSFORM)
765 {
766 return Err(BuildAccelerationStructureError::TransformMissing(
767 blas.error_ident(),
768 ));
769 }
770 None
771 };
772 temp_buffer.push(TriangleBufferStore {
773 vertex_buffer,
774 vertex_transition: vertex_pending,
775 index_buffer_transition: index_data,
776 transform_buffer_transition: transform_data,
777 geometry: mesh,
778 ending_blas: None,
779 });
780 }
781
782 if let Some(last) = temp_buffer.last_mut() {
783 last.ending_blas = Some(blas);
784 buf_storage.append(&mut temp_buffer);
785 }
786 }
787 }
788 }
789 Ok(())
790}
791
792fn iter_buffers<'a, 'b>(
794 buf_storage: &'a mut Vec<TriangleBufferStore<'b>>,
795 snatch_guard: &'a SnatchGuard,
796 input_barriers: &mut Vec<hal::BufferBarrier<'a, dyn hal::DynBuffer>>,
797 cmd_buf_data: &mut CommandBufferMutable,
798 scratch_buffer_blas_size: &mut u64,
799 blas_storage: &mut Vec<BlasStore<'a>>,
800 hub: &Hub,
801 ray_tracing_scratch_buffer_alignment: u32,
802) -> Result<(), BuildAccelerationStructureError> {
803 let mut triangle_entries =
804 Vec::<hal::AccelerationStructureTriangles<dyn hal::DynBuffer>>::new();
805 for buf in buf_storage {
806 let mesh = &buf.geometry;
807 let vertex_buffer = {
808 let vertex_buffer = buf.vertex_buffer.as_ref();
809 let vertex_raw = vertex_buffer.try_raw(snatch_guard)?;
810 vertex_buffer.check_usage(BufferUsages::BLAS_INPUT)?;
811
812 if let Some(barrier) = buf
813 .vertex_transition
814 .take()
815 .map(|pending| pending.into_hal(vertex_buffer, snatch_guard))
816 {
817 input_barriers.push(barrier);
818 }
819 if vertex_buffer.size
820 < (mesh.size.vertex_count + mesh.first_vertex) as u64 * mesh.vertex_stride
821 {
822 return Err(BuildAccelerationStructureError::InsufficientBufferSize(
823 vertex_buffer.error_ident(),
824 vertex_buffer.size,
825 (mesh.size.vertex_count + mesh.first_vertex) as u64 * mesh.vertex_stride,
826 ));
827 }
828 let vertex_buffer_offset = mesh.first_vertex as u64 * mesh.vertex_stride;
829 cmd_buf_data.buffer_memory_init_actions.extend(
830 vertex_buffer.initialization_status.read().create_action(
831 &hub.buffers.get(mesh.vertex_buffer).get()?,
832 vertex_buffer_offset
833 ..(vertex_buffer_offset
834 + mesh.size.vertex_count as u64 * mesh.vertex_stride),
835 MemoryInitKind::NeedsInitializedMemory,
836 ),
837 );
838 vertex_raw
839 };
840 let index_buffer = if let Some((ref mut index_buffer, ref mut index_pending)) =
841 buf.index_buffer_transition
842 {
843 let index_raw = index_buffer.try_raw(snatch_guard)?;
844 index_buffer.check_usage(BufferUsages::BLAS_INPUT)?;
845
846 if let Some(barrier) = index_pending
847 .take()
848 .map(|pending| pending.into_hal(index_buffer, snatch_guard))
849 {
850 input_barriers.push(barrier);
851 }
852 let index_stride = mesh.size.index_format.unwrap().byte_size() as u64;
853 let offset = mesh.first_index.unwrap() as u64 * index_stride;
854 let index_buffer_size = mesh.size.index_count.unwrap() as u64 * index_stride;
855
856 if mesh.size.index_count.unwrap() % 3 != 0 {
857 return Err(BuildAccelerationStructureError::InvalidIndexCount(
858 index_buffer.error_ident(),
859 mesh.size.index_count.unwrap(),
860 ));
861 }
862 if index_buffer.size < mesh.size.index_count.unwrap() as u64 * index_stride + offset {
863 return Err(BuildAccelerationStructureError::InsufficientBufferSize(
864 index_buffer.error_ident(),
865 index_buffer.size,
866 mesh.size.index_count.unwrap() as u64 * index_stride + offset,
867 ));
868 }
869
870 cmd_buf_data.buffer_memory_init_actions.extend(
871 index_buffer.initialization_status.read().create_action(
872 index_buffer,
873 offset..(offset + index_buffer_size),
874 MemoryInitKind::NeedsInitializedMemory,
875 ),
876 );
877 Some(index_raw)
878 } else {
879 None
880 };
881 let transform_buffer = if let Some((ref mut transform_buffer, ref mut transform_pending)) =
882 buf.transform_buffer_transition
883 {
884 if mesh.transform_buffer_offset.is_none() {
885 return Err(BuildAccelerationStructureError::MissingAssociatedData(
886 transform_buffer.error_ident(),
887 ));
888 }
889 let transform_raw = transform_buffer.try_raw(snatch_guard)?;
890 transform_buffer.check_usage(BufferUsages::BLAS_INPUT)?;
891
892 if let Some(barrier) = transform_pending
893 .take()
894 .map(|pending| pending.into_hal(transform_buffer, snatch_guard))
895 {
896 input_barriers.push(barrier);
897 }
898
899 let offset = mesh.transform_buffer_offset.unwrap();
900
901 if offset % wgt::TRANSFORM_BUFFER_ALIGNMENT != 0 {
902 return Err(
903 BuildAccelerationStructureError::UnalignedTransformBufferOffset(
904 transform_buffer.error_ident(),
905 ),
906 );
907 }
908 if transform_buffer.size < 48 + offset {
909 return Err(BuildAccelerationStructureError::InsufficientBufferSize(
910 transform_buffer.error_ident(),
911 transform_buffer.size,
912 48 + offset,
913 ));
914 }
915 cmd_buf_data.buffer_memory_init_actions.extend(
916 transform_buffer.initialization_status.read().create_action(
917 transform_buffer,
918 offset..(offset + 48),
919 MemoryInitKind::NeedsInitializedMemory,
920 ),
921 );
922 Some(transform_raw)
923 } else {
924 None
925 };
926
927 let triangles = hal::AccelerationStructureTriangles {
928 vertex_buffer: Some(vertex_buffer),
929 vertex_format: mesh.size.vertex_format,
930 first_vertex: mesh.first_vertex,
931 vertex_count: mesh.size.vertex_count,
932 vertex_stride: mesh.vertex_stride,
933 indices: index_buffer.map(|index_buffer| {
934 let index_stride = mesh.size.index_format.unwrap().byte_size() as u32;
935 hal::AccelerationStructureTriangleIndices::<dyn hal::DynBuffer> {
936 format: mesh.size.index_format.unwrap(),
937 buffer: Some(index_buffer),
938 offset: mesh.first_index.unwrap() * index_stride,
939 count: mesh.size.index_count.unwrap(),
940 }
941 }),
942 transform: transform_buffer.map(|transform_buffer| {
943 hal::AccelerationStructureTriangleTransform {
944 buffer: transform_buffer,
945 offset: mesh.transform_buffer_offset.unwrap() as u32,
946 }
947 }),
948 flags: mesh.size.flags,
949 };
950 triangle_entries.push(triangles);
951 if let Some(blas) = buf.ending_blas.take() {
952 let scratch_buffer_offset = *scratch_buffer_blas_size;
953 *scratch_buffer_blas_size += align_to(
954 blas.size_info.build_scratch_size as u32,
955 ray_tracing_scratch_buffer_alignment,
956 ) as u64;
957
958 blas_storage.push(BlasStore {
959 blas,
960 entries: hal::AccelerationStructureEntries::Triangles(triangle_entries),
961 scratch_buffer_offset,
962 });
963 triangle_entries = Vec::new();
964 }
965 }
966 Ok(())
967}
968
969fn map_blas<'a>(
970 storage: &'a BlasStore<'_>,
971 scratch_buffer: &'a dyn hal::DynBuffer,
972 snatch_guard: &'a SnatchGuard,
973 blases_compactable: &mut Vec<(
974 &'a dyn hal::DynBuffer,
975 &'a dyn hal::DynAccelerationStructure,
976 )>,
977) -> Result<
978 hal::BuildAccelerationStructureDescriptor<
979 'a,
980 dyn hal::DynBuffer,
981 dyn hal::DynAccelerationStructure,
982 >,
983 BuildAccelerationStructureError,
984> {
985 let BlasStore {
986 blas,
987 entries,
988 scratch_buffer_offset,
989 } = storage;
990 if blas.update_mode == wgt::AccelerationStructureUpdateMode::PreferUpdate {
991 log::info!("only rebuild implemented")
992 }
993 let raw = blas.try_raw(snatch_guard)?;
994
995 let state_lock = blas.compacted_state.lock();
996 if let BlasCompactState::Compacted = *state_lock {
997 return Err(BuildAccelerationStructureError::CompactedBlas(
998 blas.error_ident(),
999 ));
1000 }
1001
1002 if blas
1003 .flags
1004 .contains(wgpu_types::AccelerationStructureFlags::ALLOW_COMPACTION)
1005 {
1006 blases_compactable.push((blas.compaction_buffer.as_ref().unwrap().as_ref(), raw));
1007 }
1008 Ok(hal::BuildAccelerationStructureDescriptor {
1009 entries,
1010 mode: hal::AccelerationStructureBuildMode::Build,
1011 flags: blas.flags,
1012 source_acceleration_structure: None,
1013 destination_acceleration_structure: raw,
1014 scratch_buffer,
1015 scratch_buffer_offset: *scratch_buffer_offset,
1016 })
1017}
1018
1019fn build_blas<'a>(
1020 cmd_buf_raw: &mut dyn hal::DynCommandEncoder,
1021 blas_present: bool,
1022 tlas_present: bool,
1023 input_barriers: Vec<hal::BufferBarrier<dyn hal::DynBuffer>>,
1024 blas_descriptors: &[hal::BuildAccelerationStructureDescriptor<
1025 'a,
1026 dyn hal::DynBuffer,
1027 dyn hal::DynAccelerationStructure,
1028 >],
1029 scratch_buffer_barrier: hal::BufferBarrier<dyn hal::DynBuffer>,
1030 blas_s_for_compaction: Vec<(
1031 &'a dyn hal::DynBuffer,
1032 &'a dyn hal::DynAccelerationStructure,
1033 )>,
1034) {
1035 unsafe {
1036 cmd_buf_raw.transition_buffers(&input_barriers);
1037 }
1038
1039 if blas_present {
1040 unsafe {
1041 cmd_buf_raw.place_acceleration_structure_barrier(hal::AccelerationStructureBarrier {
1042 usage: hal::StateTransition {
1043 from: hal::AccelerationStructureUses::BUILD_INPUT,
1044 to: hal::AccelerationStructureUses::BUILD_OUTPUT,
1045 },
1046 });
1047
1048 cmd_buf_raw.build_acceleration_structures(blas_descriptors);
1049 }
1050 }
1051
1052 if blas_present && tlas_present {
1053 unsafe {
1054 cmd_buf_raw.transition_buffers(&[scratch_buffer_barrier]);
1055 }
1056 }
1057
1058 let mut source_usage = hal::AccelerationStructureUses::empty();
1059 let mut destination_usage = hal::AccelerationStructureUses::empty();
1060 for &(buf, blas) in blas_s_for_compaction.iter() {
1061 unsafe {
1062 cmd_buf_raw.transition_buffers(&[hal::BufferBarrier {
1063 buffer: buf,
1064 usage: hal::StateTransition {
1065 from: BufferUses::ACCELERATION_STRUCTURE_QUERY,
1066 to: BufferUses::ACCELERATION_STRUCTURE_QUERY,
1067 },
1068 }])
1069 }
1070 unsafe { cmd_buf_raw.read_acceleration_structure_compact_size(blas, buf) }
1071 destination_usage |= hal::AccelerationStructureUses::COPY_SRC;
1072 }
1073
1074 if blas_present {
1075 source_usage |= hal::AccelerationStructureUses::BUILD_OUTPUT;
1076 destination_usage |= hal::AccelerationStructureUses::BUILD_INPUT
1077 }
1078 if tlas_present {
1079 source_usage |= hal::AccelerationStructureUses::SHADER_INPUT;
1080 destination_usage |= hal::AccelerationStructureUses::BUILD_OUTPUT;
1081 }
1082 unsafe {
1083 cmd_buf_raw.place_acceleration_structure_barrier(hal::AccelerationStructureBarrier {
1084 usage: hal::StateTransition {
1085 from: source_usage,
1086 to: destination_usage,
1087 },
1088 });
1089 }
1090}