wgpu_core/command/
ray_tracing.rs

1use alloc::{sync::Arc, vec::Vec};
2use core::{
3    cmp::max,
4    num::NonZeroU64,
5    ops::{Deref, Range},
6};
7
8use wgt::{math::align_to, BufferUsages, BufferUses, Features};
9
10use crate::{
11    command::encoder::EncodingState,
12    ray_tracing::{AsAction, AsBuild, TlasBuild, ValidateAsActionsError},
13    resource::InvalidResourceError,
14};
15use crate::{command::EncoderStateError, device::resource::CommandIndices};
16use crate::{
17    command::{ArcCommand, ArcReferences, CommandBufferMutable},
18    device::queue::TempResource,
19    global::Global,
20    id::CommandEncoderId,
21    init_tracker::MemoryInitKind,
22    ray_tracing::{
23        ArcBlasAabbGeometry, ArcBlasBuildEntry, ArcBlasGeometries, ArcBlasTriangleGeometry,
24        ArcTlasInstance, ArcTlasPackage, BlasBuildEntry, BlasGeometries,
25        BuildAccelerationStructureError, OwnedBlasBuildEntry, OwnedTlasPackage, TlasPackage,
26    },
27    resource::{Blas, BlasCompactState, Labeled, StagingBuffer, Tlas},
28    scratch::ScratchBuffer,
29    snatch::SnatchGuard,
30};
31use crate::{lock::RwLockWriteGuard, resource::RawResourceAccess};
32
33use crate::id::{BlasId, TlasId};
34
35struct BlasStore<'a> {
36    blas: Arc<Blas>,
37    entries: hal::AccelerationStructureEntries<'a, dyn hal::DynBuffer>,
38    scratch_buffer_offset: u64,
39}
40
41struct UnsafeTlasStore<'a> {
42    tlas: Arc<Tlas>,
43    entries: hal::AccelerationStructureEntries<'a, dyn hal::DynBuffer>,
44    scratch_buffer_offset: u64,
45}
46
47struct TlasStore<'a> {
48    internal: UnsafeTlasStore<'a>,
49    range: Range<usize>,
50}
51
52impl Global {
53    fn resolve_blas_id(&self, blas_id: BlasId) -> Result<Arc<Blas>, InvalidResourceError> {
54        self.hub.blas_s.get(blas_id).get()
55    }
56
57    fn resolve_tlas_id(&self, tlas_id: TlasId) -> Result<Arc<Tlas>, InvalidResourceError> {
58        self.hub.tlas_s.get(tlas_id).get()
59    }
60
61    pub fn command_encoder_mark_acceleration_structures_built(
62        &self,
63        command_encoder_id: CommandEncoderId,
64        blas_ids: &[BlasId],
65        tlas_ids: &[TlasId],
66    ) -> Result<(), EncoderStateError> {
67        profiling::scope!("CommandEncoder::mark_acceleration_structures_built");
68
69        let hub = &self.hub;
70
71        let cmd_enc = hub.command_encoders.get(command_encoder_id);
72
73        let mut cmd_buf_data = cmd_enc.data.lock();
74        cmd_buf_data.with_buffer(
75            crate::command::EncodingApi::Raw,
76            |cmd_buf_data| -> Result<(), BuildAccelerationStructureError> {
77                let device = &cmd_enc.device;
78                device.check_is_valid()?;
79                device.require_features(Features::EXPERIMENTAL_RAY_QUERY)?;
80
81                let mut build_command = AsBuild::with_capacity(blas_ids.len(), tlas_ids.len());
82
83                for blas in blas_ids {
84                    let blas = hub.blas_s.get(*blas).get()?;
85                    build_command.blas_s_built.push(blas);
86                }
87
88                for tlas in tlas_ids {
89                    let tlas = hub.tlas_s.get(*tlas).get()?;
90                    build_command.tlas_s_built.push(TlasBuild {
91                        tlas,
92                        dependencies: Vec::new(),
93                    });
94                }
95
96                cmd_buf_data.as_actions.push(AsAction::Build(build_command));
97                Ok(())
98            },
99        )
100    }
101
102    pub fn command_encoder_build_acceleration_structures<'a>(
103        &self,
104        command_encoder_id: CommandEncoderId,
105        blas_iter: impl Iterator<Item = BlasBuildEntry<'a>>,
106        tlas_iter: impl Iterator<Item = TlasPackage<'a>>,
107    ) -> Result<(), EncoderStateError> {
108        profiling::scope!("CommandEncoder::build_acceleration_structures");
109
110        let hub = &self.hub;
111
112        let cmd_enc = hub.command_encoders.get(command_encoder_id);
113        let mut cmd_buf_data = cmd_enc.data.lock();
114
115        cmd_buf_data.push_with(|| -> Result<_, BuildAccelerationStructureError> {
116            let blas = blas_iter
117                .map(|blas_entry| {
118                    let geometries = match blas_entry.geometries {
119                        BlasGeometries::TriangleGeometries(triangle_geometries) => {
120                            let tri_geo = triangle_geometries
121                                .map(|tg| {
122                                    Ok(ArcBlasTriangleGeometry {
123                                        size: tg.size.clone(),
124                                        vertex_buffer: self.resolve_buffer_id(tg.vertex_buffer)?,
125                                        index_buffer: tg
126                                            .index_buffer
127                                            .map(|id| self.resolve_buffer_id(id))
128                                            .transpose()?,
129                                        transform_buffer: tg
130                                            .transform_buffer
131                                            .map(|id| self.resolve_buffer_id(id))
132                                            .transpose()?,
133                                        first_vertex: tg.first_vertex,
134                                        vertex_stride: tg.vertex_stride,
135                                        first_index: tg.first_index,
136                                        transform_buffer_offset: tg.transform_buffer_offset,
137                                    })
138                                })
139                                .collect::<Result<_, BuildAccelerationStructureError>>()?;
140                            ArcBlasGeometries::TriangleGeometries(tri_geo)
141                        }
142                        BlasGeometries::AabbGeometries(aabb_geometries) => {
143                            let aabb_geo = aabb_geometries
144                                .map(|ag| {
145                                    Ok(ArcBlasAabbGeometry {
146                                        size: ag.size.clone(),
147                                        stride: ag.stride,
148                                        aabb_buffer: self.resolve_buffer_id(ag.aabb_buffer)?,
149                                        primitive_offset: ag.primitive_offset,
150                                    })
151                                })
152                                .collect::<Result<_, BuildAccelerationStructureError>>()?;
153                            ArcBlasGeometries::AabbGeometries(aabb_geo)
154                        }
155                    };
156                    Ok(ArcBlasBuildEntry {
157                        blas: self.resolve_blas_id(blas_entry.blas_id)?,
158                        geometries,
159                    })
160                })
161                .collect::<Result<_, BuildAccelerationStructureError>>()?;
162
163            let tlas = tlas_iter
164                .map(|tlas_package| {
165                    let instances = tlas_package
166                        .instances
167                        .map(|instance| {
168                            instance
169                                .as_ref()
170                                .map(|instance| {
171                                    Ok(ArcTlasInstance {
172                                        blas: self.resolve_blas_id(instance.blas_id)?,
173                                        transform: *instance.transform,
174                                        custom_data: instance.custom_data,
175                                        mask: instance.mask,
176                                    })
177                                })
178                                .transpose()
179                        })
180                        .collect::<Result<_, BuildAccelerationStructureError>>()?;
181                    Ok(ArcTlasPackage {
182                        tlas: self.resolve_tlas_id(tlas_package.tlas_id)?,
183                        instances,
184                        lowest_unmodified: tlas_package.lowest_unmodified,
185                    })
186                })
187                .collect::<Result<_, BuildAccelerationStructureError>>()?;
188
189            Ok(ArcCommand::BuildAccelerationStructures { blas, tlas })
190        })
191    }
192}
193
194pub(crate) fn build_acceleration_structures(
195    state: &mut EncodingState,
196    blas: Vec<OwnedBlasBuildEntry<ArcReferences>>,
197    tlas: Vec<OwnedTlasPackage<ArcReferences>>,
198) -> Result<(), BuildAccelerationStructureError> {
199    state
200        .device
201        .require_features(Features::EXPERIMENTAL_RAY_QUERY)?;
202
203    let mut build_command = AsBuild::with_capacity(blas.len(), tlas.len());
204    let mut input_barriers = Vec::<hal::BufferBarrier<dyn hal::DynBuffer>>::new();
205    let mut scratch_buffer_blas_size = 0;
206    let mut blas_storage = Vec::with_capacity(blas.len());
207    iter_blas(
208        blas.iter(),
209        &mut build_command,
210        &mut input_barriers,
211        &mut scratch_buffer_blas_size,
212        &mut blas_storage,
213        state,
214    )?;
215
216    let mut scratch_buffer_tlas_size: u64 = 0;
217    let mut tlas_storage = Vec::<TlasStore>::with_capacity(tlas.len());
218    let mut instance_buffer_staging_source = Vec::<u8>::new();
219
220    for package in tlas.iter() {
221        let tlas = &package.tlas;
222        state.tracker.tlas_s.insert_single(tlas.clone());
223
224        let scratch_buffer_offset = scratch_buffer_tlas_size;
225        scratch_buffer_tlas_size = scratch_buffer_tlas_size.saturating_add(align_to(
226            tlas.size_info.build_scratch_size,
227            u64::from(state.device.alignments.ray_tracing_scratch_buffer_alignment),
228        ));
229
230        let first_byte_index = instance_buffer_staging_source.len();
231
232        let mut dependencies = Vec::new();
233
234        let mut instance_count = 0;
235        for instance in package.instances.iter().flatten() {
236            if instance.custom_data >= (1u32 << 24u32) {
237                return Err(BuildAccelerationStructureError::TlasInvalidCustomIndex(
238                    tlas.error_ident(),
239                ));
240            }
241            let blas = &instance.blas;
242            state.tracker.blas_s.insert_single(blas.clone());
243
244            instance_buffer_staging_source.extend(state.device.raw().tlas_instance_to_bytes(
245                hal::TlasInstance {
246                    transform: instance.transform,
247                    custom_data: instance.custom_data,
248                    mask: instance.mask,
249                    blas_address: blas.handle,
250                },
251            ));
252
253            if tlas
254                .flags
255                .contains(wgpu_types::AccelerationStructureFlags::ALLOW_RAY_HIT_VERTEX_RETURN)
256                && !blas
257                    .flags
258                    .contains(wgpu_types::AccelerationStructureFlags::ALLOW_RAY_HIT_VERTEX_RETURN)
259            {
260                return Err(
261                    BuildAccelerationStructureError::TlasDependentMissingVertexReturn(
262                        tlas.error_ident(),
263                        blas.error_ident(),
264                    ),
265                );
266            }
267
268            instance_count += 1;
269
270            dependencies.push(blas.clone());
271        }
272
273        build_command.tlas_s_built.push(TlasBuild {
274            tlas: tlas.clone(),
275            dependencies,
276        });
277
278        if instance_count > tlas.max_instance_count {
279            return Err(BuildAccelerationStructureError::TlasInstanceCountExceeded(
280                tlas.error_ident(),
281                instance_count,
282                tlas.max_instance_count,
283            ));
284        }
285
286        tlas_storage.push(TlasStore {
287            internal: UnsafeTlasStore {
288                tlas: tlas.clone(),
289                entries: hal::AccelerationStructureEntries::Instances(
290                    hal::AccelerationStructureInstances {
291                        buffer: Some(tlas.instance_buffer.as_ref()),
292                        offset: 0,
293                        count: instance_count,
294                    },
295                ),
296                scratch_buffer_offset,
297            },
298            range: first_byte_index..instance_buffer_staging_source.len(),
299        });
300    }
301
302    let Some(scratch_size) =
303        wgt::BufferSize::new(max(scratch_buffer_blas_size, scratch_buffer_tlas_size))
304    else {
305        // if the size is zero there is nothing to build
306        return Ok(());
307    };
308
309    if scratch_size.get() == u64::MAX {
310        return Err(crate::device::DeviceError::OutOfMemory.into());
311    }
312
313    let scratch_buffer = ScratchBuffer::new(state.device, scratch_size)?;
314
315    let scratch_buffer_barrier = hal::BufferBarrier::<dyn hal::DynBuffer> {
316        buffer: scratch_buffer.raw(),
317        usage: hal::StateTransition {
318            from: BufferUses::ACCELERATION_STRUCTURE_SCRATCH,
319            to: BufferUses::ACCELERATION_STRUCTURE_SCRATCH,
320        },
321    };
322
323    let mut tlas_descriptors = Vec::with_capacity(tlas_storage.len());
324
325    for &TlasStore {
326        internal:
327            UnsafeTlasStore {
328                ref tlas,
329                ref entries,
330                ref scratch_buffer_offset,
331            },
332        ..
333    } in &tlas_storage
334    {
335        if tlas.update_mode == wgt::AccelerationStructureUpdateMode::PreferUpdate {
336            log::warn!("build_acceleration_structures called with PreferUpdate, but only rebuild is implemented");
337        }
338        tlas_descriptors.push(hal::BuildAccelerationStructureDescriptor {
339            entries,
340            mode: hal::AccelerationStructureBuildMode::Build,
341            flags: tlas.flags,
342            source_acceleration_structure: None,
343            destination_acceleration_structure: tlas.try_raw(state.snatch_guard)?,
344            scratch_buffer: scratch_buffer.raw(),
345            scratch_buffer_offset: *scratch_buffer_offset,
346        })
347    }
348
349    let blas_present = !blas_storage.is_empty();
350    let tlas_present = !tlas_storage.is_empty();
351
352    let raw_encoder = &mut state.raw_encoder;
353
354    let mut blas_s_compactable = Vec::new();
355    let mut descriptors = Vec::with_capacity(blas.len());
356
357    for storage in &blas_storage {
358        descriptors.push(map_blas(
359            storage,
360            scratch_buffer.raw(),
361            state.snatch_guard,
362            &mut blas_s_compactable,
363        )?);
364    }
365
366    build_blas(
367        *raw_encoder,
368        blas_present,
369        tlas_present,
370        input_barriers,
371        &descriptors,
372        scratch_buffer_barrier,
373        blas_s_compactable,
374    );
375
376    if tlas_present {
377        let staging_buffer = if !instance_buffer_staging_source.is_empty() {
378            let mut staging_buffer = StagingBuffer::new(
379                state.device,
380                wgt::BufferSize::new(instance_buffer_staging_source.len() as u64).unwrap(),
381            )?;
382            staging_buffer.write(&instance_buffer_staging_source);
383            let flushed = staging_buffer.flush();
384            Some(flushed)
385        } else {
386            None
387        };
388
389        unsafe {
390            if let Some(ref staging_buffer) = staging_buffer {
391                raw_encoder.transition_buffers(&[hal::BufferBarrier::<dyn hal::DynBuffer> {
392                    buffer: staging_buffer.raw(),
393                    usage: hal::StateTransition {
394                        from: BufferUses::MAP_WRITE,
395                        to: BufferUses::COPY_SRC,
396                    },
397                }]);
398            }
399        }
400
401        let mut instance_buffer_barriers = Vec::new();
402        for &TlasStore {
403            internal: UnsafeTlasStore { ref tlas, .. },
404            ref range,
405        } in &tlas_storage
406        {
407            let size = match wgt::BufferSize::new((range.end - range.start) as u64) {
408                None => continue,
409                Some(size) => size,
410            };
411            instance_buffer_barriers.push(hal::BufferBarrier::<dyn hal::DynBuffer> {
412                buffer: tlas.instance_buffer.as_ref(),
413                usage: hal::StateTransition {
414                    from: BufferUses::COPY_DST,
415                    to: BufferUses::TOP_LEVEL_ACCELERATION_STRUCTURE_INPUT,
416                },
417            });
418            unsafe {
419                raw_encoder.transition_buffers(&[hal::BufferBarrier::<dyn hal::DynBuffer> {
420                    buffer: tlas.instance_buffer.as_ref(),
421                    usage: hal::StateTransition {
422                        from: BufferUses::TOP_LEVEL_ACCELERATION_STRUCTURE_INPUT,
423                        to: BufferUses::COPY_DST,
424                    },
425                }]);
426                let temp = hal::BufferCopy {
427                    src_offset: range.start as u64,
428                    dst_offset: 0,
429                    size,
430                };
431                raw_encoder.copy_buffer_to_buffer(
432                    staging_buffer.as_ref().unwrap().raw(),
433                    tlas.instance_buffer.as_ref(),
434                    &[temp],
435                );
436            }
437        }
438
439        unsafe {
440            raw_encoder.transition_buffers(&instance_buffer_barriers);
441
442            raw_encoder.build_acceleration_structures(&tlas_descriptors);
443
444            raw_encoder.place_acceleration_structure_barrier(hal::AccelerationStructureBarrier {
445                usage: hal::StateTransition {
446                    from: hal::AccelerationStructureUses::BUILD_OUTPUT,
447                    to: hal::AccelerationStructureUses::SHADER_INPUT,
448                },
449            });
450        }
451
452        if let Some(staging_buffer) = staging_buffer {
453            state
454                .temp_resources
455                .push(TempResource::StagingBuffer(staging_buffer));
456        }
457    }
458
459    state
460        .temp_resources
461        .push(TempResource::ScratchBuffer(scratch_buffer));
462
463    state.as_actions.push(AsAction::Build(build_command));
464
465    Ok(())
466}
467
468impl CommandBufferMutable {
469    pub(crate) fn validate_acceleration_structure_actions(
470        &self,
471        snatch_guard: &SnatchGuard,
472        command_index_guard: &mut RwLockWriteGuard<CommandIndices>,
473    ) -> Result<(), ValidateAsActionsError> {
474        profiling::scope!("CommandEncoder::[submission]::validate_as_actions");
475        for action in &self.as_actions {
476            match action {
477                AsAction::Build(build) => {
478                    let build_command_index = NonZeroU64::new(
479                        command_index_guard.next_acceleration_structure_build_command_index,
480                    )
481                    .unwrap();
482
483                    command_index_guard.next_acceleration_structure_build_command_index += 1;
484                    for blas in build.blas_s_built.iter() {
485                        let mut state_lock = blas.compacted_state.lock();
486                        *state_lock = match *state_lock {
487                            BlasCompactState::Compacted => {
488                                unreachable!("Should be validated out in build.")
489                            }
490                            // Reset the compacted state to idle. This means any prepares, before mapping their
491                            // internal buffer, will terminate.
492                            _ => BlasCompactState::Idle,
493                        };
494                        *blas.built_index.write() = Some(build_command_index);
495                    }
496
497                    for tlas_build in build.tlas_s_built.iter() {
498                        for blas in &tlas_build.dependencies {
499                            if blas.built_index.read().is_none() {
500                                return Err(ValidateAsActionsError::UsedUnbuiltBlas(
501                                    blas.error_ident(),
502                                    tlas_build.tlas.error_ident(),
503                                ));
504                            }
505                        }
506                        *tlas_build.tlas.built_index.write() = Some(build_command_index);
507                        tlas_build
508                            .tlas
509                            .dependencies
510                            .write()
511                            .clone_from(&tlas_build.dependencies)
512                    }
513                }
514                AsAction::UseTlas(tlas) => {
515                    let tlas_build_index = tlas.built_index.read();
516                    let dependencies = tlas.dependencies.read();
517
518                    if (*tlas_build_index).is_none() {
519                        return Err(ValidateAsActionsError::UsedUnbuiltTlas(tlas.error_ident()));
520                    }
521                    for blas in dependencies.deref() {
522                        let blas_build_index = *blas.built_index.read();
523                        if blas_build_index.is_none() {
524                            return Err(ValidateAsActionsError::UsedUnbuiltBlas(
525                                tlas.error_ident(),
526                                blas.error_ident(),
527                            ));
528                        }
529                        if blas_build_index.unwrap() > tlas_build_index.unwrap() {
530                            return Err(ValidateAsActionsError::BlasNewerThenTlas(
531                                blas.error_ident(),
532                                tlas.error_ident(),
533                            ));
534                        }
535                        blas.try_raw(snatch_guard)?;
536                    }
537                }
538            }
539        }
540        Ok(())
541    }
542
543    pub(crate) fn set_acceleration_structure_dependencies(&self, snatch_guard: &SnatchGuard) {
544        profiling::scope!("CommandEncoder::[submission]::set_acceleration_structure_dependencies");
545        let tlas_dependencies_locks: Vec<_> = self
546            .as_actions
547            .iter()
548            .filter_map(|action| {
549                if let AsAction::UseTlas(tlas) = action {
550                    Some(tlas.dependencies.read())
551                } else {
552                    None
553                }
554            })
555            .collect();
556        let mut tlas_dependencies_lock_iter = tlas_dependencies_locks.iter();
557        let mut dependencies = Vec::new();
558        for action in &self.as_actions {
559            match action {
560                AsAction::Build(build) => {
561                    for tlas_build in build.tlas_s_built.iter() {
562                        for dependency in &tlas_build.dependencies {
563                            if let Some(dependency) = dependency.raw(snatch_guard) {
564                                dependencies.push(dependency);
565                            }
566                        }
567                    }
568                }
569                AsAction::UseTlas(_tlas) => {
570                    let tlas_dependencies = tlas_dependencies_lock_iter.next().unwrap(); // _tlas.dependencies.read();
571                    for dependency in tlas_dependencies.iter() {
572                        if let Some(dependency) = dependency.raw(snatch_guard) {
573                            dependencies.push(dependency);
574                        }
575                    }
576                }
577            }
578        }
579        if !dependencies.is_empty() {
580            unsafe {
581                self.encoder
582                    .raw
583                    .set_acceleration_structure_dependencies(&self.encoder.list, &dependencies);
584            }
585        }
586    }
587}
588
589///iterates over the blas iterator, and it's geometry, pushing the buffers into a storage vector (and also some validation).
590fn iter_blas<'snatch_guard: 'buffers, 'buffers>(
591    blas_iter: impl Iterator<Item = &'buffers OwnedBlasBuildEntry<ArcReferences>>,
592    build_command: &mut AsBuild,
593    input_barriers: &mut Vec<hal::BufferBarrier<'buffers, dyn hal::DynBuffer>>,
594    scratch_buffer_blas_size: &mut u64,
595    blas_storage: &mut Vec<BlasStore<'buffers>>,
596    state: &mut EncodingState<'snatch_guard, '_>,
597) -> Result<(), BuildAccelerationStructureError> {
598    for entry in blas_iter {
599        let blas = &entry.blas;
600        state.tracker.blas_s.insert_single(blas.clone());
601
602        build_command.blas_s_built.push(blas.clone());
603
604        match &entry.geometries {
605            ArcBlasGeometries::TriangleGeometries(triangle_geometries) => {
606                let mut triangle_entries =
607                    Vec::<hal::AccelerationStructureTriangles<dyn hal::DynBuffer>>::new();
608
609                for (i, mesh) in triangle_geometries.iter().enumerate() {
610                    let size_desc = match &blas.sizes {
611                        wgt::BlasGeometrySizeDescriptors::Triangles { descriptors } => descriptors,
612                        _ => {
613                            return Err(BuildAccelerationStructureError::BlasGeometryKindMismatch(
614                                blas.error_ident(),
615                            ));
616                        }
617                    };
618                    if i >= size_desc.len() {
619                        return Err(BuildAccelerationStructureError::IncompatibleBlasBuildSizes(
620                            blas.error_ident(),
621                        ));
622                    }
623                    let size_desc = &size_desc[i];
624
625                    if size_desc.flags != mesh.size.flags {
626                        return Err(BuildAccelerationStructureError::IncompatibleBlasFlags(
627                            blas.error_ident(),
628                            size_desc.flags,
629                            mesh.size.flags,
630                        ));
631                    }
632
633                    if size_desc.vertex_count < mesh.size.vertex_count {
634                        return Err(
635                            BuildAccelerationStructureError::IncompatibleBlasVertexCount(
636                                blas.error_ident(),
637                                size_desc.vertex_count,
638                                mesh.size.vertex_count,
639                            ),
640                        );
641                    }
642
643                    if size_desc.vertex_format != mesh.size.vertex_format {
644                        return Err(BuildAccelerationStructureError::DifferentBlasVertexFormats(
645                            blas.error_ident(),
646                            size_desc.vertex_format,
647                            mesh.size.vertex_format,
648                        ));
649                    }
650
651                    if size_desc
652                        .vertex_format
653                        .min_acceleration_structure_vertex_stride()
654                        > mesh.vertex_stride
655                    {
656                        return Err(BuildAccelerationStructureError::VertexStrideTooSmall(
657                            blas.error_ident(),
658                            size_desc
659                                .vertex_format
660                                .min_acceleration_structure_vertex_stride(),
661                            mesh.vertex_stride,
662                        ));
663                    }
664
665                    if mesh.vertex_stride
666                        % size_desc
667                            .vertex_format
668                            .acceleration_structure_stride_alignment()
669                        != 0
670                    {
671                        return Err(BuildAccelerationStructureError::VertexStrideUnaligned(
672                            blas.error_ident(),
673                            size_desc
674                                .vertex_format
675                                .acceleration_structure_stride_alignment(),
676                            mesh.vertex_stride,
677                        ));
678                    }
679
680                    match (size_desc.index_count, mesh.size.index_count) {
681                        (Some(_), None) | (None, Some(_)) => {
682                            return Err(
683                                BuildAccelerationStructureError::BlasIndexCountProvidedMismatch(
684                                    blas.error_ident(),
685                                ),
686                            )
687                        }
688                        (Some(create), Some(build)) if create < build => {
689                            return Err(
690                                BuildAccelerationStructureError::IncompatibleBlasIndexCount(
691                                    blas.error_ident(),
692                                    create,
693                                    build,
694                                ),
695                            )
696                        }
697                        _ => {}
698                    }
699
700                    if size_desc.index_format != mesh.size.index_format {
701                        return Err(BuildAccelerationStructureError::DifferentBlasIndexFormats(
702                            blas.error_ident(),
703                            size_desc.index_format,
704                            mesh.size.index_format,
705                        ));
706                    }
707
708                    if size_desc.index_count.is_some() && mesh.index_buffer.is_none() {
709                        return Err(BuildAccelerationStructureError::MissingIndexBuffer(
710                            blas.error_ident(),
711                        ));
712                    }
713                    let vertex_buffer = mesh.vertex_buffer.clone();
714                    let vertex_pending = state.tracker.buffers.set_single(
715                        &vertex_buffer,
716                        BufferUses::BOTTOM_LEVEL_ACCELERATION_STRUCTURE_INPUT,
717                    );
718                    let vertex_buffer = {
719                        let vertex_raw = mesh.vertex_buffer.as_ref().try_raw(state.snatch_guard)?;
720                        let vertex_buffer = &mesh.vertex_buffer;
721                        vertex_buffer.check_usage(BufferUsages::BLAS_INPUT)?;
722
723                        if let Some(barrier) = vertex_pending.map(|pending| {
724                            pending.into_hal(vertex_buffer.as_ref(), state.snatch_guard)
725                        }) {
726                            input_barriers.push(barrier);
727                        }
728                        if u64::from(mesh.size.vertex_count)
729                            .checked_add(u64::from(mesh.first_vertex))
730                            .and_then(|end_vertex| end_vertex.checked_mul(mesh.vertex_stride))
731                            .is_none_or(|end| vertex_buffer.size < end)
732                        {
733                            return Err(BuildAccelerationStructureError::InsufficientBufferSize {
734                                buffer_ident: vertex_buffer.error_ident(),
735                                offset: u64::from(mesh.first_vertex)
736                                    .saturating_mul(mesh.vertex_stride),
737                                region_size: u64::from(mesh.size.vertex_count)
738                                    .saturating_mul(mesh.vertex_stride),
739                                buffer_size: vertex_buffer.size,
740                            });
741                        }
742                        let vertex_buffer_offset = mesh.first_vertex as u64 * mesh.vertex_stride;
743                        state.buffer_memory_init_actions.extend(
744                            vertex_buffer.initialization_status.read().create_action(
745                                vertex_buffer,
746                                vertex_buffer_offset
747                                    ..(vertex_buffer_offset
748                                        + mesh.size.vertex_count as u64 * mesh.vertex_stride),
749                                MemoryInitKind::NeedsInitializedMemory,
750                            ),
751                        );
752                        vertex_raw
753                    };
754                    let index_buffer = if let Some(ref index_buffer) = mesh.index_buffer {
755                        if mesh.first_index.is_none()
756                            || mesh.size.index_count.is_none()
757                            || mesh.size.index_count.is_none()
758                        {
759                            return Err(BuildAccelerationStructureError::MissingAssociatedData(
760                                index_buffer.error_ident(),
761                            ));
762                        }
763                        let index_pending = state.tracker.buffers.set_single(
764                            index_buffer,
765                            BufferUses::BOTTOM_LEVEL_ACCELERATION_STRUCTURE_INPUT,
766                        );
767                        let index_raw = index_buffer.try_raw(state.snatch_guard)?;
768                        index_buffer.check_usage(BufferUsages::BLAS_INPUT)?;
769
770                        if let Some(barrier) = index_pending.map(|pending| {
771                            pending.into_hal(index_buffer.as_ref(), state.snatch_guard)
772                        }) {
773                            input_barriers.push(barrier);
774                        }
775                        let index_stride = mesh.size.index_format.unwrap().byte_size();
776                        // `hal::AccelerationStructureTriangleIndices` accepts only `u32` offset
777                        let Some(vertex_offset) =
778                            mesh.first_index.unwrap().checked_mul(index_stride)
779                        else {
780                            return Err(BuildAccelerationStructureError::OffsetLimitedTo4GB {
781                                buffer_ident: index_buffer.error_ident(),
782                                offset: u64::from(mesh.first_index.unwrap())
783                                    .saturating_mul(u64::from(index_stride)),
784                                count: u64::from(mesh.first_index.unwrap()),
785                                stride: u64::from(index_stride),
786                            });
787                        };
788                        let Some(indexes_size) =
789                            mesh.size.index_count.unwrap().checked_mul(index_stride)
790                        else {
791                            return Err(BuildAccelerationStructureError::OffsetLimitedTo4GB {
792                                buffer_ident: index_buffer.error_ident(),
793                                offset: u64::from(mesh.size.index_count.unwrap())
794                                    .saturating_mul(u64::from(index_stride)),
795                                count: u64::from(mesh.size.index_count.unwrap()),
796                                stride: u64::from(index_stride),
797                            });
798                        };
799
800                        if mesh.size.index_count.unwrap() % 3 != 0 {
801                            return Err(BuildAccelerationStructureError::InvalidIndexCount(
802                                index_buffer.error_ident(),
803                                mesh.size.index_count.unwrap(),
804                            ));
805                        }
806                        if index_buffer.size < u64::from(vertex_offset)
807                            || index_buffer.size - u64::from(vertex_offset)
808                                < u64::from(indexes_size)
809                        {
810                            return Err(BuildAccelerationStructureError::InsufficientBufferSize {
811                                buffer_ident: index_buffer.error_ident(),
812                                offset: u64::from(vertex_offset),
813                                region_size: u64::from(indexes_size),
814                                buffer_size: index_buffer.size,
815                            });
816                        }
817
818                        state.buffer_memory_init_actions.extend(
819                            index_buffer.initialization_status.read().create_action(
820                                index_buffer,
821                                u64::from(vertex_offset)
822                                    ..(u64::from(vertex_offset) + u64::from(indexes_size)),
823                                MemoryInitKind::NeedsInitializedMemory,
824                            ),
825                        );
826                        Some((index_raw, vertex_offset))
827                    } else {
828                        None
829                    };
830                    let transform_buffer = if let Some(ref transform_buffer) = mesh.transform_buffer
831                    {
832                        if !blas
833                            .flags
834                            .contains(wgt::AccelerationStructureFlags::USE_TRANSFORM)
835                        {
836                            return Err(BuildAccelerationStructureError::UseTransformMissing(
837                                blas.error_ident(),
838                            ));
839                        }
840                        if mesh.transform_buffer_offset.is_none() {
841                            return Err(BuildAccelerationStructureError::MissingAssociatedData(
842                                transform_buffer.error_ident(),
843                            ));
844                        }
845                        let transform_pending = state.tracker.buffers.set_single(
846                            transform_buffer,
847                            BufferUses::BOTTOM_LEVEL_ACCELERATION_STRUCTURE_INPUT,
848                        );
849                        if mesh.transform_buffer_offset.is_none() {
850                            return Err(BuildAccelerationStructureError::MissingAssociatedData(
851                                transform_buffer.error_ident(),
852                            ));
853                        }
854                        let transform_raw = transform_buffer.try_raw(state.snatch_guard)?;
855                        transform_buffer.check_usage(BufferUsages::BLAS_INPUT)?;
856
857                        if let Some(barrier) = transform_pending.map(|pending| {
858                            pending.into_hal(transform_buffer.as_ref(), state.snatch_guard)
859                        }) {
860                            input_barriers.push(barrier);
861                        }
862
863                        // `hal::AccelerationStructureTriangleTransform` accepts only `u32` offset
864                        let Ok(offset) = u32::try_from(mesh.transform_buffer_offset.unwrap())
865                        else {
866                            return Err(BuildAccelerationStructureError::OffsetLimitedTo4GB {
867                                buffer_ident: transform_buffer.error_ident(),
868                                offset: mesh.transform_buffer_offset.unwrap(),
869                                count: mesh.transform_buffer_offset.unwrap(),
870                                stride: 1,
871                            });
872                        };
873
874                        if offset % wgt::TRANSFORM_BUFFER_ALIGNMENT as u32 != 0 {
875                            return Err(
876                                BuildAccelerationStructureError::UnalignedTransformBufferOffset(
877                                    transform_buffer.error_ident(),
878                                ),
879                            );
880                        }
881                        if transform_buffer.size < 48
882                            || transform_buffer.size - 48 < u64::from(offset)
883                        {
884                            return Err(BuildAccelerationStructureError::InsufficientBufferSize {
885                                buffer_ident: transform_buffer.error_ident(),
886                                region_size: 48,
887                                offset: u64::from(offset),
888                                buffer_size: transform_buffer.size,
889                            });
890                        }
891                        state.buffer_memory_init_actions.extend(
892                            transform_buffer.initialization_status.read().create_action(
893                                transform_buffer,
894                                u64::from(offset)..(u64::from(offset) + 48),
895                                MemoryInitKind::NeedsInitializedMemory,
896                            ),
897                        );
898                        Some((transform_raw, offset))
899                    } else {
900                        if blas
901                            .flags
902                            .contains(wgt::AccelerationStructureFlags::USE_TRANSFORM)
903                        {
904                            return Err(BuildAccelerationStructureError::TransformMissing(
905                                blas.error_ident(),
906                            ));
907                        }
908                        None
909                    };
910
911                    let triangles = hal::AccelerationStructureTriangles {
912                        vertex_buffer: Some(vertex_buffer),
913                        vertex_format: mesh.size.vertex_format,
914                        first_vertex: mesh.first_vertex,
915                        vertex_count: mesh.size.vertex_count,
916                        vertex_stride: mesh.vertex_stride,
917                        indices: index_buffer.map(|(index_raw, vertex_offset)| {
918                            hal::AccelerationStructureTriangleIndices::<dyn hal::DynBuffer> {
919                                format: mesh.size.index_format.unwrap(),
920                                buffer: Some(index_raw),
921                                offset: vertex_offset,
922                                count: mesh.size.index_count.unwrap(),
923                            }
924                        }),
925                        transform: transform_buffer.map(|(transform_raw, offset)| {
926                            hal::AccelerationStructureTriangleTransform {
927                                buffer: transform_raw,
928                                offset,
929                            }
930                        }),
931                        flags: mesh.size.flags,
932                    };
933                    triangle_entries.push(triangles);
934                }
935
936                {
937                    let scratch_buffer_offset = *scratch_buffer_blas_size;
938                    *scratch_buffer_blas_size = scratch_buffer_blas_size.saturating_add(align_to(
939                        blas.size_info.build_scratch_size,
940                        u64::from(state.device.alignments.ray_tracing_scratch_buffer_alignment),
941                    ));
942
943                    blas_storage.push(BlasStore {
944                        blas: blas.clone(),
945                        entries: hal::AccelerationStructureEntries::Triangles(triangle_entries),
946                        scratch_buffer_offset,
947                    });
948                }
949            }
950            ArcBlasGeometries::AabbGeometries(aabb_geometries) => {
951                let mut aabb_entries =
952                    Vec::<hal::AccelerationStructureAABBs<dyn hal::DynBuffer>>::new();
953
954                for (i, aabb) in aabb_geometries.iter().enumerate() {
955                    let size_desc = match &blas.sizes {
956                        wgt::BlasGeometrySizeDescriptors::AABBs { descriptors } => descriptors,
957                        _ => {
958                            return Err(BuildAccelerationStructureError::BlasGeometryKindMismatch(
959                                blas.error_ident(),
960                            ));
961                        }
962                    };
963                    if i >= size_desc.len() {
964                        return Err(BuildAccelerationStructureError::IncompatibleBlasBuildSizes(
965                            blas.error_ident(),
966                        ));
967                    }
968                    let size_desc = &size_desc[i];
969
970                    if size_desc.flags != aabb.size.flags {
971                        return Err(BuildAccelerationStructureError::IncompatibleBlasFlags(
972                            blas.error_ident(),
973                            size_desc.flags,
974                            aabb.size.flags,
975                        ));
976                    }
977
978                    if size_desc.primitive_count < aabb.size.primitive_count {
979                        return Err(
980                            BuildAccelerationStructureError::IncompatibleBlasAabbPrimitiveCount(
981                                blas.error_ident(),
982                                size_desc.primitive_count,
983                                aabb.size.primitive_count,
984                            ),
985                        );
986                    }
987
988                    if aabb.primitive_offset % 8 != 0 {
989                        return Err(
990                            BuildAccelerationStructureError::UnalignedAabbPrimitiveOffset(
991                                blas.error_ident(),
992                            ),
993                        );
994                    }
995
996                    if aabb.stride < wgt::AABB_GEOMETRY_MIN_STRIDE || aabb.stride % 8 != 0 {
997                        return Err(BuildAccelerationStructureError::InvalidAabbStride(
998                            blas.error_ident(),
999                            aabb.stride,
1000                        ));
1001                    }
1002
1003                    let aabb_buffer = aabb.aabb_buffer.clone();
1004                    let aabb_pending = state.tracker.buffers.set_single(
1005                        &aabb_buffer,
1006                        BufferUses::BOTTOM_LEVEL_ACCELERATION_STRUCTURE_INPUT,
1007                    );
1008                    let aabb_raw = {
1009                        let aabb_raw = aabb.aabb_buffer.as_ref().try_raw(state.snatch_guard)?;
1010                        let aabb_buffer = &aabb.aabb_buffer;
1011                        aabb_buffer.check_usage(BufferUsages::BLAS_INPUT)?;
1012
1013                        if let Some(barrier) = aabb_pending.map(|pending| {
1014                            pending.into_hal(aabb_buffer.as_ref(), state.snatch_guard)
1015                        }) {
1016                            input_barriers.push(barrier);
1017                        }
1018
1019                        // `hal::AccelerationStructureAABBs` accepts only `u32` offset
1020                        let Some(aabb_size) = u32::try_from(aabb.stride)
1021                            .ok()
1022                            .and_then(|stride| aabb.size.primitive_count.checked_mul(stride))
1023                        else {
1024                            return Err(BuildAccelerationStructureError::OffsetLimitedTo4GB {
1025                                buffer_ident: aabb_buffer.error_ident(),
1026                                offset: u64::from(aabb.size.primitive_count)
1027                                    .saturating_mul(aabb.stride),
1028                                count: u64::from(aabb.size.primitive_count),
1029                                stride: aabb.stride,
1030                            });
1031                        };
1032
1033                        if aabb_buffer.size < u64::from(aabb.primitive_offset)
1034                            || aabb_buffer.size - u64::from(aabb.primitive_offset)
1035                                < u64::from(aabb_size)
1036                        {
1037                            return Err(BuildAccelerationStructureError::InsufficientBufferSize {
1038                                buffer_ident: aabb_buffer.error_ident(),
1039                                region_size: u64::from(aabb_size),
1040                                offset: u64::from(aabb.primitive_offset),
1041                                buffer_size: aabb_buffer.size,
1042                            });
1043                        }
1044
1045                        state.buffer_memory_init_actions.extend(
1046                            aabb_buffer.initialization_status.read().create_action(
1047                                aabb_buffer,
1048                                u64::from(aabb.primitive_offset)
1049                                    ..u64::from(aabb.primitive_offset) + u64::from(aabb_size),
1050                                MemoryInitKind::NeedsInitializedMemory,
1051                            ),
1052                        );
1053                        aabb_raw
1054                    };
1055
1056                    aabb_entries.push(hal::AccelerationStructureAABBs {
1057                        buffer: Some(aabb_raw),
1058                        offset: aabb.primitive_offset,
1059                        count: aabb.size.primitive_count,
1060                        stride: aabb.stride,
1061                        flags: aabb.size.flags,
1062                    });
1063                }
1064
1065                {
1066                    let scratch_buffer_offset = *scratch_buffer_blas_size;
1067                    *scratch_buffer_blas_size = scratch_buffer_blas_size.saturating_add(align_to(
1068                        blas.size_info.build_scratch_size,
1069                        u64::from(state.device.alignments.ray_tracing_scratch_buffer_alignment),
1070                    ));
1071
1072                    blas_storage.push(BlasStore {
1073                        blas: blas.clone(),
1074                        entries: hal::AccelerationStructureEntries::AABBs(aabb_entries),
1075                        scratch_buffer_offset,
1076                    });
1077                }
1078            }
1079        }
1080    }
1081    Ok(())
1082}
1083
1084fn map_blas<'a>(
1085    storage: &'a BlasStore<'_>,
1086    scratch_buffer: &'a dyn hal::DynBuffer,
1087    snatch_guard: &'a SnatchGuard,
1088    blases_compactable: &mut Vec<(
1089        &'a dyn hal::DynBuffer,
1090        &'a dyn hal::DynAccelerationStructure,
1091    )>,
1092) -> Result<
1093    hal::BuildAccelerationStructureDescriptor<
1094        'a,
1095        dyn hal::DynBuffer,
1096        dyn hal::DynAccelerationStructure,
1097    >,
1098    BuildAccelerationStructureError,
1099> {
1100    let BlasStore {
1101        blas,
1102        entries,
1103        scratch_buffer_offset,
1104    } = storage;
1105    if blas.update_mode == wgt::AccelerationStructureUpdateMode::PreferUpdate {
1106        log::debug!("only rebuild implemented")
1107    }
1108    let raw = blas.try_raw(snatch_guard)?;
1109
1110    let state_lock = blas.compacted_state.lock();
1111    if let BlasCompactState::Compacted = *state_lock {
1112        return Err(BuildAccelerationStructureError::CompactedBlas(
1113            blas.error_ident(),
1114        ));
1115    }
1116
1117    if blas
1118        .flags
1119        .contains(wgpu_types::AccelerationStructureFlags::ALLOW_COMPACTION)
1120    {
1121        blases_compactable.push((blas.compaction_buffer.as_ref().unwrap().as_ref(), raw));
1122    }
1123    Ok(hal::BuildAccelerationStructureDescriptor {
1124        entries,
1125        mode: hal::AccelerationStructureBuildMode::Build,
1126        flags: blas.flags,
1127        source_acceleration_structure: None,
1128        destination_acceleration_structure: raw,
1129        scratch_buffer,
1130        scratch_buffer_offset: *scratch_buffer_offset,
1131    })
1132}
1133
1134fn build_blas<'a>(
1135    cmd_buf_raw: &mut dyn hal::DynCommandEncoder,
1136    blas_present: bool,
1137    tlas_present: bool,
1138    input_barriers: Vec<hal::BufferBarrier<dyn hal::DynBuffer>>,
1139    blas_descriptors: &[hal::BuildAccelerationStructureDescriptor<
1140        'a,
1141        dyn hal::DynBuffer,
1142        dyn hal::DynAccelerationStructure,
1143    >],
1144    scratch_buffer_barrier: hal::BufferBarrier<dyn hal::DynBuffer>,
1145    blas_s_for_compaction: Vec<(
1146        &'a dyn hal::DynBuffer,
1147        &'a dyn hal::DynAccelerationStructure,
1148    )>,
1149) {
1150    unsafe {
1151        cmd_buf_raw.transition_buffers(&input_barriers);
1152    }
1153
1154    if blas_present {
1155        unsafe {
1156            cmd_buf_raw.place_acceleration_structure_barrier(hal::AccelerationStructureBarrier {
1157                usage: hal::StateTransition {
1158                    from: hal::AccelerationStructureUses::BUILD_INPUT,
1159                    to: hal::AccelerationStructureUses::BUILD_OUTPUT,
1160                },
1161            });
1162
1163            cmd_buf_raw.build_acceleration_structures(blas_descriptors);
1164        }
1165    }
1166
1167    if blas_present && tlas_present {
1168        unsafe {
1169            cmd_buf_raw.transition_buffers(&[scratch_buffer_barrier]);
1170        }
1171    }
1172
1173    let mut source_usage = hal::AccelerationStructureUses::empty();
1174    let mut destination_usage = hal::AccelerationStructureUses::empty();
1175    for &(buf, blas) in blas_s_for_compaction.iter() {
1176        unsafe {
1177            cmd_buf_raw.transition_buffers(&[hal::BufferBarrier {
1178                buffer: buf,
1179                usage: hal::StateTransition {
1180                    from: BufferUses::ACCELERATION_STRUCTURE_QUERY,
1181                    to: BufferUses::ACCELERATION_STRUCTURE_QUERY,
1182                },
1183            }])
1184        }
1185        unsafe { cmd_buf_raw.read_acceleration_structure_compact_size(blas, buf) }
1186        destination_usage |= hal::AccelerationStructureUses::COPY_SRC;
1187    }
1188
1189    if blas_present {
1190        source_usage |= hal::AccelerationStructureUses::BUILD_OUTPUT;
1191        destination_usage |= hal::AccelerationStructureUses::BUILD_INPUT
1192    }
1193    if tlas_present {
1194        source_usage |= hal::AccelerationStructureUses::SHADER_INPUT;
1195        destination_usage |= hal::AccelerationStructureUses::BUILD_OUTPUT;
1196    }
1197    unsafe {
1198        cmd_buf_raw.place_acceleration_structure_barrier(hal::AccelerationStructureBarrier {
1199            usage: hal::StateTransition {
1200                from: source_usage,
1201                to: destination_usage,
1202            },
1203        });
1204    }
1205}