wgpu_core/command/
ray_tracing.rs

1use alloc::{sync::Arc, vec::Vec};
2use core::{
3    cmp::max,
4    num::NonZeroU64,
5    ops::{Deref, Range},
6};
7
8use wgt::{math::align_to, BufferUsages, BufferUses, Features};
9
10use crate::{
11    command::encoder::EncodingState,
12    ray_tracing::{AsAction, AsBuild, TlasBuild, ValidateAsActionsError},
13    resource::InvalidResourceError,
14};
15use crate::{command::EncoderStateError, device::resource::CommandIndices};
16use crate::{
17    command::{ArcCommand, ArcReferences, CommandBufferMutable},
18    device::queue::TempResource,
19    global::Global,
20    id::CommandEncoderId,
21    init_tracker::MemoryInitKind,
22    ray_tracing::{
23        ArcBlasBuildEntry, ArcBlasGeometries, ArcBlasTriangleGeometry, ArcTlasInstance,
24        ArcTlasPackage, BlasBuildEntry, BlasGeometries, BuildAccelerationStructureError,
25        OwnedBlasBuildEntry, OwnedTlasPackage, TlasPackage,
26    },
27    resource::{Blas, BlasCompactState, Labeled, StagingBuffer, Tlas},
28    scratch::ScratchBuffer,
29    snatch::SnatchGuard,
30};
31use crate::{lock::RwLockWriteGuard, resource::RawResourceAccess};
32
33use crate::id::{BlasId, TlasId};
34
35struct BlasStore<'a> {
36    blas: Arc<Blas>,
37    entries: hal::AccelerationStructureEntries<'a, dyn hal::DynBuffer>,
38    scratch_buffer_offset: u64,
39}
40
41struct UnsafeTlasStore<'a> {
42    tlas: Arc<Tlas>,
43    entries: hal::AccelerationStructureEntries<'a, dyn hal::DynBuffer>,
44    scratch_buffer_offset: u64,
45}
46
47struct TlasStore<'a> {
48    internal: UnsafeTlasStore<'a>,
49    range: Range<usize>,
50}
51
52impl Global {
53    fn resolve_blas_id(&self, blas_id: BlasId) -> Result<Arc<Blas>, InvalidResourceError> {
54        self.hub.blas_s.get(blas_id).get()
55    }
56
57    fn resolve_tlas_id(&self, tlas_id: TlasId) -> Result<Arc<Tlas>, InvalidResourceError> {
58        self.hub.tlas_s.get(tlas_id).get()
59    }
60
61    pub fn command_encoder_mark_acceleration_structures_built(
62        &self,
63        command_encoder_id: CommandEncoderId,
64        blas_ids: &[BlasId],
65        tlas_ids: &[TlasId],
66    ) -> Result<(), EncoderStateError> {
67        profiling::scope!("CommandEncoder::mark_acceleration_structures_built");
68
69        let hub = &self.hub;
70
71        let cmd_enc = hub.command_encoders.get(command_encoder_id);
72
73        let mut cmd_buf_data = cmd_enc.data.lock();
74        cmd_buf_data.with_buffer(
75            crate::command::EncodingApi::Raw,
76            |cmd_buf_data| -> Result<(), BuildAccelerationStructureError> {
77                let device = &cmd_enc.device;
78                device.check_is_valid()?;
79                device.require_features(Features::EXPERIMENTAL_RAY_QUERY)?;
80
81                let mut build_command = AsBuild::with_capacity(blas_ids.len(), tlas_ids.len());
82
83                for blas in blas_ids {
84                    let blas = hub.blas_s.get(*blas).get()?;
85                    build_command.blas_s_built.push(blas);
86                }
87
88                for tlas in tlas_ids {
89                    let tlas = hub.tlas_s.get(*tlas).get()?;
90                    build_command.tlas_s_built.push(TlasBuild {
91                        tlas,
92                        dependencies: Vec::new(),
93                    });
94                }
95
96                cmd_buf_data.as_actions.push(AsAction::Build(build_command));
97                Ok(())
98            },
99        )
100    }
101
102    pub fn command_encoder_build_acceleration_structures<'a>(
103        &self,
104        command_encoder_id: CommandEncoderId,
105        blas_iter: impl Iterator<Item = BlasBuildEntry<'a>>,
106        tlas_iter: impl Iterator<Item = TlasPackage<'a>>,
107    ) -> Result<(), EncoderStateError> {
108        profiling::scope!("CommandEncoder::build_acceleration_structures");
109
110        let hub = &self.hub;
111
112        let cmd_enc = hub.command_encoders.get(command_encoder_id);
113        let mut cmd_buf_data = cmd_enc.data.lock();
114
115        cmd_buf_data.push_with(|| -> Result<_, BuildAccelerationStructureError> {
116            let blas = blas_iter
117                .map(|blas_entry| {
118                    let geometries = match blas_entry.geometries {
119                        BlasGeometries::TriangleGeometries(triangle_geometries) => {
120                            let tri_geo = triangle_geometries
121                                .map(|tg| {
122                                    Ok(ArcBlasTriangleGeometry {
123                                        size: tg.size.clone(),
124                                        vertex_buffer: self.resolve_buffer_id(tg.vertex_buffer)?,
125                                        index_buffer: tg
126                                            .index_buffer
127                                            .map(|id| self.resolve_buffer_id(id))
128                                            .transpose()?,
129                                        transform_buffer: tg
130                                            .transform_buffer
131                                            .map(|id| self.resolve_buffer_id(id))
132                                            .transpose()?,
133                                        first_vertex: tg.first_vertex,
134                                        vertex_stride: tg.vertex_stride,
135                                        first_index: tg.first_index,
136                                        transform_buffer_offset: tg.transform_buffer_offset,
137                                    })
138                                })
139                                .collect::<Result<_, BuildAccelerationStructureError>>()?;
140                            ArcBlasGeometries::TriangleGeometries(tri_geo)
141                        }
142                    };
143                    Ok(ArcBlasBuildEntry {
144                        blas: self.resolve_blas_id(blas_entry.blas_id)?,
145                        geometries,
146                    })
147                })
148                .collect::<Result<_, BuildAccelerationStructureError>>()?;
149
150            let tlas = tlas_iter
151                .map(|tlas_package| {
152                    let instances = tlas_package
153                        .instances
154                        .map(|instance| {
155                            instance
156                                .as_ref()
157                                .map(|instance| {
158                                    Ok(ArcTlasInstance {
159                                        blas: self.resolve_blas_id(instance.blas_id)?,
160                                        transform: *instance.transform,
161                                        custom_data: instance.custom_data,
162                                        mask: instance.mask,
163                                    })
164                                })
165                                .transpose()
166                        })
167                        .collect::<Result<_, BuildAccelerationStructureError>>()?;
168                    Ok(ArcTlasPackage {
169                        tlas: self.resolve_tlas_id(tlas_package.tlas_id)?,
170                        instances,
171                        lowest_unmodified: tlas_package.lowest_unmodified,
172                    })
173                })
174                .collect::<Result<_, BuildAccelerationStructureError>>()?;
175
176            Ok(ArcCommand::BuildAccelerationStructures { blas, tlas })
177        })
178    }
179}
180
181pub(crate) fn build_acceleration_structures(
182    state: &mut EncodingState,
183    blas: Vec<OwnedBlasBuildEntry<ArcReferences>>,
184    tlas: Vec<OwnedTlasPackage<ArcReferences>>,
185) -> Result<(), BuildAccelerationStructureError> {
186    state
187        .device
188        .require_features(Features::EXPERIMENTAL_RAY_QUERY)?;
189
190    let mut build_command = AsBuild::with_capacity(blas.len(), tlas.len());
191    let mut input_barriers = Vec::<hal::BufferBarrier<dyn hal::DynBuffer>>::new();
192    let mut scratch_buffer_blas_size = 0;
193    let mut blas_storage = Vec::with_capacity(blas.len());
194    iter_blas(
195        blas.iter(),
196        &mut build_command,
197        &mut input_barriers,
198        &mut scratch_buffer_blas_size,
199        &mut blas_storage,
200        state,
201    )?;
202
203    let mut scratch_buffer_tlas_size = 0;
204    let mut tlas_storage = Vec::<TlasStore>::with_capacity(tlas.len());
205    let mut instance_buffer_staging_source = Vec::<u8>::new();
206
207    for package in tlas.iter() {
208        let tlas = &package.tlas;
209        state.tracker.tlas_s.insert_single(tlas.clone());
210
211        let scratch_buffer_offset = scratch_buffer_tlas_size;
212        scratch_buffer_tlas_size += align_to(
213            tlas.size_info.build_scratch_size as u32,
214            state.device.alignments.ray_tracing_scratch_buffer_alignment,
215        ) as u64;
216
217        let first_byte_index = instance_buffer_staging_source.len();
218
219        let mut dependencies = Vec::new();
220
221        let mut instance_count = 0;
222        for instance in package.instances.iter().flatten() {
223            if instance.custom_data >= (1u32 << 24u32) {
224                return Err(BuildAccelerationStructureError::TlasInvalidCustomIndex(
225                    tlas.error_ident(),
226                ));
227            }
228            let blas = &instance.blas;
229            state.tracker.blas_s.insert_single(blas.clone());
230
231            instance_buffer_staging_source.extend(state.device.raw().tlas_instance_to_bytes(
232                hal::TlasInstance {
233                    transform: instance.transform,
234                    custom_data: instance.custom_data,
235                    mask: instance.mask,
236                    blas_address: blas.handle,
237                },
238            ));
239
240            if tlas
241                .flags
242                .contains(wgpu_types::AccelerationStructureFlags::ALLOW_RAY_HIT_VERTEX_RETURN)
243                && !blas
244                    .flags
245                    .contains(wgpu_types::AccelerationStructureFlags::ALLOW_RAY_HIT_VERTEX_RETURN)
246            {
247                return Err(
248                    BuildAccelerationStructureError::TlasDependentMissingVertexReturn(
249                        tlas.error_ident(),
250                        blas.error_ident(),
251                    ),
252                );
253            }
254
255            instance_count += 1;
256
257            dependencies.push(blas.clone());
258        }
259
260        build_command.tlas_s_built.push(TlasBuild {
261            tlas: tlas.clone(),
262            dependencies,
263        });
264
265        if instance_count > tlas.max_instance_count {
266            return Err(BuildAccelerationStructureError::TlasInstanceCountExceeded(
267                tlas.error_ident(),
268                instance_count,
269                tlas.max_instance_count,
270            ));
271        }
272
273        tlas_storage.push(TlasStore {
274            internal: UnsafeTlasStore {
275                tlas: tlas.clone(),
276                entries: hal::AccelerationStructureEntries::Instances(
277                    hal::AccelerationStructureInstances {
278                        buffer: Some(tlas.instance_buffer.as_ref()),
279                        offset: 0,
280                        count: instance_count,
281                    },
282                ),
283                scratch_buffer_offset,
284            },
285            range: first_byte_index..instance_buffer_staging_source.len(),
286        });
287    }
288
289    let Some(scratch_size) =
290        wgt::BufferSize::new(max(scratch_buffer_blas_size, scratch_buffer_tlas_size))
291    else {
292        // if the size is zero there is nothing to build
293        return Ok(());
294    };
295
296    let scratch_buffer = ScratchBuffer::new(state.device, scratch_size)?;
297
298    let scratch_buffer_barrier = hal::BufferBarrier::<dyn hal::DynBuffer> {
299        buffer: scratch_buffer.raw(),
300        usage: hal::StateTransition {
301            from: BufferUses::ACCELERATION_STRUCTURE_SCRATCH,
302            to: BufferUses::ACCELERATION_STRUCTURE_SCRATCH,
303        },
304    };
305
306    let mut tlas_descriptors = Vec::with_capacity(tlas_storage.len());
307
308    for &TlasStore {
309        internal:
310            UnsafeTlasStore {
311                ref tlas,
312                ref entries,
313                ref scratch_buffer_offset,
314            },
315        ..
316    } in &tlas_storage
317    {
318        if tlas.update_mode == wgt::AccelerationStructureUpdateMode::PreferUpdate {
319            log::warn!("build_acceleration_structures called with PreferUpdate, but only rebuild is implemented");
320        }
321        tlas_descriptors.push(hal::BuildAccelerationStructureDescriptor {
322            entries,
323            mode: hal::AccelerationStructureBuildMode::Build,
324            flags: tlas.flags,
325            source_acceleration_structure: None,
326            destination_acceleration_structure: tlas.try_raw(state.snatch_guard)?,
327            scratch_buffer: scratch_buffer.raw(),
328            scratch_buffer_offset: *scratch_buffer_offset,
329        })
330    }
331
332    let blas_present = !blas_storage.is_empty();
333    let tlas_present = !tlas_storage.is_empty();
334
335    let raw_encoder = &mut state.raw_encoder;
336
337    let mut blas_s_compactable = Vec::new();
338    let mut descriptors = Vec::with_capacity(blas.len());
339
340    for storage in &blas_storage {
341        descriptors.push(map_blas(
342            storage,
343            scratch_buffer.raw(),
344            state.snatch_guard,
345            &mut blas_s_compactable,
346        )?);
347    }
348
349    build_blas(
350        *raw_encoder,
351        blas_present,
352        tlas_present,
353        input_barriers,
354        &descriptors,
355        scratch_buffer_barrier,
356        blas_s_compactable,
357    );
358
359    if tlas_present {
360        let staging_buffer = if !instance_buffer_staging_source.is_empty() {
361            let mut staging_buffer = StagingBuffer::new(
362                state.device,
363                wgt::BufferSize::new(instance_buffer_staging_source.len() as u64).unwrap(),
364            )?;
365            staging_buffer.write(&instance_buffer_staging_source);
366            let flushed = staging_buffer.flush();
367            Some(flushed)
368        } else {
369            None
370        };
371
372        unsafe {
373            if let Some(ref staging_buffer) = staging_buffer {
374                raw_encoder.transition_buffers(&[hal::BufferBarrier::<dyn hal::DynBuffer> {
375                    buffer: staging_buffer.raw(),
376                    usage: hal::StateTransition {
377                        from: BufferUses::MAP_WRITE,
378                        to: BufferUses::COPY_SRC,
379                    },
380                }]);
381            }
382        }
383
384        let mut instance_buffer_barriers = Vec::new();
385        for &TlasStore {
386            internal: UnsafeTlasStore { ref tlas, .. },
387            ref range,
388        } in &tlas_storage
389        {
390            let size = match wgt::BufferSize::new((range.end - range.start) as u64) {
391                None => continue,
392                Some(size) => size,
393            };
394            instance_buffer_barriers.push(hal::BufferBarrier::<dyn hal::DynBuffer> {
395                buffer: tlas.instance_buffer.as_ref(),
396                usage: hal::StateTransition {
397                    from: BufferUses::COPY_DST,
398                    to: BufferUses::TOP_LEVEL_ACCELERATION_STRUCTURE_INPUT,
399                },
400            });
401            unsafe {
402                raw_encoder.transition_buffers(&[hal::BufferBarrier::<dyn hal::DynBuffer> {
403                    buffer: tlas.instance_buffer.as_ref(),
404                    usage: hal::StateTransition {
405                        from: BufferUses::TOP_LEVEL_ACCELERATION_STRUCTURE_INPUT,
406                        to: BufferUses::COPY_DST,
407                    },
408                }]);
409                let temp = hal::BufferCopy {
410                    src_offset: range.start as u64,
411                    dst_offset: 0,
412                    size,
413                };
414                raw_encoder.copy_buffer_to_buffer(
415                    staging_buffer.as_ref().unwrap().raw(),
416                    tlas.instance_buffer.as_ref(),
417                    &[temp],
418                );
419            }
420        }
421
422        unsafe {
423            raw_encoder.transition_buffers(&instance_buffer_barriers);
424
425            raw_encoder.build_acceleration_structures(&tlas_descriptors);
426
427            raw_encoder.place_acceleration_structure_barrier(hal::AccelerationStructureBarrier {
428                usage: hal::StateTransition {
429                    from: hal::AccelerationStructureUses::BUILD_OUTPUT,
430                    to: hal::AccelerationStructureUses::SHADER_INPUT,
431                },
432            });
433        }
434
435        if let Some(staging_buffer) = staging_buffer {
436            state
437                .temp_resources
438                .push(TempResource::StagingBuffer(staging_buffer));
439        }
440    }
441
442    state
443        .temp_resources
444        .push(TempResource::ScratchBuffer(scratch_buffer));
445
446    state.as_actions.push(AsAction::Build(build_command));
447
448    Ok(())
449}
450
451impl CommandBufferMutable {
452    pub(crate) fn validate_acceleration_structure_actions(
453        &self,
454        snatch_guard: &SnatchGuard,
455        command_index_guard: &mut RwLockWriteGuard<CommandIndices>,
456    ) -> Result<(), ValidateAsActionsError> {
457        profiling::scope!("CommandEncoder::[submission]::validate_as_actions");
458        for action in &self.as_actions {
459            match action {
460                AsAction::Build(build) => {
461                    let build_command_index = NonZeroU64::new(
462                        command_index_guard.next_acceleration_structure_build_command_index,
463                    )
464                    .unwrap();
465
466                    command_index_guard.next_acceleration_structure_build_command_index += 1;
467                    for blas in build.blas_s_built.iter() {
468                        let mut state_lock = blas.compacted_state.lock();
469                        *state_lock = match *state_lock {
470                            BlasCompactState::Compacted => {
471                                unreachable!("Should be validated out in build.")
472                            }
473                            // Reset the compacted state to idle. This means any prepares, before mapping their
474                            // internal buffer, will terminate.
475                            _ => BlasCompactState::Idle,
476                        };
477                        *blas.built_index.write() = Some(build_command_index);
478                    }
479
480                    for tlas_build in build.tlas_s_built.iter() {
481                        for blas in &tlas_build.dependencies {
482                            if blas.built_index.read().is_none() {
483                                return Err(ValidateAsActionsError::UsedUnbuiltBlas(
484                                    blas.error_ident(),
485                                    tlas_build.tlas.error_ident(),
486                                ));
487                            }
488                        }
489                        *tlas_build.tlas.built_index.write() = Some(build_command_index);
490                        tlas_build
491                            .tlas
492                            .dependencies
493                            .write()
494                            .clone_from(&tlas_build.dependencies)
495                    }
496                }
497                AsAction::UseTlas(tlas) => {
498                    let tlas_build_index = tlas.built_index.read();
499                    let dependencies = tlas.dependencies.read();
500
501                    if (*tlas_build_index).is_none() {
502                        return Err(ValidateAsActionsError::UsedUnbuiltTlas(tlas.error_ident()));
503                    }
504                    for blas in dependencies.deref() {
505                        let blas_build_index = *blas.built_index.read();
506                        if blas_build_index.is_none() {
507                            return Err(ValidateAsActionsError::UsedUnbuiltBlas(
508                                tlas.error_ident(),
509                                blas.error_ident(),
510                            ));
511                        }
512                        if blas_build_index.unwrap() > tlas_build_index.unwrap() {
513                            return Err(ValidateAsActionsError::BlasNewerThenTlas(
514                                blas.error_ident(),
515                                tlas.error_ident(),
516                            ));
517                        }
518                        blas.try_raw(snatch_guard)?;
519                    }
520                }
521            }
522        }
523        Ok(())
524    }
525
526    pub(crate) fn set_acceleration_structure_dependencies(&self, snatch_guard: &SnatchGuard) {
527        profiling::scope!("CommandEncoder::[submission]::set_acceleration_structure_dependencies");
528        let tlas_dependencies_locks: Vec<_> = self
529            .as_actions
530            .iter()
531            .filter_map(|action| {
532                if let AsAction::UseTlas(tlas) = action {
533                    Some(tlas.dependencies.read())
534                } else {
535                    None
536                }
537            })
538            .collect();
539        let mut tlas_dependencies_lock_iter = tlas_dependencies_locks.iter();
540        let mut dependencies = Vec::new();
541        for action in &self.as_actions {
542            match action {
543                AsAction::Build(build) => {
544                    for tlas_build in build.tlas_s_built.iter() {
545                        for dependency in &tlas_build.dependencies {
546                            if let Some(dependency) = dependency.raw(snatch_guard) {
547                                dependencies.push(dependency);
548                            }
549                        }
550                    }
551                }
552                AsAction::UseTlas(_tlas) => {
553                    let tlas_dependencies = tlas_dependencies_lock_iter.next().unwrap(); // _tlas.dependencies.read();
554                    for dependency in tlas_dependencies.iter() {
555                        if let Some(dependency) = dependency.raw(snatch_guard) {
556                            dependencies.push(dependency);
557                        }
558                    }
559                }
560            }
561        }
562        if !dependencies.is_empty() {
563            unsafe {
564                self.encoder
565                    .raw
566                    .set_acceleration_structure_dependencies(&self.encoder.list, &dependencies);
567            }
568        }
569    }
570}
571
572///iterates over the blas iterator, and it's geometry, pushing the buffers into a storage vector (and also some validation).
573fn iter_blas<'snatch_guard: 'buffers, 'buffers>(
574    blas_iter: impl Iterator<Item = &'buffers OwnedBlasBuildEntry<ArcReferences>>,
575    build_command: &mut AsBuild,
576    input_barriers: &mut Vec<hal::BufferBarrier<'buffers, dyn hal::DynBuffer>>,
577    scratch_buffer_blas_size: &mut u64,
578    blas_storage: &mut Vec<BlasStore<'buffers>>,
579    state: &mut EncodingState<'snatch_guard, '_>,
580) -> Result<(), BuildAccelerationStructureError> {
581    for entry in blas_iter {
582        let blas = &entry.blas;
583        state.tracker.blas_s.insert_single(blas.clone());
584
585        build_command.blas_s_built.push(blas.clone());
586
587        match &entry.geometries {
588            ArcBlasGeometries::TriangleGeometries(triangle_geometries) => {
589                let mut triangle_entries =
590                    Vec::<hal::AccelerationStructureTriangles<dyn hal::DynBuffer>>::new();
591
592                for (i, mesh) in triangle_geometries.iter().enumerate() {
593                    let size_desc = match &blas.sizes {
594                        wgt::BlasGeometrySizeDescriptors::Triangles { descriptors } => descriptors,
595                    };
596                    if i >= size_desc.len() {
597                        return Err(BuildAccelerationStructureError::IncompatibleBlasBuildSizes(
598                            blas.error_ident(),
599                        ));
600                    }
601                    let size_desc = &size_desc[i];
602
603                    if size_desc.flags != mesh.size.flags {
604                        return Err(BuildAccelerationStructureError::IncompatibleBlasFlags(
605                            blas.error_ident(),
606                            size_desc.flags,
607                            mesh.size.flags,
608                        ));
609                    }
610
611                    if size_desc.vertex_count < mesh.size.vertex_count {
612                        return Err(
613                            BuildAccelerationStructureError::IncompatibleBlasVertexCount(
614                                blas.error_ident(),
615                                size_desc.vertex_count,
616                                mesh.size.vertex_count,
617                            ),
618                        );
619                    }
620
621                    if size_desc.vertex_format != mesh.size.vertex_format {
622                        return Err(BuildAccelerationStructureError::DifferentBlasVertexFormats(
623                            blas.error_ident(),
624                            size_desc.vertex_format,
625                            mesh.size.vertex_format,
626                        ));
627                    }
628
629                    if size_desc
630                        .vertex_format
631                        .min_acceleration_structure_vertex_stride()
632                        > mesh.vertex_stride
633                    {
634                        return Err(BuildAccelerationStructureError::VertexStrideTooSmall(
635                            blas.error_ident(),
636                            size_desc
637                                .vertex_format
638                                .min_acceleration_structure_vertex_stride(),
639                            mesh.vertex_stride,
640                        ));
641                    }
642
643                    if mesh.vertex_stride
644                        % size_desc
645                            .vertex_format
646                            .acceleration_structure_stride_alignment()
647                        != 0
648                    {
649                        return Err(BuildAccelerationStructureError::VertexStrideUnaligned(
650                            blas.error_ident(),
651                            size_desc
652                                .vertex_format
653                                .acceleration_structure_stride_alignment(),
654                            mesh.vertex_stride,
655                        ));
656                    }
657
658                    match (size_desc.index_count, mesh.size.index_count) {
659                        (Some(_), None) | (None, Some(_)) => {
660                            return Err(
661                                BuildAccelerationStructureError::BlasIndexCountProvidedMismatch(
662                                    blas.error_ident(),
663                                ),
664                            )
665                        }
666                        (Some(create), Some(build)) if create < build => {
667                            return Err(
668                                BuildAccelerationStructureError::IncompatibleBlasIndexCount(
669                                    blas.error_ident(),
670                                    create,
671                                    build,
672                                ),
673                            )
674                        }
675                        _ => {}
676                    }
677
678                    if size_desc.index_format != mesh.size.index_format {
679                        return Err(BuildAccelerationStructureError::DifferentBlasIndexFormats(
680                            blas.error_ident(),
681                            size_desc.index_format,
682                            mesh.size.index_format,
683                        ));
684                    }
685
686                    if size_desc.index_count.is_some() && mesh.index_buffer.is_none() {
687                        return Err(BuildAccelerationStructureError::MissingIndexBuffer(
688                            blas.error_ident(),
689                        ));
690                    }
691                    let vertex_buffer = mesh.vertex_buffer.clone();
692                    let vertex_pending = state.tracker.buffers.set_single(
693                        &vertex_buffer,
694                        BufferUses::BOTTOM_LEVEL_ACCELERATION_STRUCTURE_INPUT,
695                    );
696                    let vertex_buffer = {
697                        let vertex_raw = mesh.vertex_buffer.as_ref().try_raw(state.snatch_guard)?;
698                        let vertex_buffer = &mesh.vertex_buffer;
699                        vertex_buffer.check_usage(BufferUsages::BLAS_INPUT)?;
700
701                        if let Some(barrier) = vertex_pending.map(|pending| {
702                            pending.into_hal(vertex_buffer.as_ref(), state.snatch_guard)
703                        }) {
704                            input_barriers.push(barrier);
705                        }
706                        if vertex_buffer.size
707                            < (mesh.size.vertex_count + mesh.first_vertex) as u64
708                                * mesh.vertex_stride
709                        {
710                            return Err(BuildAccelerationStructureError::InsufficientBufferSize(
711                                vertex_buffer.error_ident(),
712                                vertex_buffer.size,
713                                (mesh.size.vertex_count + mesh.first_vertex) as u64
714                                    * mesh.vertex_stride,
715                            ));
716                        }
717                        let vertex_buffer_offset = mesh.first_vertex as u64 * mesh.vertex_stride;
718                        state.buffer_memory_init_actions.extend(
719                            vertex_buffer.initialization_status.read().create_action(
720                                vertex_buffer,
721                                vertex_buffer_offset
722                                    ..(vertex_buffer_offset
723                                        + mesh.size.vertex_count as u64 * mesh.vertex_stride),
724                                MemoryInitKind::NeedsInitializedMemory,
725                            ),
726                        );
727                        vertex_raw
728                    };
729                    let index_buffer = if let Some(ref index_buffer) = mesh.index_buffer {
730                        if mesh.first_index.is_none()
731                            || mesh.size.index_count.is_none()
732                            || mesh.size.index_count.is_none()
733                        {
734                            return Err(BuildAccelerationStructureError::MissingAssociatedData(
735                                index_buffer.error_ident(),
736                            ));
737                        }
738                        let index_pending = state.tracker.buffers.set_single(
739                            index_buffer,
740                            BufferUses::BOTTOM_LEVEL_ACCELERATION_STRUCTURE_INPUT,
741                        );
742                        let index_raw = index_buffer.try_raw(state.snatch_guard)?;
743                        index_buffer.check_usage(BufferUsages::BLAS_INPUT)?;
744
745                        if let Some(barrier) = index_pending.map(|pending| {
746                            pending.into_hal(index_buffer.as_ref(), state.snatch_guard)
747                        }) {
748                            input_barriers.push(barrier);
749                        }
750                        let index_stride = mesh.size.index_format.unwrap().byte_size() as u64;
751                        let offset = mesh.first_index.unwrap() as u64 * index_stride;
752                        let index_buffer_size =
753                            mesh.size.index_count.unwrap() as u64 * index_stride;
754
755                        if mesh.size.index_count.unwrap() % 3 != 0 {
756                            return Err(BuildAccelerationStructureError::InvalidIndexCount(
757                                index_buffer.error_ident(),
758                                mesh.size.index_count.unwrap(),
759                            ));
760                        }
761                        if index_buffer.size
762                            < mesh.size.index_count.unwrap() as u64 * index_stride + offset
763                        {
764                            return Err(BuildAccelerationStructureError::InsufficientBufferSize(
765                                index_buffer.error_ident(),
766                                index_buffer.size,
767                                mesh.size.index_count.unwrap() as u64 * index_stride + offset,
768                            ));
769                        }
770
771                        state.buffer_memory_init_actions.extend(
772                            index_buffer.initialization_status.read().create_action(
773                                index_buffer,
774                                offset..(offset + index_buffer_size),
775                                MemoryInitKind::NeedsInitializedMemory,
776                            ),
777                        );
778                        Some(index_raw)
779                    } else {
780                        None
781                    };
782                    let transform_buffer = if let Some(ref transform_buffer) = mesh.transform_buffer
783                    {
784                        if !blas
785                            .flags
786                            .contains(wgt::AccelerationStructureFlags::USE_TRANSFORM)
787                        {
788                            return Err(BuildAccelerationStructureError::UseTransformMissing(
789                                blas.error_ident(),
790                            ));
791                        }
792                        if mesh.transform_buffer_offset.is_none() {
793                            return Err(BuildAccelerationStructureError::MissingAssociatedData(
794                                transform_buffer.error_ident(),
795                            ));
796                        }
797                        let transform_pending = state.tracker.buffers.set_single(
798                            transform_buffer,
799                            BufferUses::BOTTOM_LEVEL_ACCELERATION_STRUCTURE_INPUT,
800                        );
801                        if mesh.transform_buffer_offset.is_none() {
802                            return Err(BuildAccelerationStructureError::MissingAssociatedData(
803                                transform_buffer.error_ident(),
804                            ));
805                        }
806                        let transform_raw = transform_buffer.try_raw(state.snatch_guard)?;
807                        transform_buffer.check_usage(BufferUsages::BLAS_INPUT)?;
808
809                        if let Some(barrier) = transform_pending.map(|pending| {
810                            pending.into_hal(transform_buffer.as_ref(), state.snatch_guard)
811                        }) {
812                            input_barriers.push(barrier);
813                        }
814
815                        let offset = mesh.transform_buffer_offset.unwrap();
816
817                        if offset % wgt::TRANSFORM_BUFFER_ALIGNMENT != 0 {
818                            return Err(
819                                BuildAccelerationStructureError::UnalignedTransformBufferOffset(
820                                    transform_buffer.error_ident(),
821                                ),
822                            );
823                        }
824                        if transform_buffer.size < 48 + offset {
825                            return Err(BuildAccelerationStructureError::InsufficientBufferSize(
826                                transform_buffer.error_ident(),
827                                transform_buffer.size,
828                                48 + offset,
829                            ));
830                        }
831                        state.buffer_memory_init_actions.extend(
832                            transform_buffer.initialization_status.read().create_action(
833                                transform_buffer,
834                                offset..(offset + 48),
835                                MemoryInitKind::NeedsInitializedMemory,
836                            ),
837                        );
838                        Some(transform_raw)
839                    } else {
840                        if blas
841                            .flags
842                            .contains(wgt::AccelerationStructureFlags::USE_TRANSFORM)
843                        {
844                            return Err(BuildAccelerationStructureError::TransformMissing(
845                                blas.error_ident(),
846                            ));
847                        }
848                        None
849                    };
850
851                    let triangles = hal::AccelerationStructureTriangles {
852                        vertex_buffer: Some(vertex_buffer),
853                        vertex_format: mesh.size.vertex_format,
854                        first_vertex: mesh.first_vertex,
855                        vertex_count: mesh.size.vertex_count,
856                        vertex_stride: mesh.vertex_stride,
857                        indices: index_buffer.map(|index_buffer| {
858                            let index_stride = mesh.size.index_format.unwrap().byte_size() as u32;
859                            hal::AccelerationStructureTriangleIndices::<dyn hal::DynBuffer> {
860                                format: mesh.size.index_format.unwrap(),
861                                buffer: Some(index_buffer),
862                                offset: mesh.first_index.unwrap() * index_stride,
863                                count: mesh.size.index_count.unwrap(),
864                            }
865                        }),
866                        transform: transform_buffer.map(|transform_buffer| {
867                            hal::AccelerationStructureTriangleTransform {
868                                buffer: transform_buffer,
869                                offset: mesh.transform_buffer_offset.unwrap() as u32,
870                            }
871                        }),
872                        flags: mesh.size.flags,
873                    };
874                    triangle_entries.push(triangles);
875                }
876
877                {
878                    let scratch_buffer_offset = *scratch_buffer_blas_size;
879                    *scratch_buffer_blas_size += align_to(
880                        blas.size_info.build_scratch_size as u32,
881                        state.device.alignments.ray_tracing_scratch_buffer_alignment,
882                    ) as u64;
883
884                    blas_storage.push(BlasStore {
885                        blas: blas.clone(),
886                        entries: hal::AccelerationStructureEntries::Triangles(triangle_entries),
887                        scratch_buffer_offset,
888                    });
889                }
890            }
891        }
892    }
893    Ok(())
894}
895
896fn map_blas<'a>(
897    storage: &'a BlasStore<'_>,
898    scratch_buffer: &'a dyn hal::DynBuffer,
899    snatch_guard: &'a SnatchGuard,
900    blases_compactable: &mut Vec<(
901        &'a dyn hal::DynBuffer,
902        &'a dyn hal::DynAccelerationStructure,
903    )>,
904) -> Result<
905    hal::BuildAccelerationStructureDescriptor<
906        'a,
907        dyn hal::DynBuffer,
908        dyn hal::DynAccelerationStructure,
909    >,
910    BuildAccelerationStructureError,
911> {
912    let BlasStore {
913        blas,
914        entries,
915        scratch_buffer_offset,
916    } = storage;
917    if blas.update_mode == wgt::AccelerationStructureUpdateMode::PreferUpdate {
918        log::debug!("only rebuild implemented")
919    }
920    let raw = blas.try_raw(snatch_guard)?;
921
922    let state_lock = blas.compacted_state.lock();
923    if let BlasCompactState::Compacted = *state_lock {
924        return Err(BuildAccelerationStructureError::CompactedBlas(
925            blas.error_ident(),
926        ));
927    }
928
929    if blas
930        .flags
931        .contains(wgpu_types::AccelerationStructureFlags::ALLOW_COMPACTION)
932    {
933        blases_compactable.push((blas.compaction_buffer.as_ref().unwrap().as_ref(), raw));
934    }
935    Ok(hal::BuildAccelerationStructureDescriptor {
936        entries,
937        mode: hal::AccelerationStructureBuildMode::Build,
938        flags: blas.flags,
939        source_acceleration_structure: None,
940        destination_acceleration_structure: raw,
941        scratch_buffer,
942        scratch_buffer_offset: *scratch_buffer_offset,
943    })
944}
945
946fn build_blas<'a>(
947    cmd_buf_raw: &mut dyn hal::DynCommandEncoder,
948    blas_present: bool,
949    tlas_present: bool,
950    input_barriers: Vec<hal::BufferBarrier<dyn hal::DynBuffer>>,
951    blas_descriptors: &[hal::BuildAccelerationStructureDescriptor<
952        'a,
953        dyn hal::DynBuffer,
954        dyn hal::DynAccelerationStructure,
955    >],
956    scratch_buffer_barrier: hal::BufferBarrier<dyn hal::DynBuffer>,
957    blas_s_for_compaction: Vec<(
958        &'a dyn hal::DynBuffer,
959        &'a dyn hal::DynAccelerationStructure,
960    )>,
961) {
962    unsafe {
963        cmd_buf_raw.transition_buffers(&input_barriers);
964    }
965
966    if blas_present {
967        unsafe {
968            cmd_buf_raw.place_acceleration_structure_barrier(hal::AccelerationStructureBarrier {
969                usage: hal::StateTransition {
970                    from: hal::AccelerationStructureUses::BUILD_INPUT,
971                    to: hal::AccelerationStructureUses::BUILD_OUTPUT,
972                },
973            });
974
975            cmd_buf_raw.build_acceleration_structures(blas_descriptors);
976        }
977    }
978
979    if blas_present && tlas_present {
980        unsafe {
981            cmd_buf_raw.transition_buffers(&[scratch_buffer_barrier]);
982        }
983    }
984
985    let mut source_usage = hal::AccelerationStructureUses::empty();
986    let mut destination_usage = hal::AccelerationStructureUses::empty();
987    for &(buf, blas) in blas_s_for_compaction.iter() {
988        unsafe {
989            cmd_buf_raw.transition_buffers(&[hal::BufferBarrier {
990                buffer: buf,
991                usage: hal::StateTransition {
992                    from: BufferUses::ACCELERATION_STRUCTURE_QUERY,
993                    to: BufferUses::ACCELERATION_STRUCTURE_QUERY,
994                },
995            }])
996        }
997        unsafe { cmd_buf_raw.read_acceleration_structure_compact_size(blas, buf) }
998        destination_usage |= hal::AccelerationStructureUses::COPY_SRC;
999    }
1000
1001    if blas_present {
1002        source_usage |= hal::AccelerationStructureUses::BUILD_OUTPUT;
1003        destination_usage |= hal::AccelerationStructureUses::BUILD_INPUT
1004    }
1005    if tlas_present {
1006        source_usage |= hal::AccelerationStructureUses::SHADER_INPUT;
1007        destination_usage |= hal::AccelerationStructureUses::BUILD_OUTPUT;
1008    }
1009    unsafe {
1010        cmd_buf_raw.place_acceleration_structure_barrier(hal::AccelerationStructureBarrier {
1011            usage: hal::StateTransition {
1012                from: source_usage,
1013                to: destination_usage,
1014            },
1015        });
1016    }
1017}