wgpu_core/command/
ray_tracing.rs

1use alloc::{sync::Arc, vec::Vec};
2use core::{
3    cmp::max,
4    num::NonZeroU64,
5    ops::{Deref, Range},
6};
7
8use wgt::{math::align_to, BufferUsages, BufferUses, Features};
9
10use crate::{
11    command::encoder::EncodingState,
12    ray_tracing::{AsAction, AsBuild, TlasBuild, ValidateAsActionsError},
13    resource::InvalidResourceError,
14};
15use crate::{command::EncoderStateError, device::resource::CommandIndices};
16use crate::{
17    command::{ArcCommand, ArcReferences, CommandBufferMutable},
18    device::queue::TempResource,
19    global::Global,
20    id::CommandEncoderId,
21    init_tracker::MemoryInitKind,
22    ray_tracing::{
23        ArcBlasAabbGeometry, ArcBlasBuildEntry, ArcBlasGeometries, ArcBlasTriangleGeometry,
24        ArcTlasInstance, ArcTlasPackage, BlasBuildEntry, BlasGeometries,
25        BuildAccelerationStructureError, OwnedBlasBuildEntry, OwnedTlasPackage, TlasPackage,
26    },
27    resource::{Blas, BlasCompactState, Labeled, StagingBuffer, Tlas},
28    scratch::ScratchBuffer,
29    snatch::SnatchGuard,
30};
31use crate::{lock::RwLockWriteGuard, resource::RawResourceAccess};
32
33use crate::id::{BlasId, TlasId};
34
35struct BlasStore<'a> {
36    blas: Arc<Blas>,
37    entries: hal::AccelerationStructureEntries<'a, dyn hal::DynBuffer>,
38    scratch_buffer_offset: u64,
39}
40
41struct UnsafeTlasStore<'a> {
42    tlas: Arc<Tlas>,
43    entries: hal::AccelerationStructureEntries<'a, dyn hal::DynBuffer>,
44    scratch_buffer_offset: u64,
45}
46
47struct TlasStore<'a> {
48    internal: UnsafeTlasStore<'a>,
49    range: Range<usize>,
50}
51
52impl Global {
53    fn resolve_blas_id(&self, blas_id: BlasId) -> Result<Arc<Blas>, InvalidResourceError> {
54        self.hub.blas_s.get(blas_id).get()
55    }
56
57    fn resolve_tlas_id(&self, tlas_id: TlasId) -> Result<Arc<Tlas>, InvalidResourceError> {
58        self.hub.tlas_s.get(tlas_id).get()
59    }
60
61    pub fn command_encoder_mark_acceleration_structures_built(
62        &self,
63        command_encoder_id: CommandEncoderId,
64        blas_ids: &[BlasId],
65        tlas_ids: &[TlasId],
66    ) -> Result<(), EncoderStateError> {
67        profiling::scope!("CommandEncoder::mark_acceleration_structures_built");
68
69        let hub = &self.hub;
70
71        let cmd_enc = hub.command_encoders.get(command_encoder_id);
72
73        let mut cmd_buf_data = cmd_enc.data.lock();
74        cmd_buf_data.with_buffer(
75            crate::command::EncodingApi::Raw,
76            |cmd_buf_data| -> Result<(), BuildAccelerationStructureError> {
77                let device = &cmd_enc.device;
78                device.check_is_valid()?;
79                device.require_features(Features::EXPERIMENTAL_RAY_QUERY)?;
80
81                let mut build_command = AsBuild::with_capacity(blas_ids.len(), tlas_ids.len());
82
83                for blas in blas_ids {
84                    let blas = hub.blas_s.get(*blas).get()?;
85                    build_command.blas_s_built.push(blas);
86                }
87
88                for tlas in tlas_ids {
89                    let tlas = hub.tlas_s.get(*tlas).get()?;
90                    build_command.tlas_s_built.push(TlasBuild {
91                        tlas,
92                        dependencies: Vec::new(),
93                    });
94                }
95
96                cmd_buf_data.as_actions.push(AsAction::Build(build_command));
97                Ok(())
98            },
99        )
100    }
101
102    pub fn command_encoder_build_acceleration_structures<'a>(
103        &self,
104        command_encoder_id: CommandEncoderId,
105        blas_iter: impl Iterator<Item = BlasBuildEntry<'a>>,
106        tlas_iter: impl Iterator<Item = TlasPackage<'a>>,
107    ) -> Result<(), EncoderStateError> {
108        profiling::scope!("CommandEncoder::build_acceleration_structures");
109
110        let hub = &self.hub;
111
112        let cmd_enc = hub.command_encoders.get(command_encoder_id);
113        let mut cmd_buf_data = cmd_enc.data.lock();
114
115        cmd_buf_data.push_with(|| -> Result<_, BuildAccelerationStructureError> {
116            let blas = blas_iter
117                .map(|blas_entry| {
118                    let geometries = match blas_entry.geometries {
119                        BlasGeometries::TriangleGeometries(triangle_geometries) => {
120                            let tri_geo = triangle_geometries
121                                .map(|tg| {
122                                    Ok(ArcBlasTriangleGeometry {
123                                        size: tg.size.clone(),
124                                        vertex_buffer: self.resolve_buffer_id(tg.vertex_buffer)?,
125                                        index_buffer: tg
126                                            .index_buffer
127                                            .map(|id| self.resolve_buffer_id(id))
128                                            .transpose()?,
129                                        transform_buffer: tg
130                                            .transform_buffer
131                                            .map(|id| self.resolve_buffer_id(id))
132                                            .transpose()?,
133                                        first_vertex: tg.first_vertex,
134                                        vertex_stride: tg.vertex_stride,
135                                        first_index: tg.first_index,
136                                        transform_buffer_offset: tg.transform_buffer_offset,
137                                    })
138                                })
139                                .collect::<Result<_, BuildAccelerationStructureError>>()?;
140                            ArcBlasGeometries::TriangleGeometries(tri_geo)
141                        }
142                        BlasGeometries::AabbGeometries(aabb_geometries) => {
143                            let aabb_geo = aabb_geometries
144                                .map(|ag| {
145                                    Ok(ArcBlasAabbGeometry {
146                                        size: ag.size.clone(),
147                                        stride: ag.stride,
148                                        aabb_buffer: self.resolve_buffer_id(ag.aabb_buffer)?,
149                                        primitive_offset: ag.primitive_offset,
150                                    })
151                                })
152                                .collect::<Result<_, BuildAccelerationStructureError>>()?;
153                            ArcBlasGeometries::AabbGeometries(aabb_geo)
154                        }
155                    };
156                    Ok(ArcBlasBuildEntry {
157                        blas: self.resolve_blas_id(blas_entry.blas_id)?,
158                        geometries,
159                    })
160                })
161                .collect::<Result<_, BuildAccelerationStructureError>>()?;
162
163            let tlas = tlas_iter
164                .map(|tlas_package| {
165                    let instances = tlas_package
166                        .instances
167                        .map(|instance| {
168                            instance
169                                .as_ref()
170                                .map(|instance| {
171                                    Ok(ArcTlasInstance {
172                                        blas: self.resolve_blas_id(instance.blas_id)?,
173                                        transform: *instance.transform,
174                                        custom_data: instance.custom_data,
175                                        mask: instance.mask,
176                                    })
177                                })
178                                .transpose()
179                        })
180                        .collect::<Result<_, BuildAccelerationStructureError>>()?;
181                    Ok(ArcTlasPackage {
182                        tlas: self.resolve_tlas_id(tlas_package.tlas_id)?,
183                        instances,
184                        lowest_unmodified: tlas_package.lowest_unmodified,
185                    })
186                })
187                .collect::<Result<_, BuildAccelerationStructureError>>()?;
188
189            Ok(ArcCommand::BuildAccelerationStructures { blas, tlas })
190        })
191    }
192}
193
194pub(crate) fn build_acceleration_structures(
195    state: &mut EncodingState,
196    blas: Vec<OwnedBlasBuildEntry<ArcReferences>>,
197    tlas: Vec<OwnedTlasPackage<ArcReferences>>,
198) -> Result<(), BuildAccelerationStructureError> {
199    state
200        .device
201        .require_features(Features::EXPERIMENTAL_RAY_QUERY)?;
202
203    let mut build_command = AsBuild::with_capacity(blas.len(), tlas.len());
204    let mut input_barriers = Vec::<hal::BufferBarrier<dyn hal::DynBuffer>>::new();
205    let mut scratch_buffer_blas_size = 0;
206    let mut blas_storage = Vec::with_capacity(blas.len());
207    iter_blas(
208        blas.iter(),
209        &mut build_command,
210        &mut input_barriers,
211        &mut scratch_buffer_blas_size,
212        &mut blas_storage,
213        state,
214    )?;
215
216    let mut scratch_buffer_tlas_size: u64 = 0;
217    let mut tlas_storage = Vec::<TlasStore>::with_capacity(tlas.len());
218    let mut instance_buffer_staging_source = Vec::<u8>::new();
219
220    for package in tlas.iter() {
221        let tlas = &package.tlas;
222        state.tracker.tlas_s.insert_single(tlas.clone());
223
224        let scratch_buffer_offset = scratch_buffer_tlas_size;
225        scratch_buffer_tlas_size = scratch_buffer_tlas_size.saturating_add(align_to(
226            tlas.size_info.build_scratch_size,
227            u64::from(state.device.alignments.ray_tracing_scratch_buffer_alignment),
228        ));
229
230        let first_byte_index = instance_buffer_staging_source.len();
231
232        let mut dependencies = Vec::new();
233
234        let mut instance_count = 0;
235        for instance in package.instances.iter().flatten() {
236            if instance.custom_data >= (1u32 << 24u32) {
237                return Err(BuildAccelerationStructureError::TlasInvalidCustomIndex(
238                    tlas.error_ident(),
239                ));
240            }
241            let blas = &instance.blas;
242            state.tracker.blas_s.insert_single(blas.clone());
243
244            instance_buffer_staging_source.extend(state.device.raw().tlas_instance_to_bytes(
245                hal::TlasInstance {
246                    transform: instance.transform,
247                    custom_data: instance.custom_data,
248                    mask: instance.mask,
249                    blas_address: blas.handle,
250                    pipeline_intersection_data_offset: 0,
251                },
252            ));
253
254            if tlas
255                .flags
256                .contains(wgpu_types::AccelerationStructureFlags::ALLOW_RAY_HIT_VERTEX_RETURN)
257                && !blas
258                    .flags
259                    .contains(wgpu_types::AccelerationStructureFlags::ALLOW_RAY_HIT_VERTEX_RETURN)
260            {
261                return Err(
262                    BuildAccelerationStructureError::TlasDependentMissingVertexReturn(
263                        tlas.error_ident(),
264                        blas.error_ident(),
265                    ),
266                );
267            }
268
269            instance_count += 1;
270
271            dependencies.push(blas.clone());
272        }
273
274        build_command.tlas_s_built.push(TlasBuild {
275            tlas: tlas.clone(),
276            dependencies,
277        });
278
279        if instance_count > tlas.max_instance_count {
280            return Err(BuildAccelerationStructureError::TlasInstanceCountExceeded(
281                tlas.error_ident(),
282                instance_count,
283                tlas.max_instance_count,
284            ));
285        }
286
287        tlas_storage.push(TlasStore {
288            internal: UnsafeTlasStore {
289                tlas: tlas.clone(),
290                entries: hal::AccelerationStructureEntries::Instances(
291                    hal::AccelerationStructureInstances {
292                        buffer: Some(tlas.instance_buffer.as_ref()),
293                        offset: 0,
294                        count: instance_count,
295                    },
296                ),
297                scratch_buffer_offset,
298            },
299            range: first_byte_index..instance_buffer_staging_source.len(),
300        });
301    }
302
303    let Some(scratch_size) =
304        wgt::BufferSize::new(max(scratch_buffer_blas_size, scratch_buffer_tlas_size))
305    else {
306        // if the size is zero there is nothing to build
307        return Ok(());
308    };
309
310    if scratch_size.get() == u64::MAX {
311        return Err(crate::device::DeviceError::OutOfMemory.into());
312    }
313
314    let scratch_buffer = ScratchBuffer::new(state.device, scratch_size)?;
315
316    let scratch_buffer_barrier = hal::BufferBarrier::<dyn hal::DynBuffer> {
317        buffer: scratch_buffer.raw(),
318        usage: hal::StateTransition {
319            from: BufferUses::ACCELERATION_STRUCTURE_SCRATCH,
320            to: BufferUses::ACCELERATION_STRUCTURE_SCRATCH,
321        },
322    };
323
324    let mut tlas_descriptors = Vec::with_capacity(tlas_storage.len());
325
326    for &TlasStore {
327        internal:
328            UnsafeTlasStore {
329                ref tlas,
330                ref entries,
331                ref scratch_buffer_offset,
332            },
333        ..
334    } in &tlas_storage
335    {
336        if tlas.update_mode == wgt::AccelerationStructureUpdateMode::PreferUpdate {
337            log::warn!("build_acceleration_structures called with PreferUpdate, but only rebuild is implemented");
338        }
339        tlas_descriptors.push(hal::BuildAccelerationStructureDescriptor {
340            entries,
341            mode: hal::AccelerationStructureBuildMode::Build,
342            flags: tlas.flags,
343            source_acceleration_structure: None,
344            destination_acceleration_structure: tlas.try_raw(state.snatch_guard)?,
345            scratch_buffer: scratch_buffer.raw(),
346            scratch_buffer_offset: *scratch_buffer_offset,
347        })
348    }
349
350    let blas_present = !blas_storage.is_empty();
351    let tlas_present = !tlas_storage.is_empty();
352
353    let raw_encoder = &mut state.raw_encoder;
354
355    let mut blas_s_compactable = Vec::new();
356    let mut descriptors = Vec::with_capacity(blas.len());
357
358    for storage in &blas_storage {
359        descriptors.push(map_blas(
360            storage,
361            scratch_buffer.raw(),
362            state.snatch_guard,
363            &mut blas_s_compactable,
364        )?);
365    }
366
367    build_blas(
368        *raw_encoder,
369        blas_present,
370        tlas_present,
371        input_barriers,
372        &descriptors,
373        scratch_buffer_barrier,
374        blas_s_compactable,
375    );
376
377    if tlas_present {
378        let staging_buffer = if !instance_buffer_staging_source.is_empty() {
379            let mut staging_buffer = StagingBuffer::new(
380                state.device,
381                wgt::BufferSize::new(instance_buffer_staging_source.len() as u64).unwrap(),
382            )?;
383            staging_buffer.write(&instance_buffer_staging_source);
384            let flushed = staging_buffer.flush();
385            Some(flushed)
386        } else {
387            None
388        };
389
390        unsafe {
391            if let Some(ref staging_buffer) = staging_buffer {
392                raw_encoder.transition_buffers(&[hal::BufferBarrier::<dyn hal::DynBuffer> {
393                    buffer: staging_buffer.raw(),
394                    usage: hal::StateTransition {
395                        from: BufferUses::MAP_WRITE,
396                        to: BufferUses::COPY_SRC,
397                    },
398                }]);
399            }
400        }
401
402        let mut instance_buffer_barriers = Vec::new();
403        for &TlasStore {
404            internal: UnsafeTlasStore { ref tlas, .. },
405            ref range,
406        } in &tlas_storage
407        {
408            let size = match wgt::BufferSize::new((range.end - range.start) as u64) {
409                None => continue,
410                Some(size) => size,
411            };
412            instance_buffer_barriers.push(hal::BufferBarrier::<dyn hal::DynBuffer> {
413                buffer: tlas.instance_buffer.as_ref(),
414                usage: hal::StateTransition {
415                    from: BufferUses::COPY_DST,
416                    to: BufferUses::TOP_LEVEL_ACCELERATION_STRUCTURE_INPUT,
417                },
418            });
419            unsafe {
420                raw_encoder.transition_buffers(&[hal::BufferBarrier::<dyn hal::DynBuffer> {
421                    buffer: tlas.instance_buffer.as_ref(),
422                    usage: hal::StateTransition {
423                        from: BufferUses::TOP_LEVEL_ACCELERATION_STRUCTURE_INPUT,
424                        to: BufferUses::COPY_DST,
425                    },
426                }]);
427                let temp = hal::BufferCopy {
428                    src_offset: range.start as u64,
429                    dst_offset: 0,
430                    size,
431                };
432                raw_encoder.copy_buffer_to_buffer(
433                    staging_buffer.as_ref().unwrap().raw(),
434                    tlas.instance_buffer.as_ref(),
435                    &[temp],
436                );
437            }
438        }
439
440        unsafe {
441            raw_encoder.transition_buffers(&instance_buffer_barriers);
442
443            raw_encoder.build_acceleration_structures(&tlas_descriptors);
444
445            raw_encoder.place_acceleration_structure_barrier(hal::AccelerationStructureBarrier {
446                usage: hal::StateTransition {
447                    from: hal::AccelerationStructureUses::BUILD_OUTPUT,
448                    to: hal::AccelerationStructureUses::SHADER_INPUT,
449                },
450            });
451        }
452
453        if let Some(staging_buffer) = staging_buffer {
454            state
455                .temp_resources
456                .push(TempResource::StagingBuffer(staging_buffer));
457        }
458    }
459
460    state
461        .temp_resources
462        .push(TempResource::ScratchBuffer(scratch_buffer));
463
464    state.as_actions.push(AsAction::Build(build_command));
465
466    Ok(())
467}
468
469impl CommandBufferMutable {
470    pub(crate) fn validate_acceleration_structure_actions(
471        &self,
472        snatch_guard: &SnatchGuard,
473        command_index_guard: &mut RwLockWriteGuard<CommandIndices>,
474    ) -> Result<(), ValidateAsActionsError> {
475        profiling::scope!("CommandEncoder::[submission]::validate_as_actions");
476        for action in &self.as_actions {
477            match action {
478                AsAction::Build(build) => {
479                    let build_command_index = NonZeroU64::new(
480                        command_index_guard.next_acceleration_structure_build_command_index,
481                    )
482                    .unwrap();
483
484                    command_index_guard.next_acceleration_structure_build_command_index += 1;
485                    for blas in build.blas_s_built.iter() {
486                        let mut state_lock = blas.compacted_state.lock();
487                        *state_lock = match *state_lock {
488                            BlasCompactState::Compacted => {
489                                unreachable!("Should be validated out in build.")
490                            }
491                            // Reset the compacted state to idle. This means any prepares, before mapping their
492                            // internal buffer, will terminate.
493                            _ => BlasCompactState::Idle,
494                        };
495                        *blas.built_index.write() = Some(build_command_index);
496                    }
497
498                    for tlas_build in build.tlas_s_built.iter() {
499                        for blas in &tlas_build.dependencies {
500                            if blas.built_index.read().is_none() {
501                                return Err(ValidateAsActionsError::UsedUnbuiltBlas(
502                                    blas.error_ident(),
503                                    tlas_build.tlas.error_ident(),
504                                ));
505                            }
506                        }
507                        *tlas_build.tlas.built_index.write() = Some(build_command_index);
508                        tlas_build
509                            .tlas
510                            .dependencies
511                            .write()
512                            .clone_from(&tlas_build.dependencies)
513                    }
514                }
515                AsAction::UseTlas(tlas) => {
516                    let tlas_build_index = tlas.built_index.read();
517                    let dependencies = tlas.dependencies.read();
518
519                    if (*tlas_build_index).is_none() {
520                        return Err(ValidateAsActionsError::UsedUnbuiltTlas(tlas.error_ident()));
521                    }
522                    for blas in dependencies.deref() {
523                        let blas_build_index = *blas.built_index.read();
524                        if blas_build_index.is_none() {
525                            return Err(ValidateAsActionsError::UsedUnbuiltBlas(
526                                tlas.error_ident(),
527                                blas.error_ident(),
528                            ));
529                        }
530                        if blas_build_index.unwrap() > tlas_build_index.unwrap() {
531                            return Err(ValidateAsActionsError::BlasNewerThenTlas(
532                                blas.error_ident(),
533                                tlas.error_ident(),
534                            ));
535                        }
536                        blas.try_raw(snatch_guard)?;
537                    }
538                }
539            }
540        }
541        Ok(())
542    }
543
544    pub(crate) fn set_acceleration_structure_dependencies(&self, snatch_guard: &SnatchGuard) {
545        profiling::scope!("CommandEncoder::[submission]::set_acceleration_structure_dependencies");
546        let tlas_dependencies_locks: Vec<_> = self
547            .as_actions
548            .iter()
549            .filter_map(|action| {
550                if let AsAction::UseTlas(tlas) = action {
551                    Some(tlas.dependencies.read())
552                } else {
553                    None
554                }
555            })
556            .collect();
557        let mut tlas_dependencies_lock_iter = tlas_dependencies_locks.iter();
558        let mut dependencies = Vec::new();
559        for action in &self.as_actions {
560            match action {
561                AsAction::Build(build) => {
562                    for tlas_build in build.tlas_s_built.iter() {
563                        for dependency in &tlas_build.dependencies {
564                            if let Some(dependency) = dependency.raw(snatch_guard) {
565                                dependencies.push(dependency);
566                            }
567                        }
568                    }
569                }
570                AsAction::UseTlas(_tlas) => {
571                    let tlas_dependencies = tlas_dependencies_lock_iter.next().unwrap(); // _tlas.dependencies.read();
572                    for dependency in tlas_dependencies.iter() {
573                        if let Some(dependency) = dependency.raw(snatch_guard) {
574                            dependencies.push(dependency);
575                        }
576                    }
577                }
578            }
579        }
580        if !dependencies.is_empty() {
581            unsafe {
582                self.encoder
583                    .raw
584                    .set_acceleration_structure_dependencies(&self.encoder.list, &dependencies);
585            }
586        }
587    }
588}
589
590///iterates over the blas iterator, and it's geometry, pushing the buffers into a storage vector (and also some validation).
591fn iter_blas<'snatch_guard: 'buffers, 'buffers>(
592    blas_iter: impl Iterator<Item = &'buffers OwnedBlasBuildEntry<ArcReferences>>,
593    build_command: &mut AsBuild,
594    input_barriers: &mut Vec<hal::BufferBarrier<'buffers, dyn hal::DynBuffer>>,
595    scratch_buffer_blas_size: &mut u64,
596    blas_storage: &mut Vec<BlasStore<'buffers>>,
597    state: &mut EncodingState<'snatch_guard, '_>,
598) -> Result<(), BuildAccelerationStructureError> {
599    for entry in blas_iter {
600        let blas = &entry.blas;
601        state.tracker.blas_s.insert_single(blas.clone());
602
603        build_command.blas_s_built.push(blas.clone());
604
605        match &entry.geometries {
606            ArcBlasGeometries::TriangleGeometries(triangle_geometries) => {
607                let mut triangle_entries =
608                    Vec::<hal::AccelerationStructureTriangles<dyn hal::DynBuffer>>::new();
609
610                for (i, mesh) in triangle_geometries.iter().enumerate() {
611                    let size_desc = match &blas.sizes {
612                        wgt::BlasGeometrySizeDescriptors::Triangles { descriptors } => descriptors,
613                        _ => {
614                            return Err(BuildAccelerationStructureError::BlasGeometryKindMismatch(
615                                blas.error_ident(),
616                            ));
617                        }
618                    };
619                    if i >= size_desc.len() {
620                        return Err(BuildAccelerationStructureError::IncompatibleBlasBuildSizes(
621                            blas.error_ident(),
622                        ));
623                    }
624                    let size_desc = &size_desc[i];
625
626                    if size_desc.flags != mesh.size.flags {
627                        return Err(BuildAccelerationStructureError::IncompatibleBlasFlags(
628                            blas.error_ident(),
629                            size_desc.flags,
630                            mesh.size.flags,
631                        ));
632                    }
633
634                    if size_desc.vertex_count < mesh.size.vertex_count {
635                        return Err(
636                            BuildAccelerationStructureError::IncompatibleBlasVertexCount(
637                                blas.error_ident(),
638                                size_desc.vertex_count,
639                                mesh.size.vertex_count,
640                            ),
641                        );
642                    }
643
644                    if size_desc.vertex_format != mesh.size.vertex_format {
645                        return Err(BuildAccelerationStructureError::DifferentBlasVertexFormats(
646                            blas.error_ident(),
647                            size_desc.vertex_format,
648                            mesh.size.vertex_format,
649                        ));
650                    }
651
652                    if size_desc
653                        .vertex_format
654                        .min_acceleration_structure_vertex_stride()
655                        > mesh.vertex_stride
656                    {
657                        return Err(BuildAccelerationStructureError::VertexStrideTooSmall(
658                            blas.error_ident(),
659                            size_desc
660                                .vertex_format
661                                .min_acceleration_structure_vertex_stride(),
662                            mesh.vertex_stride,
663                        ));
664                    }
665
666                    if mesh.vertex_stride
667                        % size_desc
668                            .vertex_format
669                            .acceleration_structure_stride_alignment()
670                        != 0
671                    {
672                        return Err(BuildAccelerationStructureError::VertexStrideUnaligned(
673                            blas.error_ident(),
674                            size_desc
675                                .vertex_format
676                                .acceleration_structure_stride_alignment(),
677                            mesh.vertex_stride,
678                        ));
679                    }
680
681                    match (size_desc.index_count, mesh.size.index_count) {
682                        (Some(_), None) | (None, Some(_)) => {
683                            return Err(
684                                BuildAccelerationStructureError::BlasIndexCountProvidedMismatch(
685                                    blas.error_ident(),
686                                ),
687                            )
688                        }
689                        (Some(create), Some(build)) if create < build => {
690                            return Err(
691                                BuildAccelerationStructureError::IncompatibleBlasIndexCount(
692                                    blas.error_ident(),
693                                    create,
694                                    build,
695                                ),
696                            )
697                        }
698                        _ => {}
699                    }
700
701                    if size_desc.index_format != mesh.size.index_format {
702                        return Err(BuildAccelerationStructureError::DifferentBlasIndexFormats(
703                            blas.error_ident(),
704                            size_desc.index_format,
705                            mesh.size.index_format,
706                        ));
707                    }
708
709                    if size_desc.index_count.is_some() && mesh.index_buffer.is_none() {
710                        return Err(BuildAccelerationStructureError::MissingIndexBuffer(
711                            blas.error_ident(),
712                        ));
713                    }
714                    let vertex_buffer = mesh.vertex_buffer.clone();
715                    let vertex_pending = state.tracker.buffers.set_single(
716                        &vertex_buffer,
717                        BufferUses::BOTTOM_LEVEL_ACCELERATION_STRUCTURE_INPUT,
718                    );
719                    let vertex_buffer = {
720                        let vertex_raw = mesh.vertex_buffer.as_ref().try_raw(state.snatch_guard)?;
721                        let vertex_buffer = &mesh.vertex_buffer;
722                        vertex_buffer.check_usage(BufferUsages::BLAS_INPUT)?;
723
724                        if let Some(barrier) = vertex_pending.map(|pending| {
725                            pending.into_hal(vertex_buffer.as_ref(), state.snatch_guard)
726                        }) {
727                            input_barriers.push(barrier);
728                        }
729                        if u64::from(mesh.size.vertex_count)
730                            .checked_add(u64::from(mesh.first_vertex))
731                            .and_then(|end_vertex| end_vertex.checked_mul(mesh.vertex_stride))
732                            .is_none_or(|end| vertex_buffer.size < end)
733                        {
734                            return Err(BuildAccelerationStructureError::InsufficientBufferSize {
735                                buffer_ident: vertex_buffer.error_ident(),
736                                offset: u64::from(mesh.first_vertex)
737                                    .saturating_mul(mesh.vertex_stride),
738                                region_size: u64::from(mesh.size.vertex_count)
739                                    .saturating_mul(mesh.vertex_stride),
740                                buffer_size: vertex_buffer.size,
741                            });
742                        }
743                        let vertex_buffer_offset = mesh.first_vertex as u64 * mesh.vertex_stride;
744                        state.buffer_memory_init_actions.extend(
745                            vertex_buffer.initialization_status.read().create_action(
746                                vertex_buffer,
747                                vertex_buffer_offset
748                                    ..(vertex_buffer_offset
749                                        + mesh.size.vertex_count as u64 * mesh.vertex_stride),
750                                MemoryInitKind::NeedsInitializedMemory,
751                            ),
752                        );
753                        vertex_raw
754                    };
755                    let index_buffer = if let Some(ref index_buffer) = mesh.index_buffer {
756                        if mesh.first_index.is_none()
757                            || mesh.size.index_count.is_none()
758                            || mesh.size.index_count.is_none()
759                        {
760                            return Err(BuildAccelerationStructureError::MissingAssociatedData(
761                                index_buffer.error_ident(),
762                            ));
763                        }
764                        let index_pending = state.tracker.buffers.set_single(
765                            index_buffer,
766                            BufferUses::BOTTOM_LEVEL_ACCELERATION_STRUCTURE_INPUT,
767                        );
768                        let index_raw = index_buffer.try_raw(state.snatch_guard)?;
769                        index_buffer.check_usage(BufferUsages::BLAS_INPUT)?;
770
771                        if let Some(barrier) = index_pending.map(|pending| {
772                            pending.into_hal(index_buffer.as_ref(), state.snatch_guard)
773                        }) {
774                            input_barriers.push(barrier);
775                        }
776                        let index_stride = mesh.size.index_format.unwrap().byte_size();
777                        // `hal::AccelerationStructureTriangleIndices` accepts only `u32` offset
778                        let Some(vertex_offset) =
779                            mesh.first_index.unwrap().checked_mul(index_stride)
780                        else {
781                            return Err(BuildAccelerationStructureError::OffsetLimitedTo4GB {
782                                buffer_ident: index_buffer.error_ident(),
783                                offset: u64::from(mesh.first_index.unwrap())
784                                    .saturating_mul(u64::from(index_stride)),
785                                count: u64::from(mesh.first_index.unwrap()),
786                                stride: u64::from(index_stride),
787                            });
788                        };
789                        let Some(indexes_size) =
790                            mesh.size.index_count.unwrap().checked_mul(index_stride)
791                        else {
792                            return Err(BuildAccelerationStructureError::OffsetLimitedTo4GB {
793                                buffer_ident: index_buffer.error_ident(),
794                                offset: u64::from(mesh.size.index_count.unwrap())
795                                    .saturating_mul(u64::from(index_stride)),
796                                count: u64::from(mesh.size.index_count.unwrap()),
797                                stride: u64::from(index_stride),
798                            });
799                        };
800
801                        if mesh.size.index_count.unwrap() % 3 != 0 {
802                            return Err(BuildAccelerationStructureError::InvalidIndexCount(
803                                index_buffer.error_ident(),
804                                mesh.size.index_count.unwrap(),
805                            ));
806                        }
807                        if index_buffer.size < u64::from(vertex_offset)
808                            || index_buffer.size - u64::from(vertex_offset)
809                                < u64::from(indexes_size)
810                        {
811                            return Err(BuildAccelerationStructureError::InsufficientBufferSize {
812                                buffer_ident: index_buffer.error_ident(),
813                                offset: u64::from(vertex_offset),
814                                region_size: u64::from(indexes_size),
815                                buffer_size: index_buffer.size,
816                            });
817                        }
818
819                        state.buffer_memory_init_actions.extend(
820                            index_buffer.initialization_status.read().create_action(
821                                index_buffer,
822                                u64::from(vertex_offset)
823                                    ..(u64::from(vertex_offset) + u64::from(indexes_size)),
824                                MemoryInitKind::NeedsInitializedMemory,
825                            ),
826                        );
827                        Some((index_raw, vertex_offset))
828                    } else {
829                        None
830                    };
831                    let transform_buffer = if let Some(ref transform_buffer) = mesh.transform_buffer
832                    {
833                        if !blas
834                            .flags
835                            .contains(wgt::AccelerationStructureFlags::USE_TRANSFORM)
836                        {
837                            return Err(BuildAccelerationStructureError::UseTransformMissing(
838                                blas.error_ident(),
839                            ));
840                        }
841                        if mesh.transform_buffer_offset.is_none() {
842                            return Err(BuildAccelerationStructureError::MissingAssociatedData(
843                                transform_buffer.error_ident(),
844                            ));
845                        }
846                        let transform_pending = state.tracker.buffers.set_single(
847                            transform_buffer,
848                            BufferUses::BOTTOM_LEVEL_ACCELERATION_STRUCTURE_INPUT,
849                        );
850                        if mesh.transform_buffer_offset.is_none() {
851                            return Err(BuildAccelerationStructureError::MissingAssociatedData(
852                                transform_buffer.error_ident(),
853                            ));
854                        }
855                        let transform_raw = transform_buffer.try_raw(state.snatch_guard)?;
856                        transform_buffer.check_usage(BufferUsages::BLAS_INPUT)?;
857
858                        if let Some(barrier) = transform_pending.map(|pending| {
859                            pending.into_hal(transform_buffer.as_ref(), state.snatch_guard)
860                        }) {
861                            input_barriers.push(barrier);
862                        }
863
864                        // `hal::AccelerationStructureTriangleTransform` accepts only `u32` offset
865                        let Ok(offset) = u32::try_from(mesh.transform_buffer_offset.unwrap())
866                        else {
867                            return Err(BuildAccelerationStructureError::OffsetLimitedTo4GB {
868                                buffer_ident: transform_buffer.error_ident(),
869                                offset: mesh.transform_buffer_offset.unwrap(),
870                                count: mesh.transform_buffer_offset.unwrap(),
871                                stride: 1,
872                            });
873                        };
874
875                        if offset % wgt::TRANSFORM_BUFFER_ALIGNMENT as u32 != 0 {
876                            return Err(
877                                BuildAccelerationStructureError::UnalignedTransformBufferOffset(
878                                    transform_buffer.error_ident(),
879                                ),
880                            );
881                        }
882                        if transform_buffer.size < 48
883                            || transform_buffer.size - 48 < u64::from(offset)
884                        {
885                            return Err(BuildAccelerationStructureError::InsufficientBufferSize {
886                                buffer_ident: transform_buffer.error_ident(),
887                                region_size: 48,
888                                offset: u64::from(offset),
889                                buffer_size: transform_buffer.size,
890                            });
891                        }
892                        state.buffer_memory_init_actions.extend(
893                            transform_buffer.initialization_status.read().create_action(
894                                transform_buffer,
895                                u64::from(offset)..(u64::from(offset) + 48),
896                                MemoryInitKind::NeedsInitializedMemory,
897                            ),
898                        );
899                        Some((transform_raw, offset))
900                    } else {
901                        if blas
902                            .flags
903                            .contains(wgt::AccelerationStructureFlags::USE_TRANSFORM)
904                        {
905                            return Err(BuildAccelerationStructureError::TransformMissing(
906                                blas.error_ident(),
907                            ));
908                        }
909                        None
910                    };
911
912                    let triangles = hal::AccelerationStructureTriangles {
913                        vertex_buffer: Some(vertex_buffer),
914                        vertex_format: mesh.size.vertex_format,
915                        first_vertex: mesh.first_vertex,
916                        vertex_count: mesh.size.vertex_count,
917                        vertex_stride: mesh.vertex_stride,
918                        indices: index_buffer.map(|(index_raw, vertex_offset)| {
919                            hal::AccelerationStructureTriangleIndices::<dyn hal::DynBuffer> {
920                                format: mesh.size.index_format.unwrap(),
921                                buffer: Some(index_raw),
922                                offset: vertex_offset,
923                                count: mesh.size.index_count.unwrap(),
924                            }
925                        }),
926                        transform: transform_buffer.map(|(transform_raw, offset)| {
927                            hal::AccelerationStructureTriangleTransform {
928                                buffer: transform_raw,
929                                offset,
930                            }
931                        }),
932                        flags: mesh.size.flags,
933                    };
934                    triangle_entries.push(triangles);
935                }
936
937                {
938                    let scratch_buffer_offset = *scratch_buffer_blas_size;
939                    *scratch_buffer_blas_size = scratch_buffer_blas_size.saturating_add(align_to(
940                        blas.size_info.build_scratch_size,
941                        u64::from(state.device.alignments.ray_tracing_scratch_buffer_alignment),
942                    ));
943
944                    blas_storage.push(BlasStore {
945                        blas: blas.clone(),
946                        entries: hal::AccelerationStructureEntries::Triangles(triangle_entries),
947                        scratch_buffer_offset,
948                    });
949                }
950            }
951            ArcBlasGeometries::AabbGeometries(aabb_geometries) => {
952                let mut aabb_entries =
953                    Vec::<hal::AccelerationStructureAABBs<dyn hal::DynBuffer>>::new();
954
955                for (i, aabb) in aabb_geometries.iter().enumerate() {
956                    let size_desc = match &blas.sizes {
957                        wgt::BlasGeometrySizeDescriptors::AABBs { descriptors } => descriptors,
958                        _ => {
959                            return Err(BuildAccelerationStructureError::BlasGeometryKindMismatch(
960                                blas.error_ident(),
961                            ));
962                        }
963                    };
964                    if i >= size_desc.len() {
965                        return Err(BuildAccelerationStructureError::IncompatibleBlasBuildSizes(
966                            blas.error_ident(),
967                        ));
968                    }
969                    let size_desc = &size_desc[i];
970
971                    if size_desc.flags != aabb.size.flags {
972                        return Err(BuildAccelerationStructureError::IncompatibleBlasFlags(
973                            blas.error_ident(),
974                            size_desc.flags,
975                            aabb.size.flags,
976                        ));
977                    }
978
979                    if size_desc.primitive_count < aabb.size.primitive_count {
980                        return Err(
981                            BuildAccelerationStructureError::IncompatibleBlasAabbPrimitiveCount(
982                                blas.error_ident(),
983                                size_desc.primitive_count,
984                                aabb.size.primitive_count,
985                            ),
986                        );
987                    }
988
989                    if aabb.primitive_offset % 8 != 0 {
990                        return Err(
991                            BuildAccelerationStructureError::UnalignedAabbPrimitiveOffset(
992                                blas.error_ident(),
993                            ),
994                        );
995                    }
996
997                    if aabb.stride < wgt::AABB_GEOMETRY_MIN_STRIDE || aabb.stride % 8 != 0 {
998                        return Err(BuildAccelerationStructureError::InvalidAabbStride(
999                            blas.error_ident(),
1000                            aabb.stride,
1001                        ));
1002                    }
1003
1004                    let aabb_buffer = aabb.aabb_buffer.clone();
1005                    let aabb_pending = state.tracker.buffers.set_single(
1006                        &aabb_buffer,
1007                        BufferUses::BOTTOM_LEVEL_ACCELERATION_STRUCTURE_INPUT,
1008                    );
1009                    let aabb_raw = {
1010                        let aabb_raw = aabb.aabb_buffer.as_ref().try_raw(state.snatch_guard)?;
1011                        let aabb_buffer = &aabb.aabb_buffer;
1012                        aabb_buffer.check_usage(BufferUsages::BLAS_INPUT)?;
1013
1014                        if let Some(barrier) = aabb_pending.map(|pending| {
1015                            pending.into_hal(aabb_buffer.as_ref(), state.snatch_guard)
1016                        }) {
1017                            input_barriers.push(barrier);
1018                        }
1019
1020                        // `hal::AccelerationStructureAABBs` accepts only `u32` offset
1021                        let Some(aabb_size) = u32::try_from(aabb.stride)
1022                            .ok()
1023                            .and_then(|stride| aabb.size.primitive_count.checked_mul(stride))
1024                        else {
1025                            return Err(BuildAccelerationStructureError::OffsetLimitedTo4GB {
1026                                buffer_ident: aabb_buffer.error_ident(),
1027                                offset: u64::from(aabb.size.primitive_count)
1028                                    .saturating_mul(aabb.stride),
1029                                count: u64::from(aabb.size.primitive_count),
1030                                stride: aabb.stride,
1031                            });
1032                        };
1033
1034                        if aabb_buffer.size < u64::from(aabb.primitive_offset)
1035                            || aabb_buffer.size - u64::from(aabb.primitive_offset)
1036                                < u64::from(aabb_size)
1037                        {
1038                            return Err(BuildAccelerationStructureError::InsufficientBufferSize {
1039                                buffer_ident: aabb_buffer.error_ident(),
1040                                region_size: u64::from(aabb_size),
1041                                offset: u64::from(aabb.primitive_offset),
1042                                buffer_size: aabb_buffer.size,
1043                            });
1044                        }
1045
1046                        state.buffer_memory_init_actions.extend(
1047                            aabb_buffer.initialization_status.read().create_action(
1048                                aabb_buffer,
1049                                u64::from(aabb.primitive_offset)
1050                                    ..u64::from(aabb.primitive_offset) + u64::from(aabb_size),
1051                                MemoryInitKind::NeedsInitializedMemory,
1052                            ),
1053                        );
1054                        aabb_raw
1055                    };
1056
1057                    aabb_entries.push(hal::AccelerationStructureAABBs {
1058                        buffer: Some(aabb_raw),
1059                        offset: aabb.primitive_offset,
1060                        count: aabb.size.primitive_count,
1061                        stride: aabb.stride,
1062                        flags: aabb.size.flags,
1063                    });
1064                }
1065
1066                {
1067                    let scratch_buffer_offset = *scratch_buffer_blas_size;
1068                    *scratch_buffer_blas_size = scratch_buffer_blas_size.saturating_add(align_to(
1069                        blas.size_info.build_scratch_size,
1070                        u64::from(state.device.alignments.ray_tracing_scratch_buffer_alignment),
1071                    ));
1072
1073                    blas_storage.push(BlasStore {
1074                        blas: blas.clone(),
1075                        entries: hal::AccelerationStructureEntries::AABBs(aabb_entries),
1076                        scratch_buffer_offset,
1077                    });
1078                }
1079            }
1080        }
1081    }
1082    Ok(())
1083}
1084
1085fn map_blas<'a>(
1086    storage: &'a BlasStore<'_>,
1087    scratch_buffer: &'a dyn hal::DynBuffer,
1088    snatch_guard: &'a SnatchGuard,
1089    blases_compactable: &mut Vec<(
1090        &'a dyn hal::DynBuffer,
1091        &'a dyn hal::DynAccelerationStructure,
1092    )>,
1093) -> Result<
1094    hal::BuildAccelerationStructureDescriptor<
1095        'a,
1096        dyn hal::DynBuffer,
1097        dyn hal::DynAccelerationStructure,
1098    >,
1099    BuildAccelerationStructureError,
1100> {
1101    let BlasStore {
1102        blas,
1103        entries,
1104        scratch_buffer_offset,
1105    } = storage;
1106    if blas.update_mode == wgt::AccelerationStructureUpdateMode::PreferUpdate {
1107        log::debug!("only rebuild implemented")
1108    }
1109    let raw = blas.try_raw(snatch_guard)?;
1110
1111    let state_lock = blas.compacted_state.lock();
1112    if let BlasCompactState::Compacted = *state_lock {
1113        return Err(BuildAccelerationStructureError::CompactedBlas(
1114            blas.error_ident(),
1115        ));
1116    }
1117
1118    if blas
1119        .flags
1120        .contains(wgpu_types::AccelerationStructureFlags::ALLOW_COMPACTION)
1121    {
1122        blases_compactable.push((blas.compaction_buffer.as_ref().unwrap().as_ref(), raw));
1123    }
1124    Ok(hal::BuildAccelerationStructureDescriptor {
1125        entries,
1126        mode: hal::AccelerationStructureBuildMode::Build,
1127        flags: blas.flags,
1128        source_acceleration_structure: None,
1129        destination_acceleration_structure: raw,
1130        scratch_buffer,
1131        scratch_buffer_offset: *scratch_buffer_offset,
1132    })
1133}
1134
1135fn build_blas<'a>(
1136    cmd_buf_raw: &mut dyn hal::DynCommandEncoder,
1137    blas_present: bool,
1138    tlas_present: bool,
1139    input_barriers: Vec<hal::BufferBarrier<dyn hal::DynBuffer>>,
1140    blas_descriptors: &[hal::BuildAccelerationStructureDescriptor<
1141        'a,
1142        dyn hal::DynBuffer,
1143        dyn hal::DynAccelerationStructure,
1144    >],
1145    scratch_buffer_barrier: hal::BufferBarrier<dyn hal::DynBuffer>,
1146    blas_s_for_compaction: Vec<(
1147        &'a dyn hal::DynBuffer,
1148        &'a dyn hal::DynAccelerationStructure,
1149    )>,
1150) {
1151    unsafe {
1152        cmd_buf_raw.transition_buffers(&input_barriers);
1153    }
1154
1155    if blas_present {
1156        unsafe {
1157            cmd_buf_raw.place_acceleration_structure_barrier(hal::AccelerationStructureBarrier {
1158                usage: hal::StateTransition {
1159                    from: hal::AccelerationStructureUses::BUILD_INPUT,
1160                    to: hal::AccelerationStructureUses::BUILD_OUTPUT,
1161                },
1162            });
1163
1164            cmd_buf_raw.build_acceleration_structures(blas_descriptors);
1165        }
1166    }
1167
1168    if blas_present && tlas_present {
1169        unsafe {
1170            cmd_buf_raw.transition_buffers(&[scratch_buffer_barrier]);
1171        }
1172    }
1173
1174    let mut source_usage = hal::AccelerationStructureUses::empty();
1175    let mut destination_usage = hal::AccelerationStructureUses::empty();
1176    for &(buf, blas) in blas_s_for_compaction.iter() {
1177        unsafe {
1178            cmd_buf_raw.transition_buffers(&[hal::BufferBarrier {
1179                buffer: buf,
1180                usage: hal::StateTransition {
1181                    from: BufferUses::ACCELERATION_STRUCTURE_QUERY,
1182                    to: BufferUses::ACCELERATION_STRUCTURE_QUERY,
1183                },
1184            }])
1185        }
1186        unsafe { cmd_buf_raw.read_acceleration_structure_compact_size(blas, buf) }
1187        destination_usage |= hal::AccelerationStructureUses::COPY_SRC;
1188    }
1189
1190    if blas_present {
1191        source_usage |= hal::AccelerationStructureUses::BUILD_OUTPUT;
1192        destination_usage |= hal::AccelerationStructureUses::BUILD_INPUT
1193    }
1194    if tlas_present {
1195        source_usage |= hal::AccelerationStructureUses::SHADER_INPUT;
1196        destination_usage |= hal::AccelerationStructureUses::BUILD_OUTPUT;
1197    }
1198    unsafe {
1199        cmd_buf_raw.place_acceleration_structure_barrier(hal::AccelerationStructureBarrier {
1200            usage: hal::StateTransition {
1201                from: source_usage,
1202                to: destination_usage,
1203            },
1204        });
1205    }
1206}