wgpu_core/command/
ray_tracing.rs

1use alloc::{sync::Arc, vec::Vec};
2use core::{
3    cmp::max,
4    num::NonZeroU64,
5    ops::{Deref, Range},
6};
7
8use wgt::{math::align_to, BufferUsages, BufferUses, Features};
9
10use crate::{
11    command::encoder::EncodingState,
12    ray_tracing::{AsAction, AsBuild, TlasBuild, ValidateAsActionsError},
13    resource::InvalidResourceError,
14};
15use crate::{command::EncoderStateError, device::resource::CommandIndices};
16use crate::{
17    command::{ArcCommand, ArcReferences, CommandBufferMutable},
18    device::queue::TempResource,
19    global::Global,
20    id::CommandEncoderId,
21    init_tracker::MemoryInitKind,
22    ray_tracing::{
23        ArcBlasAabbGeometry, ArcBlasBuildEntry, ArcBlasGeometries, ArcBlasTriangleGeometry,
24        ArcTlasInstance, ArcTlasPackage, BlasBuildEntry, BlasGeometries,
25        BuildAccelerationStructureError, OwnedBlasBuildEntry, OwnedTlasPackage, TlasPackage,
26    },
27    resource::{Blas, BlasCompactState, Labeled, StagingBuffer, Tlas},
28    scratch::ScratchBuffer,
29    snatch::SnatchGuard,
30};
31use crate::{lock::RwLockWriteGuard, resource::RawResourceAccess};
32
33use crate::id::{BlasId, TlasId};
34
35struct BlasStore<'a> {
36    blas: Arc<Blas>,
37    entries: hal::AccelerationStructureEntries<'a, dyn hal::DynBuffer>,
38    scratch_buffer_offset: u64,
39}
40
41struct UnsafeTlasStore<'a> {
42    tlas: Arc<Tlas>,
43    entries: hal::AccelerationStructureEntries<'a, dyn hal::DynBuffer>,
44    scratch_buffer_offset: u64,
45}
46
47struct TlasStore<'a> {
48    internal: UnsafeTlasStore<'a>,
49    range: Range<usize>,
50}
51
52impl Global {
53    fn resolve_blas_id(&self, blas_id: BlasId) -> Result<Arc<Blas>, InvalidResourceError> {
54        self.hub.blas_s.get(blas_id).get()
55    }
56
57    fn resolve_tlas_id(&self, tlas_id: TlasId) -> Result<Arc<Tlas>, InvalidResourceError> {
58        self.hub.tlas_s.get(tlas_id).get()
59    }
60
61    pub fn command_encoder_mark_acceleration_structures_built(
62        &self,
63        command_encoder_id: CommandEncoderId,
64        blas_ids: &[BlasId],
65        tlas_ids: &[TlasId],
66    ) -> Result<(), EncoderStateError> {
67        profiling::scope!("CommandEncoder::mark_acceleration_structures_built");
68
69        let hub = &self.hub;
70
71        let cmd_enc = hub.command_encoders.get(command_encoder_id);
72
73        let mut cmd_buf_data = cmd_enc.data.lock();
74        cmd_buf_data.with_buffer(
75            crate::command::EncodingApi::Raw,
76            |cmd_buf_data| -> Result<(), BuildAccelerationStructureError> {
77                let device = &cmd_enc.device;
78                device.check_is_valid()?;
79                device.require_features(Features::EXPERIMENTAL_RAY_QUERY)?;
80
81                let mut build_command = AsBuild::with_capacity(blas_ids.len(), tlas_ids.len());
82
83                for blas in blas_ids {
84                    let blas = hub.blas_s.get(*blas).get()?;
85                    build_command.blas_s_built.push(blas);
86                }
87
88                for tlas in tlas_ids {
89                    let tlas = hub.tlas_s.get(*tlas).get()?;
90                    build_command.tlas_s_built.push(TlasBuild {
91                        tlas,
92                        dependencies: Vec::new(),
93                    });
94                }
95
96                cmd_buf_data.as_actions.push(AsAction::Build(build_command));
97                Ok(())
98            },
99        )
100    }
101
102    pub fn command_encoder_build_acceleration_structures<'a>(
103        &self,
104        command_encoder_id: CommandEncoderId,
105        blas_iter: impl Iterator<Item = BlasBuildEntry<'a>>,
106        tlas_iter: impl Iterator<Item = TlasPackage<'a>>,
107    ) -> Result<(), EncoderStateError> {
108        profiling::scope!("CommandEncoder::build_acceleration_structures");
109
110        let hub = &self.hub;
111
112        let cmd_enc = hub.command_encoders.get(command_encoder_id);
113        let mut cmd_buf_data = cmd_enc.data.lock();
114
115        cmd_buf_data.push_with(|| -> Result<_, BuildAccelerationStructureError> {
116            let blas = blas_iter
117                .map(|blas_entry| {
118                    let geometries = match blas_entry.geometries {
119                        BlasGeometries::TriangleGeometries(triangle_geometries) => {
120                            let tri_geo = triangle_geometries
121                                .map(|tg| {
122                                    Ok(ArcBlasTriangleGeometry {
123                                        size: tg.size.clone(),
124                                        vertex_buffer: self.resolve_buffer_id(tg.vertex_buffer)?,
125                                        index_buffer: tg
126                                            .index_buffer
127                                            .map(|id| self.resolve_buffer_id(id))
128                                            .transpose()?,
129                                        transform_buffer: tg
130                                            .transform_buffer
131                                            .map(|id| self.resolve_buffer_id(id))
132                                            .transpose()?,
133                                        first_vertex: tg.first_vertex,
134                                        vertex_stride: tg.vertex_stride,
135                                        first_index: tg.first_index,
136                                        transform_buffer_offset: tg.transform_buffer_offset,
137                                    })
138                                })
139                                .collect::<Result<_, BuildAccelerationStructureError>>()?;
140                            ArcBlasGeometries::TriangleGeometries(tri_geo)
141                        }
142                        BlasGeometries::AabbGeometries(aabb_geometries) => {
143                            let aabb_geo = aabb_geometries
144                                .map(|ag| {
145                                    Ok(ArcBlasAabbGeometry {
146                                        size: ag.size.clone(),
147                                        stride: ag.stride,
148                                        aabb_buffer: self.resolve_buffer_id(ag.aabb_buffer)?,
149                                        primitive_offset: ag.primitive_offset,
150                                    })
151                                })
152                                .collect::<Result<_, BuildAccelerationStructureError>>()?;
153                            ArcBlasGeometries::AabbGeometries(aabb_geo)
154                        }
155                    };
156                    Ok(ArcBlasBuildEntry {
157                        blas: self.resolve_blas_id(blas_entry.blas_id)?,
158                        geometries,
159                    })
160                })
161                .collect::<Result<_, BuildAccelerationStructureError>>()?;
162
163            let tlas = tlas_iter
164                .map(|tlas_package| {
165                    let instances = tlas_package
166                        .instances
167                        .map(|instance| {
168                            instance
169                                .as_ref()
170                                .map(|instance| {
171                                    Ok(ArcTlasInstance {
172                                        blas: self.resolve_blas_id(instance.blas_id)?,
173                                        transform: *instance.transform,
174                                        custom_data: instance.custom_data,
175                                        mask: instance.mask,
176                                    })
177                                })
178                                .transpose()
179                        })
180                        .collect::<Result<_, BuildAccelerationStructureError>>()?;
181                    Ok(ArcTlasPackage {
182                        tlas: self.resolve_tlas_id(tlas_package.tlas_id)?,
183                        instances,
184                        lowest_unmodified: tlas_package.lowest_unmodified,
185                    })
186                })
187                .collect::<Result<_, BuildAccelerationStructureError>>()?;
188
189            Ok(ArcCommand::BuildAccelerationStructures { blas, tlas })
190        })
191    }
192}
193
194pub(crate) fn build_acceleration_structures(
195    state: &mut EncodingState,
196    blas: Vec<OwnedBlasBuildEntry<ArcReferences>>,
197    tlas: Vec<OwnedTlasPackage<ArcReferences>>,
198) -> Result<(), BuildAccelerationStructureError> {
199    state
200        .device
201        .require_features(Features::EXPERIMENTAL_RAY_QUERY)?;
202
203    let mut build_command = AsBuild::with_capacity(blas.len(), tlas.len());
204    let mut input_barriers = Vec::<hal::BufferBarrier<dyn hal::DynBuffer>>::new();
205    let mut scratch_buffer_blas_size = 0;
206    let mut blas_storage = Vec::with_capacity(blas.len());
207    iter_blas(
208        blas.iter(),
209        &mut build_command,
210        &mut input_barriers,
211        &mut scratch_buffer_blas_size,
212        &mut blas_storage,
213        state,
214    )?;
215
216    let mut scratch_buffer_tlas_size = 0;
217    let mut tlas_storage = Vec::<TlasStore>::with_capacity(tlas.len());
218    let mut instance_buffer_staging_source = Vec::<u8>::new();
219
220    for package in tlas.iter() {
221        let tlas = &package.tlas;
222        state.tracker.tlas_s.insert_single(tlas.clone());
223
224        let scratch_buffer_offset = scratch_buffer_tlas_size;
225        assert!(
226            tlas.size_info.build_scratch_size
227                <= u64::MAX
228                    - u64::from(state.device.alignments.ray_tracing_scratch_buffer_alignment)
229        );
230        scratch_buffer_tlas_size += align_to(
231            tlas.size_info.build_scratch_size,
232            u64::from(state.device.alignments.ray_tracing_scratch_buffer_alignment),
233        );
234
235        let first_byte_index = instance_buffer_staging_source.len();
236
237        let mut dependencies = Vec::new();
238
239        let mut instance_count = 0;
240        for instance in package.instances.iter().flatten() {
241            if instance.custom_data >= (1u32 << 24u32) {
242                return Err(BuildAccelerationStructureError::TlasInvalidCustomIndex(
243                    tlas.error_ident(),
244                ));
245            }
246            let blas = &instance.blas;
247            state.tracker.blas_s.insert_single(blas.clone());
248
249            instance_buffer_staging_source.extend(state.device.raw().tlas_instance_to_bytes(
250                hal::TlasInstance {
251                    transform: instance.transform,
252                    custom_data: instance.custom_data,
253                    mask: instance.mask,
254                    blas_address: blas.handle,
255                },
256            ));
257
258            if tlas
259                .flags
260                .contains(wgpu_types::AccelerationStructureFlags::ALLOW_RAY_HIT_VERTEX_RETURN)
261                && !blas
262                    .flags
263                    .contains(wgpu_types::AccelerationStructureFlags::ALLOW_RAY_HIT_VERTEX_RETURN)
264            {
265                return Err(
266                    BuildAccelerationStructureError::TlasDependentMissingVertexReturn(
267                        tlas.error_ident(),
268                        blas.error_ident(),
269                    ),
270                );
271            }
272
273            instance_count += 1;
274
275            dependencies.push(blas.clone());
276        }
277
278        build_command.tlas_s_built.push(TlasBuild {
279            tlas: tlas.clone(),
280            dependencies,
281        });
282
283        if instance_count > tlas.max_instance_count {
284            return Err(BuildAccelerationStructureError::TlasInstanceCountExceeded(
285                tlas.error_ident(),
286                instance_count,
287                tlas.max_instance_count,
288            ));
289        }
290
291        tlas_storage.push(TlasStore {
292            internal: UnsafeTlasStore {
293                tlas: tlas.clone(),
294                entries: hal::AccelerationStructureEntries::Instances(
295                    hal::AccelerationStructureInstances {
296                        buffer: Some(tlas.instance_buffer.as_ref()),
297                        offset: 0,
298                        count: instance_count,
299                    },
300                ),
301                scratch_buffer_offset,
302            },
303            range: first_byte_index..instance_buffer_staging_source.len(),
304        });
305    }
306
307    let Some(scratch_size) =
308        wgt::BufferSize::new(max(scratch_buffer_blas_size, scratch_buffer_tlas_size))
309    else {
310        // if the size is zero there is nothing to build
311        return Ok(());
312    };
313
314    assert!(scratch_size.get() != u64::MAX);
315
316    let scratch_buffer = ScratchBuffer::new(state.device, scratch_size)?;
317
318    let scratch_buffer_barrier = hal::BufferBarrier::<dyn hal::DynBuffer> {
319        buffer: scratch_buffer.raw(),
320        usage: hal::StateTransition {
321            from: BufferUses::ACCELERATION_STRUCTURE_SCRATCH,
322            to: BufferUses::ACCELERATION_STRUCTURE_SCRATCH,
323        },
324    };
325
326    let mut tlas_descriptors = Vec::with_capacity(tlas_storage.len());
327
328    for &TlasStore {
329        internal:
330            UnsafeTlasStore {
331                ref tlas,
332                ref entries,
333                ref scratch_buffer_offset,
334            },
335        ..
336    } in &tlas_storage
337    {
338        if tlas.update_mode == wgt::AccelerationStructureUpdateMode::PreferUpdate {
339            log::warn!("build_acceleration_structures called with PreferUpdate, but only rebuild is implemented");
340        }
341        tlas_descriptors.push(hal::BuildAccelerationStructureDescriptor {
342            entries,
343            mode: hal::AccelerationStructureBuildMode::Build,
344            flags: tlas.flags,
345            source_acceleration_structure: None,
346            destination_acceleration_structure: tlas.try_raw(state.snatch_guard)?,
347            scratch_buffer: scratch_buffer.raw(),
348            scratch_buffer_offset: *scratch_buffer_offset,
349        })
350    }
351
352    let blas_present = !blas_storage.is_empty();
353    let tlas_present = !tlas_storage.is_empty();
354
355    let raw_encoder = &mut state.raw_encoder;
356
357    let mut blas_s_compactable = Vec::new();
358    let mut descriptors = Vec::with_capacity(blas.len());
359
360    for storage in &blas_storage {
361        descriptors.push(map_blas(
362            storage,
363            scratch_buffer.raw(),
364            state.snatch_guard,
365            &mut blas_s_compactable,
366        )?);
367    }
368
369    build_blas(
370        *raw_encoder,
371        blas_present,
372        tlas_present,
373        input_barriers,
374        &descriptors,
375        scratch_buffer_barrier,
376        blas_s_compactable,
377    );
378
379    if tlas_present {
380        let staging_buffer = if !instance_buffer_staging_source.is_empty() {
381            let mut staging_buffer = StagingBuffer::new(
382                state.device,
383                wgt::BufferSize::new(instance_buffer_staging_source.len() as u64).unwrap(),
384            )?;
385            staging_buffer.write(&instance_buffer_staging_source);
386            let flushed = staging_buffer.flush();
387            Some(flushed)
388        } else {
389            None
390        };
391
392        unsafe {
393            if let Some(ref staging_buffer) = staging_buffer {
394                raw_encoder.transition_buffers(&[hal::BufferBarrier::<dyn hal::DynBuffer> {
395                    buffer: staging_buffer.raw(),
396                    usage: hal::StateTransition {
397                        from: BufferUses::MAP_WRITE,
398                        to: BufferUses::COPY_SRC,
399                    },
400                }]);
401            }
402        }
403
404        let mut instance_buffer_barriers = Vec::new();
405        for &TlasStore {
406            internal: UnsafeTlasStore { ref tlas, .. },
407            ref range,
408        } in &tlas_storage
409        {
410            let size = match wgt::BufferSize::new((range.end - range.start) as u64) {
411                None => continue,
412                Some(size) => size,
413            };
414            instance_buffer_barriers.push(hal::BufferBarrier::<dyn hal::DynBuffer> {
415                buffer: tlas.instance_buffer.as_ref(),
416                usage: hal::StateTransition {
417                    from: BufferUses::COPY_DST,
418                    to: BufferUses::TOP_LEVEL_ACCELERATION_STRUCTURE_INPUT,
419                },
420            });
421            unsafe {
422                raw_encoder.transition_buffers(&[hal::BufferBarrier::<dyn hal::DynBuffer> {
423                    buffer: tlas.instance_buffer.as_ref(),
424                    usage: hal::StateTransition {
425                        from: BufferUses::TOP_LEVEL_ACCELERATION_STRUCTURE_INPUT,
426                        to: BufferUses::COPY_DST,
427                    },
428                }]);
429                let temp = hal::BufferCopy {
430                    src_offset: range.start as u64,
431                    dst_offset: 0,
432                    size,
433                };
434                raw_encoder.copy_buffer_to_buffer(
435                    staging_buffer.as_ref().unwrap().raw(),
436                    tlas.instance_buffer.as_ref(),
437                    &[temp],
438                );
439            }
440        }
441
442        unsafe {
443            raw_encoder.transition_buffers(&instance_buffer_barriers);
444
445            raw_encoder.build_acceleration_structures(&tlas_descriptors);
446
447            raw_encoder.place_acceleration_structure_barrier(hal::AccelerationStructureBarrier {
448                usage: hal::StateTransition {
449                    from: hal::AccelerationStructureUses::BUILD_OUTPUT,
450                    to: hal::AccelerationStructureUses::SHADER_INPUT,
451                },
452            });
453        }
454
455        if let Some(staging_buffer) = staging_buffer {
456            state
457                .temp_resources
458                .push(TempResource::StagingBuffer(staging_buffer));
459        }
460    }
461
462    state
463        .temp_resources
464        .push(TempResource::ScratchBuffer(scratch_buffer));
465
466    state.as_actions.push(AsAction::Build(build_command));
467
468    Ok(())
469}
470
471impl CommandBufferMutable {
472    pub(crate) fn validate_acceleration_structure_actions(
473        &self,
474        snatch_guard: &SnatchGuard,
475        command_index_guard: &mut RwLockWriteGuard<CommandIndices>,
476    ) -> Result<(), ValidateAsActionsError> {
477        profiling::scope!("CommandEncoder::[submission]::validate_as_actions");
478        for action in &self.as_actions {
479            match action {
480                AsAction::Build(build) => {
481                    let build_command_index = NonZeroU64::new(
482                        command_index_guard.next_acceleration_structure_build_command_index,
483                    )
484                    .unwrap();
485
486                    command_index_guard.next_acceleration_structure_build_command_index += 1;
487                    for blas in build.blas_s_built.iter() {
488                        let mut state_lock = blas.compacted_state.lock();
489                        *state_lock = match *state_lock {
490                            BlasCompactState::Compacted => {
491                                unreachable!("Should be validated out in build.")
492                            }
493                            // Reset the compacted state to idle. This means any prepares, before mapping their
494                            // internal buffer, will terminate.
495                            _ => BlasCompactState::Idle,
496                        };
497                        *blas.built_index.write() = Some(build_command_index);
498                    }
499
500                    for tlas_build in build.tlas_s_built.iter() {
501                        for blas in &tlas_build.dependencies {
502                            if blas.built_index.read().is_none() {
503                                return Err(ValidateAsActionsError::UsedUnbuiltBlas(
504                                    blas.error_ident(),
505                                    tlas_build.tlas.error_ident(),
506                                ));
507                            }
508                        }
509                        *tlas_build.tlas.built_index.write() = Some(build_command_index);
510                        tlas_build
511                            .tlas
512                            .dependencies
513                            .write()
514                            .clone_from(&tlas_build.dependencies)
515                    }
516                }
517                AsAction::UseTlas(tlas) => {
518                    let tlas_build_index = tlas.built_index.read();
519                    let dependencies = tlas.dependencies.read();
520
521                    if (*tlas_build_index).is_none() {
522                        return Err(ValidateAsActionsError::UsedUnbuiltTlas(tlas.error_ident()));
523                    }
524                    for blas in dependencies.deref() {
525                        let blas_build_index = *blas.built_index.read();
526                        if blas_build_index.is_none() {
527                            return Err(ValidateAsActionsError::UsedUnbuiltBlas(
528                                tlas.error_ident(),
529                                blas.error_ident(),
530                            ));
531                        }
532                        if blas_build_index.unwrap() > tlas_build_index.unwrap() {
533                            return Err(ValidateAsActionsError::BlasNewerThenTlas(
534                                blas.error_ident(),
535                                tlas.error_ident(),
536                            ));
537                        }
538                        blas.try_raw(snatch_guard)?;
539                    }
540                }
541            }
542        }
543        Ok(())
544    }
545
546    pub(crate) fn set_acceleration_structure_dependencies(&self, snatch_guard: &SnatchGuard) {
547        profiling::scope!("CommandEncoder::[submission]::set_acceleration_structure_dependencies");
548        let tlas_dependencies_locks: Vec<_> = self
549            .as_actions
550            .iter()
551            .filter_map(|action| {
552                if let AsAction::UseTlas(tlas) = action {
553                    Some(tlas.dependencies.read())
554                } else {
555                    None
556                }
557            })
558            .collect();
559        let mut tlas_dependencies_lock_iter = tlas_dependencies_locks.iter();
560        let mut dependencies = Vec::new();
561        for action in &self.as_actions {
562            match action {
563                AsAction::Build(build) => {
564                    for tlas_build in build.tlas_s_built.iter() {
565                        for dependency in &tlas_build.dependencies {
566                            if let Some(dependency) = dependency.raw(snatch_guard) {
567                                dependencies.push(dependency);
568                            }
569                        }
570                    }
571                }
572                AsAction::UseTlas(_tlas) => {
573                    let tlas_dependencies = tlas_dependencies_lock_iter.next().unwrap(); // _tlas.dependencies.read();
574                    for dependency in tlas_dependencies.iter() {
575                        if let Some(dependency) = dependency.raw(snatch_guard) {
576                            dependencies.push(dependency);
577                        }
578                    }
579                }
580            }
581        }
582        if !dependencies.is_empty() {
583            unsafe {
584                self.encoder
585                    .raw
586                    .set_acceleration_structure_dependencies(&self.encoder.list, &dependencies);
587            }
588        }
589    }
590}
591
592///iterates over the blas iterator, and it's geometry, pushing the buffers into a storage vector (and also some validation).
593fn iter_blas<'snatch_guard: 'buffers, 'buffers>(
594    blas_iter: impl Iterator<Item = &'buffers OwnedBlasBuildEntry<ArcReferences>>,
595    build_command: &mut AsBuild,
596    input_barriers: &mut Vec<hal::BufferBarrier<'buffers, dyn hal::DynBuffer>>,
597    scratch_buffer_blas_size: &mut u64,
598    blas_storage: &mut Vec<BlasStore<'buffers>>,
599    state: &mut EncodingState<'snatch_guard, '_>,
600) -> Result<(), BuildAccelerationStructureError> {
601    for entry in blas_iter {
602        let blas = &entry.blas;
603        state.tracker.blas_s.insert_single(blas.clone());
604
605        build_command.blas_s_built.push(blas.clone());
606
607        match &entry.geometries {
608            ArcBlasGeometries::TriangleGeometries(triangle_geometries) => {
609                let mut triangle_entries =
610                    Vec::<hal::AccelerationStructureTriangles<dyn hal::DynBuffer>>::new();
611
612                for (i, mesh) in triangle_geometries.iter().enumerate() {
613                    let size_desc = match &blas.sizes {
614                        wgt::BlasGeometrySizeDescriptors::Triangles { descriptors } => descriptors,
615                        _ => {
616                            return Err(BuildAccelerationStructureError::BlasGeometryKindMismatch(
617                                blas.error_ident(),
618                            ));
619                        }
620                    };
621                    if i >= size_desc.len() {
622                        return Err(BuildAccelerationStructureError::IncompatibleBlasBuildSizes(
623                            blas.error_ident(),
624                        ));
625                    }
626                    let size_desc = &size_desc[i];
627
628                    if size_desc.flags != mesh.size.flags {
629                        return Err(BuildAccelerationStructureError::IncompatibleBlasFlags(
630                            blas.error_ident(),
631                            size_desc.flags,
632                            mesh.size.flags,
633                        ));
634                    }
635
636                    if size_desc.vertex_count < mesh.size.vertex_count {
637                        return Err(
638                            BuildAccelerationStructureError::IncompatibleBlasVertexCount(
639                                blas.error_ident(),
640                                size_desc.vertex_count,
641                                mesh.size.vertex_count,
642                            ),
643                        );
644                    }
645
646                    if size_desc.vertex_format != mesh.size.vertex_format {
647                        return Err(BuildAccelerationStructureError::DifferentBlasVertexFormats(
648                            blas.error_ident(),
649                            size_desc.vertex_format,
650                            mesh.size.vertex_format,
651                        ));
652                    }
653
654                    if size_desc
655                        .vertex_format
656                        .min_acceleration_structure_vertex_stride()
657                        > mesh.vertex_stride
658                    {
659                        return Err(BuildAccelerationStructureError::VertexStrideTooSmall(
660                            blas.error_ident(),
661                            size_desc
662                                .vertex_format
663                                .min_acceleration_structure_vertex_stride(),
664                            mesh.vertex_stride,
665                        ));
666                    }
667
668                    if mesh.vertex_stride
669                        % size_desc
670                            .vertex_format
671                            .acceleration_structure_stride_alignment()
672                        != 0
673                    {
674                        return Err(BuildAccelerationStructureError::VertexStrideUnaligned(
675                            blas.error_ident(),
676                            size_desc
677                                .vertex_format
678                                .acceleration_structure_stride_alignment(),
679                            mesh.vertex_stride,
680                        ));
681                    }
682
683                    match (size_desc.index_count, mesh.size.index_count) {
684                        (Some(_), None) | (None, Some(_)) => {
685                            return Err(
686                                BuildAccelerationStructureError::BlasIndexCountProvidedMismatch(
687                                    blas.error_ident(),
688                                ),
689                            )
690                        }
691                        (Some(create), Some(build)) if create < build => {
692                            return Err(
693                                BuildAccelerationStructureError::IncompatibleBlasIndexCount(
694                                    blas.error_ident(),
695                                    create,
696                                    build,
697                                ),
698                            )
699                        }
700                        _ => {}
701                    }
702
703                    if size_desc.index_format != mesh.size.index_format {
704                        return Err(BuildAccelerationStructureError::DifferentBlasIndexFormats(
705                            blas.error_ident(),
706                            size_desc.index_format,
707                            mesh.size.index_format,
708                        ));
709                    }
710
711                    if size_desc.index_count.is_some() && mesh.index_buffer.is_none() {
712                        return Err(BuildAccelerationStructureError::MissingIndexBuffer(
713                            blas.error_ident(),
714                        ));
715                    }
716                    let vertex_buffer = mesh.vertex_buffer.clone();
717                    let vertex_pending = state.tracker.buffers.set_single(
718                        &vertex_buffer,
719                        BufferUses::BOTTOM_LEVEL_ACCELERATION_STRUCTURE_INPUT,
720                    );
721                    let vertex_buffer = {
722                        let vertex_raw = mesh.vertex_buffer.as_ref().try_raw(state.snatch_guard)?;
723                        let vertex_buffer = &mesh.vertex_buffer;
724                        vertex_buffer.check_usage(BufferUsages::BLAS_INPUT)?;
725
726                        if let Some(barrier) = vertex_pending.map(|pending| {
727                            pending.into_hal(vertex_buffer.as_ref(), state.snatch_guard)
728                        }) {
729                            input_barriers.push(barrier);
730                        }
731                        if u64::from(mesh.size.vertex_count)
732                            .checked_add(u64::from(mesh.first_vertex))
733                            .and_then(|end_vertex| end_vertex.checked_mul(mesh.vertex_stride))
734                            .is_none_or(|end| vertex_buffer.size < end)
735                        {
736                            return Err(BuildAccelerationStructureError::InsufficientBufferSize {
737                                buffer_ident: vertex_buffer.error_ident(),
738                                offset: u64::from(mesh.first_vertex)
739                                    .saturating_mul(mesh.vertex_stride),
740                                region_size: u64::from(mesh.size.vertex_count)
741                                    .saturating_mul(mesh.vertex_stride),
742                                buffer_size: vertex_buffer.size,
743                            });
744                        }
745                        let vertex_buffer_offset = mesh.first_vertex as u64 * mesh.vertex_stride;
746                        state.buffer_memory_init_actions.extend(
747                            vertex_buffer.initialization_status.read().create_action(
748                                vertex_buffer,
749                                vertex_buffer_offset
750                                    ..(vertex_buffer_offset
751                                        + mesh.size.vertex_count as u64 * mesh.vertex_stride),
752                                MemoryInitKind::NeedsInitializedMemory,
753                            ),
754                        );
755                        vertex_raw
756                    };
757                    let index_buffer = if let Some(ref index_buffer) = mesh.index_buffer {
758                        if mesh.first_index.is_none()
759                            || mesh.size.index_count.is_none()
760                            || mesh.size.index_count.is_none()
761                        {
762                            return Err(BuildAccelerationStructureError::MissingAssociatedData(
763                                index_buffer.error_ident(),
764                            ));
765                        }
766                        let index_pending = state.tracker.buffers.set_single(
767                            index_buffer,
768                            BufferUses::BOTTOM_LEVEL_ACCELERATION_STRUCTURE_INPUT,
769                        );
770                        let index_raw = index_buffer.try_raw(state.snatch_guard)?;
771                        index_buffer.check_usage(BufferUsages::BLAS_INPUT)?;
772
773                        if let Some(barrier) = index_pending.map(|pending| {
774                            pending.into_hal(index_buffer.as_ref(), state.snatch_guard)
775                        }) {
776                            input_barriers.push(barrier);
777                        }
778                        let index_stride = mesh.size.index_format.unwrap().byte_size() as u32;
779                        // `hal::AccelerationStructureTriangleIndices` accepts only `u32` offset
780                        let vertex_offset = mesh.first_index.unwrap().saturating_mul(index_stride);
781                        let indexes_size =
782                            mesh.size.index_count.unwrap().saturating_mul(index_stride);
783
784                        if mesh.size.index_count.unwrap() % 3 != 0 {
785                            return Err(BuildAccelerationStructureError::InvalidIndexCount(
786                                index_buffer.error_ident(),
787                                mesh.size.index_count.unwrap(),
788                            ));
789                        }
790                        if indexes_size == u32::MAX
791                            || vertex_offset == u32::MAX
792                            || index_buffer.size < u64::from(vertex_offset)
793                            || index_buffer.size - u64::from(vertex_offset)
794                                < u64::from(indexes_size)
795                        {
796                            return Err(BuildAccelerationStructureError::InsufficientBufferSize {
797                                buffer_ident: index_buffer.error_ident(),
798                                offset: u64::from(vertex_offset),
799                                region_size: u64::from(indexes_size),
800                                buffer_size: index_buffer.size,
801                            });
802                        }
803
804                        state.buffer_memory_init_actions.extend(
805                            index_buffer.initialization_status.read().create_action(
806                                index_buffer,
807                                u64::from(vertex_offset)
808                                    ..(u64::from(vertex_offset) + u64::from(indexes_size)),
809                                MemoryInitKind::NeedsInitializedMemory,
810                            ),
811                        );
812                        Some((index_raw, vertex_offset))
813                    } else {
814                        None
815                    };
816                    let transform_buffer = if let Some(ref transform_buffer) = mesh.transform_buffer
817                    {
818                        if !blas
819                            .flags
820                            .contains(wgt::AccelerationStructureFlags::USE_TRANSFORM)
821                        {
822                            return Err(BuildAccelerationStructureError::UseTransformMissing(
823                                blas.error_ident(),
824                            ));
825                        }
826                        if mesh.transform_buffer_offset.is_none() {
827                            return Err(BuildAccelerationStructureError::MissingAssociatedData(
828                                transform_buffer.error_ident(),
829                            ));
830                        }
831                        let transform_pending = state.tracker.buffers.set_single(
832                            transform_buffer,
833                            BufferUses::BOTTOM_LEVEL_ACCELERATION_STRUCTURE_INPUT,
834                        );
835                        if mesh.transform_buffer_offset.is_none() {
836                            return Err(BuildAccelerationStructureError::MissingAssociatedData(
837                                transform_buffer.error_ident(),
838                            ));
839                        }
840                        let transform_raw = transform_buffer.try_raw(state.snatch_guard)?;
841                        transform_buffer.check_usage(BufferUsages::BLAS_INPUT)?;
842
843                        if let Some(barrier) = transform_pending.map(|pending| {
844                            pending.into_hal(transform_buffer.as_ref(), state.snatch_guard)
845                        }) {
846                            input_barriers.push(barrier);
847                        }
848
849                        // `hal::AccelerationStructureTriangleTransform` accepts only `u32` offset
850                        let offset = u32::try_from(mesh.transform_buffer_offset.unwrap())
851                            .unwrap_or(u32::MAX);
852
853                        if offset % wgt::TRANSFORM_BUFFER_ALIGNMENT as u32 != 0 {
854                            return Err(
855                                BuildAccelerationStructureError::UnalignedTransformBufferOffset(
856                                    transform_buffer.error_ident(),
857                                ),
858                            );
859                        }
860                        if offset == u32::MAX
861                            || transform_buffer.size < 48
862                            || transform_buffer.size - 48 < u64::from(offset)
863                        {
864                            return Err(BuildAccelerationStructureError::InsufficientBufferSize {
865                                buffer_ident: transform_buffer.error_ident(),
866                                region_size: 48,
867                                offset: u64::from(offset),
868                                buffer_size: transform_buffer.size,
869                            });
870                        }
871                        state.buffer_memory_init_actions.extend(
872                            transform_buffer.initialization_status.read().create_action(
873                                transform_buffer,
874                                u64::from(offset)..(u64::from(offset) + 48),
875                                MemoryInitKind::NeedsInitializedMemory,
876                            ),
877                        );
878                        Some((transform_raw, offset))
879                    } else {
880                        if blas
881                            .flags
882                            .contains(wgt::AccelerationStructureFlags::USE_TRANSFORM)
883                        {
884                            return Err(BuildAccelerationStructureError::TransformMissing(
885                                blas.error_ident(),
886                            ));
887                        }
888                        None
889                    };
890
891                    let triangles = hal::AccelerationStructureTriangles {
892                        vertex_buffer: Some(vertex_buffer),
893                        vertex_format: mesh.size.vertex_format,
894                        first_vertex: mesh.first_vertex,
895                        vertex_count: mesh.size.vertex_count,
896                        vertex_stride: mesh.vertex_stride,
897                        indices: index_buffer.map(|(index_raw, vertex_offset)| {
898                            hal::AccelerationStructureTriangleIndices::<dyn hal::DynBuffer> {
899                                format: mesh.size.index_format.unwrap(),
900                                buffer: Some(index_raw),
901                                offset: vertex_offset,
902                                count: mesh.size.index_count.unwrap(),
903                            }
904                        }),
905                        transform: transform_buffer.map(|(transform_raw, offset)| {
906                            hal::AccelerationStructureTriangleTransform {
907                                buffer: transform_raw,
908                                offset,
909                            }
910                        }),
911                        flags: mesh.size.flags,
912                    };
913                    triangle_entries.push(triangles);
914                }
915
916                {
917                    let scratch_buffer_offset = *scratch_buffer_blas_size;
918                    assert!(
919                        blas.size_info.build_scratch_size
920                            <= u64::MAX
921                                - u64::from(
922                                    state.device.alignments.ray_tracing_scratch_buffer_alignment
923                                )
924                    );
925                    *scratch_buffer_blas_size = scratch_buffer_blas_size.saturating_add(align_to(
926                        blas.size_info.build_scratch_size,
927                        u64::from(state.device.alignments.ray_tracing_scratch_buffer_alignment),
928                    ));
929
930                    blas_storage.push(BlasStore {
931                        blas: blas.clone(),
932                        entries: hal::AccelerationStructureEntries::Triangles(triangle_entries),
933                        scratch_buffer_offset,
934                    });
935                }
936            }
937            ArcBlasGeometries::AabbGeometries(aabb_geometries) => {
938                let mut aabb_entries =
939                    Vec::<hal::AccelerationStructureAABBs<dyn hal::DynBuffer>>::new();
940
941                for (i, aabb) in aabb_geometries.iter().enumerate() {
942                    let size_desc = match &blas.sizes {
943                        wgt::BlasGeometrySizeDescriptors::AABBs { descriptors } => descriptors,
944                        _ => {
945                            return Err(BuildAccelerationStructureError::BlasGeometryKindMismatch(
946                                blas.error_ident(),
947                            ));
948                        }
949                    };
950                    if i >= size_desc.len() {
951                        return Err(BuildAccelerationStructureError::IncompatibleBlasBuildSizes(
952                            blas.error_ident(),
953                        ));
954                    }
955                    let size_desc = &size_desc[i];
956
957                    if size_desc.flags != aabb.size.flags {
958                        return Err(BuildAccelerationStructureError::IncompatibleBlasFlags(
959                            blas.error_ident(),
960                            size_desc.flags,
961                            aabb.size.flags,
962                        ));
963                    }
964
965                    if size_desc.primitive_count < aabb.size.primitive_count {
966                        return Err(
967                            BuildAccelerationStructureError::IncompatibleBlasAabbPrimitiveCount(
968                                blas.error_ident(),
969                                size_desc.primitive_count,
970                                aabb.size.primitive_count,
971                            ),
972                        );
973                    }
974
975                    if aabb.primitive_offset % 8 != 0 {
976                        return Err(
977                            BuildAccelerationStructureError::UnalignedAabbPrimitiveOffset(
978                                blas.error_ident(),
979                            ),
980                        );
981                    }
982
983                    if aabb.stride < wgt::AABB_GEOMETRY_MIN_STRIDE || aabb.stride % 8 != 0 {
984                        return Err(BuildAccelerationStructureError::InvalidAabbStride(
985                            blas.error_ident(),
986                            aabb.stride,
987                        ));
988                    }
989
990                    let aabb_buffer = aabb.aabb_buffer.clone();
991                    let aabb_pending = state.tracker.buffers.set_single(
992                        &aabb_buffer,
993                        BufferUses::BOTTOM_LEVEL_ACCELERATION_STRUCTURE_INPUT,
994                    );
995                    let aabb_raw = {
996                        let aabb_raw = aabb.aabb_buffer.as_ref().try_raw(state.snatch_guard)?;
997                        let aabb_buffer = &aabb.aabb_buffer;
998                        aabb_buffer.check_usage(BufferUsages::BLAS_INPUT)?;
999
1000                        if let Some(barrier) = aabb_pending.map(|pending| {
1001                            pending.into_hal(aabb_buffer.as_ref(), state.snatch_guard)
1002                        }) {
1003                            input_barriers.push(barrier);
1004                        }
1005
1006                        // `hal::AccelerationStructureAABBs` accepts only `u32` offset
1007                        let aabb_size = aabb
1008                            .size
1009                            .primitive_count
1010                            .saturating_mul(u32::try_from(aabb.stride).unwrap());
1011                        if aabb_size == u32::MAX
1012                            || aabb_buffer.size < u64::from(aabb.primitive_offset)
1013                            || aabb_buffer.size - u64::from(aabb.primitive_offset)
1014                                < u64::from(aabb_size)
1015                        {
1016                            return Err(BuildAccelerationStructureError::InsufficientBufferSize {
1017                                buffer_ident: aabb_buffer.error_ident(),
1018                                region_size: u64::from(aabb_size),
1019                                offset: u64::from(aabb.primitive_offset),
1020                                buffer_size: aabb_buffer.size,
1021                            });
1022                        }
1023
1024                        state.buffer_memory_init_actions.extend(
1025                            aabb_buffer.initialization_status.read().create_action(
1026                                aabb_buffer,
1027                                u64::from(aabb.primitive_offset)
1028                                    ..u64::from(aabb.primitive_offset) + u64::from(aabb_size),
1029                                MemoryInitKind::NeedsInitializedMemory,
1030                            ),
1031                        );
1032                        aabb_raw
1033                    };
1034
1035                    aabb_entries.push(hal::AccelerationStructureAABBs {
1036                        buffer: Some(aabb_raw),
1037                        offset: aabb.primitive_offset,
1038                        count: aabb.size.primitive_count,
1039                        stride: aabb.stride,
1040                        flags: aabb.size.flags,
1041                    });
1042                }
1043
1044                {
1045                    let scratch_buffer_offset = *scratch_buffer_blas_size;
1046                    *scratch_buffer_blas_size = scratch_buffer_blas_size.saturating_add(align_to(
1047                        blas.size_info.build_scratch_size,
1048                        u64::from(state.device.alignments.ray_tracing_scratch_buffer_alignment),
1049                    ));
1050
1051                    blas_storage.push(BlasStore {
1052                        blas: blas.clone(),
1053                        entries: hal::AccelerationStructureEntries::AABBs(aabb_entries),
1054                        scratch_buffer_offset,
1055                    });
1056                }
1057            }
1058        }
1059    }
1060    Ok(())
1061}
1062
1063fn map_blas<'a>(
1064    storage: &'a BlasStore<'_>,
1065    scratch_buffer: &'a dyn hal::DynBuffer,
1066    snatch_guard: &'a SnatchGuard,
1067    blases_compactable: &mut Vec<(
1068        &'a dyn hal::DynBuffer,
1069        &'a dyn hal::DynAccelerationStructure,
1070    )>,
1071) -> Result<
1072    hal::BuildAccelerationStructureDescriptor<
1073        'a,
1074        dyn hal::DynBuffer,
1075        dyn hal::DynAccelerationStructure,
1076    >,
1077    BuildAccelerationStructureError,
1078> {
1079    let BlasStore {
1080        blas,
1081        entries,
1082        scratch_buffer_offset,
1083    } = storage;
1084    if blas.update_mode == wgt::AccelerationStructureUpdateMode::PreferUpdate {
1085        log::debug!("only rebuild implemented")
1086    }
1087    let raw = blas.try_raw(snatch_guard)?;
1088
1089    let state_lock = blas.compacted_state.lock();
1090    if let BlasCompactState::Compacted = *state_lock {
1091        return Err(BuildAccelerationStructureError::CompactedBlas(
1092            blas.error_ident(),
1093        ));
1094    }
1095
1096    if blas
1097        .flags
1098        .contains(wgpu_types::AccelerationStructureFlags::ALLOW_COMPACTION)
1099    {
1100        blases_compactable.push((blas.compaction_buffer.as_ref().unwrap().as_ref(), raw));
1101    }
1102    Ok(hal::BuildAccelerationStructureDescriptor {
1103        entries,
1104        mode: hal::AccelerationStructureBuildMode::Build,
1105        flags: blas.flags,
1106        source_acceleration_structure: None,
1107        destination_acceleration_structure: raw,
1108        scratch_buffer,
1109        scratch_buffer_offset: *scratch_buffer_offset,
1110    })
1111}
1112
1113fn build_blas<'a>(
1114    cmd_buf_raw: &mut dyn hal::DynCommandEncoder,
1115    blas_present: bool,
1116    tlas_present: bool,
1117    input_barriers: Vec<hal::BufferBarrier<dyn hal::DynBuffer>>,
1118    blas_descriptors: &[hal::BuildAccelerationStructureDescriptor<
1119        'a,
1120        dyn hal::DynBuffer,
1121        dyn hal::DynAccelerationStructure,
1122    >],
1123    scratch_buffer_barrier: hal::BufferBarrier<dyn hal::DynBuffer>,
1124    blas_s_for_compaction: Vec<(
1125        &'a dyn hal::DynBuffer,
1126        &'a dyn hal::DynAccelerationStructure,
1127    )>,
1128) {
1129    unsafe {
1130        cmd_buf_raw.transition_buffers(&input_barriers);
1131    }
1132
1133    if blas_present {
1134        unsafe {
1135            cmd_buf_raw.place_acceleration_structure_barrier(hal::AccelerationStructureBarrier {
1136                usage: hal::StateTransition {
1137                    from: hal::AccelerationStructureUses::BUILD_INPUT,
1138                    to: hal::AccelerationStructureUses::BUILD_OUTPUT,
1139                },
1140            });
1141
1142            cmd_buf_raw.build_acceleration_structures(blas_descriptors);
1143        }
1144    }
1145
1146    if blas_present && tlas_present {
1147        unsafe {
1148            cmd_buf_raw.transition_buffers(&[scratch_buffer_barrier]);
1149        }
1150    }
1151
1152    let mut source_usage = hal::AccelerationStructureUses::empty();
1153    let mut destination_usage = hal::AccelerationStructureUses::empty();
1154    for &(buf, blas) in blas_s_for_compaction.iter() {
1155        unsafe {
1156            cmd_buf_raw.transition_buffers(&[hal::BufferBarrier {
1157                buffer: buf,
1158                usage: hal::StateTransition {
1159                    from: BufferUses::ACCELERATION_STRUCTURE_QUERY,
1160                    to: BufferUses::ACCELERATION_STRUCTURE_QUERY,
1161                },
1162            }])
1163        }
1164        unsafe { cmd_buf_raw.read_acceleration_structure_compact_size(blas, buf) }
1165        destination_usage |= hal::AccelerationStructureUses::COPY_SRC;
1166    }
1167
1168    if blas_present {
1169        source_usage |= hal::AccelerationStructureUses::BUILD_OUTPUT;
1170        destination_usage |= hal::AccelerationStructureUses::BUILD_INPUT
1171    }
1172    if tlas_present {
1173        source_usage |= hal::AccelerationStructureUses::SHADER_INPUT;
1174        destination_usage |= hal::AccelerationStructureUses::BUILD_OUTPUT;
1175    }
1176    unsafe {
1177        cmd_buf_raw.place_acceleration_structure_barrier(hal::AccelerationStructureBarrier {
1178            usage: hal::StateTransition {
1179                from: source_usage,
1180                to: destination_usage,
1181            },
1182        });
1183    }
1184}