wgpu_core/command/
ray_tracing.rs

1use alloc::{sync::Arc, vec::Vec};
2use core::{
3    cmp::max,
4    num::NonZeroU64,
5    ops::{Deref, Range},
6};
7
8use wgt::{math::align_to, BufferUsages, BufferUses, Features};
9
10use crate::{
11    command::encoder::EncodingState,
12    ray_tracing::{AsAction, AsBuild, BlasTriangleGeometryInfo, TlasBuild, ValidateAsActionsError},
13    resource::InvalidResourceError,
14    track::Tracker,
15};
16use crate::{command::EncoderStateError, device::resource::CommandIndices};
17use crate::{
18    command::{ArcCommand, ArcReferences, CommandBufferMutable},
19    device::queue::TempResource,
20    global::Global,
21    id::CommandEncoderId,
22    init_tracker::MemoryInitKind,
23    ray_tracing::{
24        ArcBlasBuildEntry, ArcBlasGeometries, ArcBlasTriangleGeometry, ArcTlasInstance,
25        ArcTlasPackage, BlasBuildEntry, BlasGeometries, BuildAccelerationStructureError,
26        OwnedBlasBuildEntry, OwnedTlasPackage, TlasPackage,
27    },
28    resource::{Blas, BlasCompactState, Buffer, Labeled, StagingBuffer, Tlas},
29    scratch::ScratchBuffer,
30    snatch::SnatchGuard,
31    track::PendingTransition,
32};
33use crate::{lock::RwLockWriteGuard, resource::RawResourceAccess};
34
35use crate::id::{BlasId, TlasId};
36
37struct TriangleBufferStore {
38    vertex_buffer: Arc<Buffer>,
39    vertex_transition: Option<PendingTransition<BufferUses>>,
40    index_buffer_transition: Option<(Arc<Buffer>, Option<PendingTransition<BufferUses>>)>,
41    transform_buffer_transition: Option<(Arc<Buffer>, Option<PendingTransition<BufferUses>>)>,
42    geometry: BlasTriangleGeometryInfo,
43    ending_blas: Option<Arc<Blas>>,
44}
45
46struct BlasStore<'a> {
47    blas: Arc<Blas>,
48    entries: hal::AccelerationStructureEntries<'a, dyn hal::DynBuffer>,
49    scratch_buffer_offset: u64,
50}
51
52struct UnsafeTlasStore<'a> {
53    tlas: Arc<Tlas>,
54    entries: hal::AccelerationStructureEntries<'a, dyn hal::DynBuffer>,
55    scratch_buffer_offset: u64,
56}
57
58struct TlasStore<'a> {
59    internal: UnsafeTlasStore<'a>,
60    range: Range<usize>,
61}
62
63impl Global {
64    fn resolve_blas_id(&self, blas_id: BlasId) -> Result<Arc<Blas>, InvalidResourceError> {
65        self.hub.blas_s.get(blas_id).get()
66    }
67
68    fn resolve_tlas_id(&self, tlas_id: TlasId) -> Result<Arc<Tlas>, InvalidResourceError> {
69        self.hub.tlas_s.get(tlas_id).get()
70    }
71
72    pub fn command_encoder_mark_acceleration_structures_built(
73        &self,
74        command_encoder_id: CommandEncoderId,
75        blas_ids: &[BlasId],
76        tlas_ids: &[TlasId],
77    ) -> Result<(), EncoderStateError> {
78        profiling::scope!("CommandEncoder::mark_acceleration_structures_built");
79
80        let hub = &self.hub;
81
82        let cmd_enc = hub.command_encoders.get(command_encoder_id);
83
84        let mut cmd_buf_data = cmd_enc.data.lock();
85        cmd_buf_data.with_buffer(
86            crate::command::EncodingApi::Raw,
87            |cmd_buf_data| -> Result<(), BuildAccelerationStructureError> {
88                let device = &cmd_enc.device;
89                device.check_is_valid()?;
90                device.require_features(Features::EXPERIMENTAL_RAY_QUERY)?;
91
92                let mut build_command = AsBuild::default();
93
94                for blas in blas_ids {
95                    let blas = hub.blas_s.get(*blas).get()?;
96                    build_command.blas_s_built.push(blas);
97                }
98
99                for tlas in tlas_ids {
100                    let tlas = hub.tlas_s.get(*tlas).get()?;
101                    build_command.tlas_s_built.push(TlasBuild {
102                        tlas,
103                        dependencies: Vec::new(),
104                    });
105                }
106
107                cmd_buf_data.as_actions.push(AsAction::Build(build_command));
108                Ok(())
109            },
110        )
111    }
112
113    pub fn command_encoder_build_acceleration_structures<'a>(
114        &self,
115        command_encoder_id: CommandEncoderId,
116        blas_iter: impl Iterator<Item = BlasBuildEntry<'a>>,
117        tlas_iter: impl Iterator<Item = TlasPackage<'a>>,
118    ) -> Result<(), EncoderStateError> {
119        profiling::scope!("CommandEncoder::build_acceleration_structures");
120
121        let hub = &self.hub;
122
123        let cmd_enc = hub.command_encoders.get(command_encoder_id);
124        let mut cmd_buf_data = cmd_enc.data.lock();
125
126        cmd_buf_data.push_with(|| -> Result<_, BuildAccelerationStructureError> {
127            let blas = blas_iter
128                .map(|blas_entry| {
129                    let geometries = match blas_entry.geometries {
130                        BlasGeometries::TriangleGeometries(triangle_geometries) => {
131                            let tri_geo = triangle_geometries
132                                .map(|tg| {
133                                    Ok(ArcBlasTriangleGeometry {
134                                        size: tg.size.clone(),
135                                        vertex_buffer: self.resolve_buffer_id(tg.vertex_buffer)?,
136                                        index_buffer: tg
137                                            .index_buffer
138                                            .map(|id| self.resolve_buffer_id(id))
139                                            .transpose()?,
140                                        transform_buffer: tg
141                                            .transform_buffer
142                                            .map(|id| self.resolve_buffer_id(id))
143                                            .transpose()?,
144                                        first_vertex: tg.first_vertex,
145                                        vertex_stride: tg.vertex_stride,
146                                        first_index: tg.first_index,
147                                        transform_buffer_offset: tg.transform_buffer_offset,
148                                    })
149                                })
150                                .collect::<Result<_, BuildAccelerationStructureError>>()?;
151                            ArcBlasGeometries::TriangleGeometries(tri_geo)
152                        }
153                    };
154                    Ok(ArcBlasBuildEntry {
155                        blas: self.resolve_blas_id(blas_entry.blas_id)?,
156                        geometries,
157                    })
158                })
159                .collect::<Result<_, BuildAccelerationStructureError>>()?;
160
161            let tlas = tlas_iter
162                .map(|tlas_package| {
163                    let instances = tlas_package
164                        .instances
165                        .map(|instance| {
166                            instance
167                                .as_ref()
168                                .map(|instance| {
169                                    Ok(ArcTlasInstance {
170                                        blas: self.resolve_blas_id(instance.blas_id)?,
171                                        transform: *instance.transform,
172                                        custom_data: instance.custom_data,
173                                        mask: instance.mask,
174                                    })
175                                })
176                                .transpose()
177                        })
178                        .collect::<Result<_, BuildAccelerationStructureError>>()?;
179                    Ok(ArcTlasPackage {
180                        tlas: self.resolve_tlas_id(tlas_package.tlas_id)?,
181                        instances,
182                        lowest_unmodified: tlas_package.lowest_unmodified,
183                    })
184                })
185                .collect::<Result<_, BuildAccelerationStructureError>>()?;
186
187            Ok(ArcCommand::BuildAccelerationStructures { blas, tlas })
188        })
189    }
190}
191
192pub(crate) fn build_acceleration_structures(
193    state: &mut EncodingState,
194    blas: Vec<OwnedBlasBuildEntry<ArcReferences>>,
195    tlas: Vec<OwnedTlasPackage<ArcReferences>>,
196) -> Result<(), BuildAccelerationStructureError> {
197    state
198        .device
199        .require_features(Features::EXPERIMENTAL_RAY_QUERY)?;
200
201    let mut build_command = AsBuild::default();
202    let mut buf_storage = Vec::new();
203    iter_blas(
204        blas.into_iter(),
205        state.tracker,
206        &mut build_command,
207        &mut buf_storage,
208    )?;
209
210    let mut input_barriers = Vec::<hal::BufferBarrier<dyn hal::DynBuffer>>::new();
211    let mut scratch_buffer_blas_size = 0;
212    let mut blas_storage = Vec::new();
213    iter_buffers(
214        state,
215        &mut buf_storage,
216        &mut input_barriers,
217        &mut scratch_buffer_blas_size,
218        &mut blas_storage,
219    )?;
220    let mut tlas_lock_store = Vec::<(Option<OwnedTlasPackage<ArcReferences>>, Arc<Tlas>)>::new();
221
222    for package in tlas.into_iter() {
223        let tlas = package.tlas.clone();
224        state.tracker.tlas_s.insert_single(tlas.clone());
225        tlas_lock_store.push((Some(package), tlas))
226    }
227
228    let mut scratch_buffer_tlas_size = 0;
229    let mut tlas_storage = Vec::<TlasStore>::new();
230    let mut instance_buffer_staging_source = Vec::<u8>::new();
231
232    for (package, tlas) in &mut tlas_lock_store {
233        let package = package.take().unwrap();
234
235        let scratch_buffer_offset = scratch_buffer_tlas_size;
236        scratch_buffer_tlas_size += align_to(
237            tlas.size_info.build_scratch_size as u32,
238            state.device.alignments.ray_tracing_scratch_buffer_alignment,
239        ) as u64;
240
241        let first_byte_index = instance_buffer_staging_source.len();
242
243        let mut dependencies = Vec::new();
244
245        let mut instance_count = 0;
246        for instance in package.instances.into_iter().flatten() {
247            if instance.custom_data >= (1u32 << 24u32) {
248                return Err(BuildAccelerationStructureError::TlasInvalidCustomIndex(
249                    tlas.error_ident(),
250                ));
251            }
252            let blas = &instance.blas;
253            state.tracker.blas_s.insert_single(blas.clone());
254
255            instance_buffer_staging_source.extend(state.device.raw().tlas_instance_to_bytes(
256                hal::TlasInstance {
257                    transform: instance.transform,
258                    custom_data: instance.custom_data,
259                    mask: instance.mask,
260                    blas_address: blas.handle,
261                },
262            ));
263
264            if tlas
265                .flags
266                .contains(wgpu_types::AccelerationStructureFlags::ALLOW_RAY_HIT_VERTEX_RETURN)
267                && !blas
268                    .flags
269                    .contains(wgpu_types::AccelerationStructureFlags::ALLOW_RAY_HIT_VERTEX_RETURN)
270            {
271                return Err(
272                    BuildAccelerationStructureError::TlasDependentMissingVertexReturn(
273                        tlas.error_ident(),
274                        blas.error_ident(),
275                    ),
276                );
277            }
278
279            instance_count += 1;
280
281            dependencies.push(blas.clone());
282        }
283
284        build_command.tlas_s_built.push(TlasBuild {
285            tlas: tlas.clone(),
286            dependencies,
287        });
288
289        if instance_count > tlas.max_instance_count {
290            return Err(BuildAccelerationStructureError::TlasInstanceCountExceeded(
291                tlas.error_ident(),
292                instance_count,
293                tlas.max_instance_count,
294            ));
295        }
296
297        tlas_storage.push(TlasStore {
298            internal: UnsafeTlasStore {
299                tlas: tlas.clone(),
300                entries: hal::AccelerationStructureEntries::Instances(
301                    hal::AccelerationStructureInstances {
302                        buffer: Some(tlas.instance_buffer.as_ref()),
303                        offset: 0,
304                        count: instance_count,
305                    },
306                ),
307                scratch_buffer_offset,
308            },
309            range: first_byte_index..instance_buffer_staging_source.len(),
310        });
311    }
312
313    let Some(scratch_size) =
314        wgt::BufferSize::new(max(scratch_buffer_blas_size, scratch_buffer_tlas_size))
315    else {
316        // if the size is zero there is nothing to build
317        return Ok(());
318    };
319
320    let scratch_buffer = ScratchBuffer::new(state.device, scratch_size)?;
321
322    let scratch_buffer_barrier = hal::BufferBarrier::<dyn hal::DynBuffer> {
323        buffer: scratch_buffer.raw(),
324        usage: hal::StateTransition {
325            from: BufferUses::ACCELERATION_STRUCTURE_SCRATCH,
326            to: BufferUses::ACCELERATION_STRUCTURE_SCRATCH,
327        },
328    };
329
330    let mut tlas_descriptors = Vec::with_capacity(tlas_storage.len());
331
332    for &TlasStore {
333        internal:
334            UnsafeTlasStore {
335                ref tlas,
336                ref entries,
337                ref scratch_buffer_offset,
338            },
339        ..
340    } in &tlas_storage
341    {
342        if tlas.update_mode == wgt::AccelerationStructureUpdateMode::PreferUpdate {
343            log::info!("only rebuild implemented")
344        }
345        tlas_descriptors.push(hal::BuildAccelerationStructureDescriptor {
346            entries,
347            mode: hal::AccelerationStructureBuildMode::Build,
348            flags: tlas.flags,
349            source_acceleration_structure: None,
350            destination_acceleration_structure: tlas.try_raw(state.snatch_guard)?,
351            scratch_buffer: scratch_buffer.raw(),
352            scratch_buffer_offset: *scratch_buffer_offset,
353        })
354    }
355
356    let blas_present = !blas_storage.is_empty();
357    let tlas_present = !tlas_storage.is_empty();
358
359    let raw_encoder = &mut state.raw_encoder;
360
361    let mut blas_s_compactable = Vec::new();
362    let mut descriptors = Vec::new();
363
364    for storage in &blas_storage {
365        descriptors.push(map_blas(
366            storage,
367            scratch_buffer.raw(),
368            state.snatch_guard,
369            &mut blas_s_compactable,
370        )?);
371    }
372
373    build_blas(
374        *raw_encoder,
375        blas_present,
376        tlas_present,
377        input_barriers,
378        &descriptors,
379        scratch_buffer_barrier,
380        blas_s_compactable,
381    );
382
383    if tlas_present {
384        let staging_buffer = if !instance_buffer_staging_source.is_empty() {
385            let mut staging_buffer = StagingBuffer::new(
386                state.device,
387                wgt::BufferSize::new(instance_buffer_staging_source.len() as u64).unwrap(),
388            )?;
389            staging_buffer.write(&instance_buffer_staging_source);
390            let flushed = staging_buffer.flush();
391            Some(flushed)
392        } else {
393            None
394        };
395
396        unsafe {
397            if let Some(ref staging_buffer) = staging_buffer {
398                raw_encoder.transition_buffers(&[hal::BufferBarrier::<dyn hal::DynBuffer> {
399                    buffer: staging_buffer.raw(),
400                    usage: hal::StateTransition {
401                        from: BufferUses::MAP_WRITE,
402                        to: BufferUses::COPY_SRC,
403                    },
404                }]);
405            }
406        }
407
408        let mut instance_buffer_barriers = Vec::new();
409        for &TlasStore {
410            internal: UnsafeTlasStore { ref tlas, .. },
411            ref range,
412        } in &tlas_storage
413        {
414            let size = match wgt::BufferSize::new((range.end - range.start) as u64) {
415                None => continue,
416                Some(size) => size,
417            };
418            instance_buffer_barriers.push(hal::BufferBarrier::<dyn hal::DynBuffer> {
419                buffer: tlas.instance_buffer.as_ref(),
420                usage: hal::StateTransition {
421                    from: BufferUses::COPY_DST,
422                    to: BufferUses::TOP_LEVEL_ACCELERATION_STRUCTURE_INPUT,
423                },
424            });
425            unsafe {
426                raw_encoder.transition_buffers(&[hal::BufferBarrier::<dyn hal::DynBuffer> {
427                    buffer: tlas.instance_buffer.as_ref(),
428                    usage: hal::StateTransition {
429                        from: BufferUses::TOP_LEVEL_ACCELERATION_STRUCTURE_INPUT,
430                        to: BufferUses::COPY_DST,
431                    },
432                }]);
433                let temp = hal::BufferCopy {
434                    src_offset: range.start as u64,
435                    dst_offset: 0,
436                    size,
437                };
438                raw_encoder.copy_buffer_to_buffer(
439                    staging_buffer.as_ref().unwrap().raw(),
440                    tlas.instance_buffer.as_ref(),
441                    &[temp],
442                );
443            }
444        }
445
446        unsafe {
447            raw_encoder.transition_buffers(&instance_buffer_barriers);
448
449            raw_encoder.build_acceleration_structures(&tlas_descriptors);
450
451            raw_encoder.place_acceleration_structure_barrier(hal::AccelerationStructureBarrier {
452                usage: hal::StateTransition {
453                    from: hal::AccelerationStructureUses::BUILD_OUTPUT,
454                    to: hal::AccelerationStructureUses::SHADER_INPUT,
455                },
456            });
457        }
458
459        if let Some(staging_buffer) = staging_buffer {
460            state
461                .temp_resources
462                .push(TempResource::StagingBuffer(staging_buffer));
463        }
464    }
465
466    state
467        .temp_resources
468        .push(TempResource::ScratchBuffer(scratch_buffer));
469
470    state.as_actions.push(AsAction::Build(build_command));
471
472    Ok(())
473}
474
475impl CommandBufferMutable {
476    pub(crate) fn validate_acceleration_structure_actions(
477        &self,
478        snatch_guard: &SnatchGuard,
479        command_index_guard: &mut RwLockWriteGuard<CommandIndices>,
480    ) -> Result<(), ValidateAsActionsError> {
481        profiling::scope!("CommandEncoder::[submission]::validate_as_actions");
482        for action in &self.as_actions {
483            match action {
484                AsAction::Build(build) => {
485                    let build_command_index = NonZeroU64::new(
486                        command_index_guard.next_acceleration_structure_build_command_index,
487                    )
488                    .unwrap();
489
490                    command_index_guard.next_acceleration_structure_build_command_index += 1;
491                    for blas in build.blas_s_built.iter() {
492                        let mut state_lock = blas.compacted_state.lock();
493                        *state_lock = match *state_lock {
494                            BlasCompactState::Compacted => {
495                                unreachable!("Should be validated out in build.")
496                            }
497                            // Reset the compacted state to idle. This means any prepares, before mapping their
498                            // internal buffer, will terminate.
499                            _ => BlasCompactState::Idle,
500                        };
501                        *blas.built_index.write() = Some(build_command_index);
502                    }
503
504                    for tlas_build in build.tlas_s_built.iter() {
505                        for blas in &tlas_build.dependencies {
506                            if blas.built_index.read().is_none() {
507                                return Err(ValidateAsActionsError::UsedUnbuiltBlas(
508                                    blas.error_ident(),
509                                    tlas_build.tlas.error_ident(),
510                                ));
511                            }
512                        }
513                        *tlas_build.tlas.built_index.write() = Some(build_command_index);
514                        tlas_build
515                            .tlas
516                            .dependencies
517                            .write()
518                            .clone_from(&tlas_build.dependencies)
519                    }
520                }
521                AsAction::UseTlas(tlas) => {
522                    let tlas_build_index = tlas.built_index.read();
523                    let dependencies = tlas.dependencies.read();
524
525                    if (*tlas_build_index).is_none() {
526                        return Err(ValidateAsActionsError::UsedUnbuiltTlas(tlas.error_ident()));
527                    }
528                    for blas in dependencies.deref() {
529                        let blas_build_index = *blas.built_index.read();
530                        if blas_build_index.is_none() {
531                            return Err(ValidateAsActionsError::UsedUnbuiltBlas(
532                                tlas.error_ident(),
533                                blas.error_ident(),
534                            ));
535                        }
536                        if blas_build_index.unwrap() > tlas_build_index.unwrap() {
537                            return Err(ValidateAsActionsError::BlasNewerThenTlas(
538                                blas.error_ident(),
539                                tlas.error_ident(),
540                            ));
541                        }
542                        blas.try_raw(snatch_guard)?;
543                    }
544                }
545            }
546        }
547        Ok(())
548    }
549}
550
551///iterates over the blas iterator, and it's geometry, pushing the buffers into a storage vector (and also some validation).
552fn iter_blas(
553    blas_iter: impl Iterator<Item = OwnedBlasBuildEntry<ArcReferences>>,
554    tracker: &mut Tracker,
555    build_command: &mut AsBuild,
556    buf_storage: &mut Vec<TriangleBufferStore>,
557) -> Result<(), BuildAccelerationStructureError> {
558    let mut temp_buffer = Vec::new();
559    for entry in blas_iter {
560        let blas = &entry.blas;
561        tracker.blas_s.insert_single(blas.clone());
562
563        build_command.blas_s_built.push(blas.clone());
564
565        match entry.geometries {
566            ArcBlasGeometries::TriangleGeometries(triangle_geometries) => {
567                for (i, mesh) in triangle_geometries.into_iter().enumerate() {
568                    let size_desc = match &blas.sizes {
569                        wgt::BlasGeometrySizeDescriptors::Triangles { descriptors } => descriptors,
570                    };
571                    if i >= size_desc.len() {
572                        return Err(BuildAccelerationStructureError::IncompatibleBlasBuildSizes(
573                            blas.error_ident(),
574                        ));
575                    }
576                    let size_desc = &size_desc[i];
577
578                    if size_desc.flags != mesh.size.flags {
579                        return Err(BuildAccelerationStructureError::IncompatibleBlasFlags(
580                            blas.error_ident(),
581                            size_desc.flags,
582                            mesh.size.flags,
583                        ));
584                    }
585
586                    if size_desc.vertex_count < mesh.size.vertex_count {
587                        return Err(
588                            BuildAccelerationStructureError::IncompatibleBlasVertexCount(
589                                blas.error_ident(),
590                                size_desc.vertex_count,
591                                mesh.size.vertex_count,
592                            ),
593                        );
594                    }
595
596                    if size_desc.vertex_format != mesh.size.vertex_format {
597                        return Err(BuildAccelerationStructureError::DifferentBlasVertexFormats(
598                            blas.error_ident(),
599                            size_desc.vertex_format,
600                            mesh.size.vertex_format,
601                        ));
602                    }
603
604                    if size_desc
605                        .vertex_format
606                        .min_acceleration_structure_vertex_stride()
607                        > mesh.vertex_stride
608                    {
609                        return Err(BuildAccelerationStructureError::VertexStrideTooSmall(
610                            blas.error_ident(),
611                            size_desc
612                                .vertex_format
613                                .min_acceleration_structure_vertex_stride(),
614                            mesh.vertex_stride,
615                        ));
616                    }
617
618                    if mesh.vertex_stride
619                        % size_desc
620                            .vertex_format
621                            .acceleration_structure_stride_alignment()
622                        != 0
623                    {
624                        return Err(BuildAccelerationStructureError::VertexStrideUnaligned(
625                            blas.error_ident(),
626                            size_desc
627                                .vertex_format
628                                .acceleration_structure_stride_alignment(),
629                            mesh.vertex_stride,
630                        ));
631                    }
632
633                    match (size_desc.index_count, mesh.size.index_count) {
634                        (Some(_), None) | (None, Some(_)) => {
635                            return Err(
636                                BuildAccelerationStructureError::BlasIndexCountProvidedMismatch(
637                                    blas.error_ident(),
638                                ),
639                            )
640                        }
641                        (Some(create), Some(build)) if create < build => {
642                            return Err(
643                                BuildAccelerationStructureError::IncompatibleBlasIndexCount(
644                                    blas.error_ident(),
645                                    create,
646                                    build,
647                                ),
648                            )
649                        }
650                        _ => {}
651                    }
652
653                    if size_desc.index_format != mesh.size.index_format {
654                        return Err(BuildAccelerationStructureError::DifferentBlasIndexFormats(
655                            blas.error_ident(),
656                            size_desc.index_format,
657                            mesh.size.index_format,
658                        ));
659                    }
660
661                    if size_desc.index_count.is_some() && mesh.index_buffer.is_none() {
662                        return Err(BuildAccelerationStructureError::MissingIndexBuffer(
663                            blas.error_ident(),
664                        ));
665                    }
666                    let vertex_buffer = mesh.vertex_buffer.clone();
667                    let vertex_pending = tracker.buffers.set_single(
668                        &vertex_buffer,
669                        BufferUses::BOTTOM_LEVEL_ACCELERATION_STRUCTURE_INPUT,
670                    );
671                    let index_data = if let Some(index_buffer) = mesh.index_buffer {
672                        if mesh.first_index.is_none()
673                            || mesh.size.index_count.is_none()
674                            || mesh.size.index_count.is_none()
675                        {
676                            return Err(BuildAccelerationStructureError::MissingAssociatedData(
677                                index_buffer.error_ident(),
678                            ));
679                        }
680                        let data = tracker.buffers.set_single(
681                            &index_buffer,
682                            BufferUses::BOTTOM_LEVEL_ACCELERATION_STRUCTURE_INPUT,
683                        );
684                        Some((index_buffer, data))
685                    } else {
686                        None
687                    };
688                    let transform_data = if let Some(transform_buffer) = mesh.transform_buffer {
689                        if !blas
690                            .flags
691                            .contains(wgt::AccelerationStructureFlags::USE_TRANSFORM)
692                        {
693                            return Err(BuildAccelerationStructureError::UseTransformMissing(
694                                blas.error_ident(),
695                            ));
696                        }
697                        if mesh.transform_buffer_offset.is_none() {
698                            return Err(BuildAccelerationStructureError::MissingAssociatedData(
699                                transform_buffer.error_ident(),
700                            ));
701                        }
702                        let data = tracker.buffers.set_single(
703                            &transform_buffer,
704                            BufferUses::BOTTOM_LEVEL_ACCELERATION_STRUCTURE_INPUT,
705                        );
706                        Some((transform_buffer, data))
707                    } else {
708                        if blas
709                            .flags
710                            .contains(wgt::AccelerationStructureFlags::USE_TRANSFORM)
711                        {
712                            return Err(BuildAccelerationStructureError::TransformMissing(
713                                blas.error_ident(),
714                            ));
715                        }
716                        None
717                    };
718                    temp_buffer.push(TriangleBufferStore {
719                        vertex_buffer,
720                        vertex_transition: vertex_pending,
721                        index_buffer_transition: index_data,
722                        transform_buffer_transition: transform_data,
723                        geometry: BlasTriangleGeometryInfo {
724                            size: mesh.size,
725                            first_vertex: mesh.first_vertex,
726                            vertex_stride: mesh.vertex_stride,
727                            first_index: mesh.first_index,
728                            transform_buffer_offset: mesh.transform_buffer_offset,
729                        },
730                        ending_blas: None,
731                    });
732                }
733
734                if let Some(last) = temp_buffer.last_mut() {
735                    last.ending_blas = Some(blas.clone());
736                    buf_storage.append(&mut temp_buffer);
737                }
738            }
739        }
740    }
741    Ok(())
742}
743
744/// Iterates over the buffers generated in [iter_blas], convert the barriers into hal barriers, and the triangles into [hal::AccelerationStructureEntries] (and also some validation).
745///
746/// `'buffers` is the lifetime of `&dyn hal::DynBuffer` in our working data,
747/// i.e., needs to span until `build_acceleration_structures` finishes encoding.
748/// `'snatch_guard` is the lifetime of the snatch lock acquisition.
749fn iter_buffers<'snatch_guard: 'buffers, 'buffers>(
750    state: &mut EncodingState<'snatch_guard, '_>,
751    buf_storage: &'buffers mut Vec<TriangleBufferStore>,
752    input_barriers: &mut Vec<hal::BufferBarrier<'buffers, dyn hal::DynBuffer>>,
753    scratch_buffer_blas_size: &mut u64,
754    blas_storage: &mut Vec<BlasStore<'buffers>>,
755) -> Result<(), BuildAccelerationStructureError> {
756    let mut triangle_entries =
757        Vec::<hal::AccelerationStructureTriangles<dyn hal::DynBuffer>>::new();
758    for buf in buf_storage {
759        let mesh = &buf.geometry;
760        let vertex_buffer = {
761            let vertex_raw = buf.vertex_buffer.as_ref().try_raw(state.snatch_guard)?;
762            let vertex_buffer = &buf.vertex_buffer;
763            vertex_buffer.check_usage(BufferUsages::BLAS_INPUT)?;
764
765            if let Some(barrier) = buf
766                .vertex_transition
767                .take()
768                .map(|pending| pending.into_hal(buf.vertex_buffer.as_ref(), state.snatch_guard))
769            {
770                input_barriers.push(barrier);
771            }
772            if vertex_buffer.size
773                < (mesh.size.vertex_count + mesh.first_vertex) as u64 * mesh.vertex_stride
774            {
775                return Err(BuildAccelerationStructureError::InsufficientBufferSize(
776                    vertex_buffer.error_ident(),
777                    vertex_buffer.size,
778                    (mesh.size.vertex_count + mesh.first_vertex) as u64 * mesh.vertex_stride,
779                ));
780            }
781            let vertex_buffer_offset = mesh.first_vertex as u64 * mesh.vertex_stride;
782            state.buffer_memory_init_actions.extend(
783                vertex_buffer.initialization_status.read().create_action(
784                    vertex_buffer,
785                    vertex_buffer_offset
786                        ..(vertex_buffer_offset
787                            + mesh.size.vertex_count as u64 * mesh.vertex_stride),
788                    MemoryInitKind::NeedsInitializedMemory,
789                ),
790            );
791            vertex_raw
792        };
793        let index_buffer = if let Some((ref mut index_buffer, ref mut index_pending)) =
794            buf.index_buffer_transition
795        {
796            let index_raw = index_buffer.try_raw(state.snatch_guard)?;
797            index_buffer.check_usage(BufferUsages::BLAS_INPUT)?;
798
799            if let Some(barrier) = index_pending
800                .take()
801                .map(|pending| pending.into_hal(index_buffer, state.snatch_guard))
802            {
803                input_barriers.push(barrier);
804            }
805            let index_stride = mesh.size.index_format.unwrap().byte_size() as u64;
806            let offset = mesh.first_index.unwrap() as u64 * index_stride;
807            let index_buffer_size = mesh.size.index_count.unwrap() as u64 * index_stride;
808
809            if mesh.size.index_count.unwrap() % 3 != 0 {
810                return Err(BuildAccelerationStructureError::InvalidIndexCount(
811                    index_buffer.error_ident(),
812                    mesh.size.index_count.unwrap(),
813                ));
814            }
815            if index_buffer.size < mesh.size.index_count.unwrap() as u64 * index_stride + offset {
816                return Err(BuildAccelerationStructureError::InsufficientBufferSize(
817                    index_buffer.error_ident(),
818                    index_buffer.size,
819                    mesh.size.index_count.unwrap() as u64 * index_stride + offset,
820                ));
821            }
822
823            state.buffer_memory_init_actions.extend(
824                index_buffer.initialization_status.read().create_action(
825                    index_buffer,
826                    offset..(offset + index_buffer_size),
827                    MemoryInitKind::NeedsInitializedMemory,
828                ),
829            );
830            Some(index_raw)
831        } else {
832            None
833        };
834        let transform_buffer = if let Some((ref mut transform_buffer, ref mut transform_pending)) =
835            buf.transform_buffer_transition
836        {
837            if mesh.transform_buffer_offset.is_none() {
838                return Err(BuildAccelerationStructureError::MissingAssociatedData(
839                    transform_buffer.error_ident(),
840                ));
841            }
842            let transform_raw = transform_buffer.try_raw(state.snatch_guard)?;
843            transform_buffer.check_usage(BufferUsages::BLAS_INPUT)?;
844
845            if let Some(barrier) = transform_pending
846                .take()
847                .map(|pending| pending.into_hal(transform_buffer, state.snatch_guard))
848            {
849                input_barriers.push(barrier);
850            }
851
852            let offset = mesh.transform_buffer_offset.unwrap();
853
854            if offset % wgt::TRANSFORM_BUFFER_ALIGNMENT != 0 {
855                return Err(
856                    BuildAccelerationStructureError::UnalignedTransformBufferOffset(
857                        transform_buffer.error_ident(),
858                    ),
859                );
860            }
861            if transform_buffer.size < 48 + offset {
862                return Err(BuildAccelerationStructureError::InsufficientBufferSize(
863                    transform_buffer.error_ident(),
864                    transform_buffer.size,
865                    48 + offset,
866                ));
867            }
868            state.buffer_memory_init_actions.extend(
869                transform_buffer.initialization_status.read().create_action(
870                    transform_buffer,
871                    offset..(offset + 48),
872                    MemoryInitKind::NeedsInitializedMemory,
873                ),
874            );
875            Some(transform_raw)
876        } else {
877            None
878        };
879
880        let triangles = hal::AccelerationStructureTriangles {
881            vertex_buffer: Some(vertex_buffer),
882            vertex_format: mesh.size.vertex_format,
883            first_vertex: mesh.first_vertex,
884            vertex_count: mesh.size.vertex_count,
885            vertex_stride: mesh.vertex_stride,
886            indices: index_buffer.map(|index_buffer| {
887                let index_stride = mesh.size.index_format.unwrap().byte_size() as u32;
888                hal::AccelerationStructureTriangleIndices::<dyn hal::DynBuffer> {
889                    format: mesh.size.index_format.unwrap(),
890                    buffer: Some(index_buffer),
891                    offset: mesh.first_index.unwrap() * index_stride,
892                    count: mesh.size.index_count.unwrap(),
893                }
894            }),
895            transform: transform_buffer.map(|transform_buffer| {
896                hal::AccelerationStructureTriangleTransform {
897                    buffer: transform_buffer,
898                    offset: mesh.transform_buffer_offset.unwrap() as u32,
899                }
900            }),
901            flags: mesh.size.flags,
902        };
903        triangle_entries.push(triangles);
904        if let Some(blas) = buf.ending_blas.take() {
905            let scratch_buffer_offset = *scratch_buffer_blas_size;
906            *scratch_buffer_blas_size += align_to(
907                blas.size_info.build_scratch_size as u32,
908                state.device.alignments.ray_tracing_scratch_buffer_alignment,
909            ) as u64;
910
911            blas_storage.push(BlasStore {
912                blas,
913                entries: hal::AccelerationStructureEntries::Triangles(triangle_entries),
914                scratch_buffer_offset,
915            });
916            triangle_entries = Vec::new();
917        }
918    }
919    Ok(())
920}
921
922fn map_blas<'a>(
923    storage: &'a BlasStore<'_>,
924    scratch_buffer: &'a dyn hal::DynBuffer,
925    snatch_guard: &'a SnatchGuard,
926    blases_compactable: &mut Vec<(
927        &'a dyn hal::DynBuffer,
928        &'a dyn hal::DynAccelerationStructure,
929    )>,
930) -> Result<
931    hal::BuildAccelerationStructureDescriptor<
932        'a,
933        dyn hal::DynBuffer,
934        dyn hal::DynAccelerationStructure,
935    >,
936    BuildAccelerationStructureError,
937> {
938    let BlasStore {
939        blas,
940        entries,
941        scratch_buffer_offset,
942    } = storage;
943    if blas.update_mode == wgt::AccelerationStructureUpdateMode::PreferUpdate {
944        log::info!("only rebuild implemented")
945    }
946    let raw = blas.try_raw(snatch_guard)?;
947
948    let state_lock = blas.compacted_state.lock();
949    if let BlasCompactState::Compacted = *state_lock {
950        return Err(BuildAccelerationStructureError::CompactedBlas(
951            blas.error_ident(),
952        ));
953    }
954
955    if blas
956        .flags
957        .contains(wgpu_types::AccelerationStructureFlags::ALLOW_COMPACTION)
958    {
959        blases_compactable.push((blas.compaction_buffer.as_ref().unwrap().as_ref(), raw));
960    }
961    Ok(hal::BuildAccelerationStructureDescriptor {
962        entries,
963        mode: hal::AccelerationStructureBuildMode::Build,
964        flags: blas.flags,
965        source_acceleration_structure: None,
966        destination_acceleration_structure: raw,
967        scratch_buffer,
968        scratch_buffer_offset: *scratch_buffer_offset,
969    })
970}
971
972fn build_blas<'a>(
973    cmd_buf_raw: &mut dyn hal::DynCommandEncoder,
974    blas_present: bool,
975    tlas_present: bool,
976    input_barriers: Vec<hal::BufferBarrier<dyn hal::DynBuffer>>,
977    blas_descriptors: &[hal::BuildAccelerationStructureDescriptor<
978        'a,
979        dyn hal::DynBuffer,
980        dyn hal::DynAccelerationStructure,
981    >],
982    scratch_buffer_barrier: hal::BufferBarrier<dyn hal::DynBuffer>,
983    blas_s_for_compaction: Vec<(
984        &'a dyn hal::DynBuffer,
985        &'a dyn hal::DynAccelerationStructure,
986    )>,
987) {
988    unsafe {
989        cmd_buf_raw.transition_buffers(&input_barriers);
990    }
991
992    if blas_present {
993        unsafe {
994            cmd_buf_raw.place_acceleration_structure_barrier(hal::AccelerationStructureBarrier {
995                usage: hal::StateTransition {
996                    from: hal::AccelerationStructureUses::BUILD_INPUT,
997                    to: hal::AccelerationStructureUses::BUILD_OUTPUT,
998                },
999            });
1000
1001            cmd_buf_raw.build_acceleration_structures(blas_descriptors);
1002        }
1003    }
1004
1005    if blas_present && tlas_present {
1006        unsafe {
1007            cmd_buf_raw.transition_buffers(&[scratch_buffer_barrier]);
1008        }
1009    }
1010
1011    let mut source_usage = hal::AccelerationStructureUses::empty();
1012    let mut destination_usage = hal::AccelerationStructureUses::empty();
1013    for &(buf, blas) in blas_s_for_compaction.iter() {
1014        unsafe {
1015            cmd_buf_raw.transition_buffers(&[hal::BufferBarrier {
1016                buffer: buf,
1017                usage: hal::StateTransition {
1018                    from: BufferUses::ACCELERATION_STRUCTURE_QUERY,
1019                    to: BufferUses::ACCELERATION_STRUCTURE_QUERY,
1020                },
1021            }])
1022        }
1023        unsafe { cmd_buf_raw.read_acceleration_structure_compact_size(blas, buf) }
1024        destination_usage |= hal::AccelerationStructureUses::COPY_SRC;
1025    }
1026
1027    if blas_present {
1028        source_usage |= hal::AccelerationStructureUses::BUILD_OUTPUT;
1029        destination_usage |= hal::AccelerationStructureUses::BUILD_INPUT
1030    }
1031    if tlas_present {
1032        source_usage |= hal::AccelerationStructureUses::SHADER_INPUT;
1033        destination_usage |= hal::AccelerationStructureUses::BUILD_OUTPUT;
1034    }
1035    unsafe {
1036        cmd_buf_raw.place_acceleration_structure_barrier(hal::AccelerationStructureBarrier {
1037            usage: hal::StateTransition {
1038                from: source_usage,
1039                to: destination_usage,
1040            },
1041        });
1042    }
1043}