wgpu_core/command/
ray_tracing.rs

1use alloc::{boxed::Box, sync::Arc, vec::Vec};
2use core::{
3    cmp::max,
4    num::NonZeroU64,
5    ops::{Deref, Range},
6};
7
8use wgt::{math::align_to, BufferUsages, BufferUses, Features};
9
10use crate::{
11    command::CommandBufferMutable,
12    device::queue::TempResource,
13    global::Global,
14    hub::Hub,
15    id::CommandEncoderId,
16    init_tracker::MemoryInitKind,
17    ray_tracing::{
18        BlasBuildEntry, BlasGeometries, BlasTriangleGeometry, BuildAccelerationStructureError,
19        TlasInstance, TlasPackage, TraceBlasBuildEntry, TraceBlasGeometries,
20        TraceBlasTriangleGeometry, TraceTlasInstance, TraceTlasPackage,
21    },
22    resource::{Blas, BlasCompactState, Buffer, Labeled, StagingBuffer, Tlas},
23    scratch::ScratchBuffer,
24    snatch::SnatchGuard,
25    track::PendingTransition,
26};
27use crate::{
28    command::CommandEncoder,
29    ray_tracing::{AsAction, AsBuild, TlasBuild, ValidateAsActionsError},
30};
31use crate::{command::EncoderStateError, device::resource::CommandIndices};
32use crate::{lock::RwLockWriteGuard, resource::RawResourceAccess};
33
34use crate::id::{BlasId, TlasId};
35
36struct TriangleBufferStore<'a> {
37    vertex_buffer: Arc<Buffer>,
38    vertex_transition: Option<PendingTransition<BufferUses>>,
39    index_buffer_transition: Option<(Arc<Buffer>, Option<PendingTransition<BufferUses>>)>,
40    transform_buffer_transition: Option<(Arc<Buffer>, Option<PendingTransition<BufferUses>>)>,
41    geometry: BlasTriangleGeometry<'a>,
42    ending_blas: Option<Arc<Blas>>,
43}
44
45struct BlasStore<'a> {
46    blas: Arc<Blas>,
47    entries: hal::AccelerationStructureEntries<'a, dyn hal::DynBuffer>,
48    scratch_buffer_offset: u64,
49}
50
51struct UnsafeTlasStore<'a> {
52    tlas: Arc<Tlas>,
53    entries: hal::AccelerationStructureEntries<'a, dyn hal::DynBuffer>,
54    scratch_buffer_offset: u64,
55}
56
57struct TlasStore<'a> {
58    internal: UnsafeTlasStore<'a>,
59    range: Range<usize>,
60}
61
62impl Global {
63    pub fn command_encoder_mark_acceleration_structures_built(
64        &self,
65        command_encoder_id: CommandEncoderId,
66        blas_ids: &[BlasId],
67        tlas_ids: &[TlasId],
68    ) -> Result<(), EncoderStateError> {
69        profiling::scope!("CommandEncoder::mark_acceleration_structures_built");
70
71        let hub = &self.hub;
72
73        let cmd_enc = hub.command_encoders.get(command_encoder_id);
74
75        let mut cmd_buf_data = cmd_enc.data.lock();
76        cmd_buf_data.record_with(
77            |cmd_buf_data| -> Result<(), BuildAccelerationStructureError> {
78                let device = &cmd_enc.device;
79                device.check_is_valid()?;
80                device.require_features(Features::EXPERIMENTAL_RAY_QUERY)?;
81
82                let mut build_command = AsBuild::default();
83
84                for blas in blas_ids {
85                    let blas = hub.blas_s.get(*blas).get()?;
86                    build_command.blas_s_built.push(blas);
87                }
88
89                for tlas in tlas_ids {
90                    let tlas = hub.tlas_s.get(*tlas).get()?;
91                    build_command.tlas_s_built.push(TlasBuild {
92                        tlas,
93                        dependencies: Vec::new(),
94                    });
95                }
96
97                cmd_buf_data.as_actions.push(AsAction::Build(build_command));
98                Ok(())
99            },
100        )
101    }
102
103    pub fn command_encoder_build_acceleration_structures<'a>(
104        &self,
105        command_encoder_id: CommandEncoderId,
106        blas_iter: impl Iterator<Item = BlasBuildEntry<'a>>,
107        tlas_iter: impl Iterator<Item = TlasPackage<'a>>,
108    ) -> Result<(), EncoderStateError> {
109        profiling::scope!("CommandEncoder::build_acceleration_structures");
110
111        let hub = &self.hub;
112
113        let cmd_enc = hub.command_encoders.get(command_encoder_id);
114
115        let trace_blas: Vec<TraceBlasBuildEntry> = blas_iter
116            .map(|blas_entry| {
117                let geometries = match blas_entry.geometries {
118                    BlasGeometries::TriangleGeometries(triangle_geometries) => {
119                        TraceBlasGeometries::TriangleGeometries(
120                            triangle_geometries
121                                .map(|tg| TraceBlasTriangleGeometry {
122                                    size: tg.size.clone(),
123                                    vertex_buffer: tg.vertex_buffer,
124                                    index_buffer: tg.index_buffer,
125                                    transform_buffer: tg.transform_buffer,
126                                    first_vertex: tg.first_vertex,
127                                    vertex_stride: tg.vertex_stride,
128                                    first_index: tg.first_index,
129                                    transform_buffer_offset: tg.transform_buffer_offset,
130                                })
131                                .collect(),
132                        )
133                    }
134                };
135                TraceBlasBuildEntry {
136                    blas_id: blas_entry.blas_id,
137                    geometries,
138                }
139            })
140            .collect();
141
142        let trace_tlas: Vec<TraceTlasPackage> = tlas_iter
143            .map(|package: TlasPackage| {
144                let instances = package
145                    .instances
146                    .map(|instance| {
147                        instance.map(|instance| TraceTlasInstance {
148                            blas_id: instance.blas_id,
149                            transform: *instance.transform,
150                            custom_data: instance.custom_data,
151                            mask: instance.mask,
152                        })
153                    })
154                    .collect();
155                TraceTlasPackage {
156                    tlas_id: package.tlas_id,
157                    instances,
158                    lowest_unmodified: package.lowest_unmodified,
159                }
160            })
161            .collect();
162
163        let mut cmd_buf_data = cmd_enc.data.lock();
164        cmd_buf_data.record_with(|cmd_buf_data| {
165            let blas_iter = trace_blas.iter().map(|blas_entry| {
166                let geometries = match &blas_entry.geometries {
167                    TraceBlasGeometries::TriangleGeometries(triangle_geometries) => {
168                        let iter = triangle_geometries.iter().map(|tg| BlasTriangleGeometry {
169                            size: &tg.size,
170                            vertex_buffer: tg.vertex_buffer,
171                            index_buffer: tg.index_buffer,
172                            transform_buffer: tg.transform_buffer,
173                            first_vertex: tg.first_vertex,
174                            vertex_stride: tg.vertex_stride,
175                            first_index: tg.first_index,
176                            transform_buffer_offset: tg.transform_buffer_offset,
177                        });
178                        BlasGeometries::TriangleGeometries(Box::new(iter))
179                    }
180                };
181                BlasBuildEntry {
182                    blas_id: blas_entry.blas_id,
183                    geometries,
184                }
185            });
186
187            let tlas_iter = trace_tlas.iter().map(|tlas_package| {
188                let instances = tlas_package.instances.iter().map(|instance| {
189                    instance.as_ref().map(|instance| TlasInstance {
190                        blas_id: instance.blas_id,
191                        transform: &instance.transform,
192                        custom_data: instance.custom_data,
193                        mask: instance.mask,
194                    })
195                });
196                TlasPackage {
197                    tlas_id: tlas_package.tlas_id,
198                    instances: Box::new(instances),
199                    lowest_unmodified: tlas_package.lowest_unmodified,
200                }
201            });
202
203            build_acceleration_structures(
204                cmd_buf_data,
205                hub,
206                &cmd_enc,
207                trace_blas.clone(),
208                trace_tlas.clone(),
209                blas_iter,
210                tlas_iter,
211            )
212        })
213    }
214}
215
216pub(crate) fn build_acceleration_structures<'a>(
217    cmd_buf_data: &'a mut CommandBufferMutable,
218    hub: &'a Hub,
219    cmd_enc: &'a Arc<CommandEncoder>,
220    trace_blas: Vec<TraceBlasBuildEntry>,
221    trace_tlas: Vec<TraceTlasPackage>,
222    blas_iter: impl Iterator<Item = BlasBuildEntry<'a>>,
223    tlas_iter: impl Iterator<Item = TlasPackage<'a>>,
224) -> Result<(), BuildAccelerationStructureError> {
225    #[cfg(feature = "trace")]
226    if let Some(ref mut list) = cmd_buf_data.trace_commands {
227        list.push(crate::command::Command::BuildAccelerationStructures {
228            blas: trace_blas,
229            tlas: trace_tlas,
230        });
231    }
232    #[cfg(not(feature = "trace"))]
233    {
234        let _ = trace_blas;
235        let _ = trace_tlas;
236    }
237
238    let device = &cmd_enc.device;
239    device.check_is_valid()?;
240    device.require_features(Features::EXPERIMENTAL_RAY_QUERY)?;
241
242    let mut build_command = AsBuild::default();
243    let mut buf_storage = Vec::new();
244    iter_blas(
245        blas_iter,
246        cmd_buf_data,
247        &mut build_command,
248        &mut buf_storage,
249        hub,
250    )?;
251
252    let snatch_guard = device.snatchable_lock.read();
253    let mut input_barriers = Vec::<hal::BufferBarrier<dyn hal::DynBuffer>>::new();
254    let mut scratch_buffer_blas_size = 0;
255    let mut blas_storage = Vec::new();
256    iter_buffers(
257        &mut buf_storage,
258        &snatch_guard,
259        &mut input_barriers,
260        cmd_buf_data,
261        &mut scratch_buffer_blas_size,
262        &mut blas_storage,
263        hub,
264        device.alignments.ray_tracing_scratch_buffer_alignment,
265    )?;
266    let mut tlas_lock_store = Vec::<(Option<TlasPackage>, Arc<Tlas>)>::new();
267
268    for package in tlas_iter {
269        let tlas = hub.tlas_s.get(package.tlas_id).get()?;
270
271        cmd_buf_data.trackers.tlas_s.insert_single(tlas.clone());
272
273        tlas_lock_store.push((Some(package), tlas))
274    }
275
276    let mut scratch_buffer_tlas_size = 0;
277    let mut tlas_storage = Vec::<TlasStore>::new();
278    let mut instance_buffer_staging_source = Vec::<u8>::new();
279
280    for (package, tlas) in &mut tlas_lock_store {
281        let package = package.take().unwrap();
282
283        let scratch_buffer_offset = scratch_buffer_tlas_size;
284        scratch_buffer_tlas_size += align_to(
285            tlas.size_info.build_scratch_size as u32,
286            device.alignments.ray_tracing_scratch_buffer_alignment,
287        ) as u64;
288
289        let first_byte_index = instance_buffer_staging_source.len();
290
291        let mut dependencies = Vec::new();
292
293        let mut instance_count = 0;
294        for instance in package.instances.flatten() {
295            if instance.custom_data >= (1u32 << 24u32) {
296                return Err(BuildAccelerationStructureError::TlasInvalidCustomIndex(
297                    tlas.error_ident(),
298                ));
299            }
300            let blas = hub.blas_s.get(instance.blas_id).get()?;
301
302            cmd_buf_data.trackers.blas_s.insert_single(blas.clone());
303
304            instance_buffer_staging_source.extend(device.raw().tlas_instance_to_bytes(
305                hal::TlasInstance {
306                    transform: *instance.transform,
307                    custom_data: instance.custom_data,
308                    mask: instance.mask,
309                    blas_address: blas.handle,
310                },
311            ));
312
313            if tlas
314                .flags
315                .contains(wgpu_types::AccelerationStructureFlags::ALLOW_RAY_HIT_VERTEX_RETURN)
316                && !blas
317                    .flags
318                    .contains(wgpu_types::AccelerationStructureFlags::ALLOW_RAY_HIT_VERTEX_RETURN)
319            {
320                return Err(
321                    BuildAccelerationStructureError::TlasDependentMissingVertexReturn(
322                        tlas.error_ident(),
323                        blas.error_ident(),
324                    ),
325                );
326            }
327
328            instance_count += 1;
329
330            dependencies.push(blas.clone());
331        }
332
333        build_command.tlas_s_built.push(TlasBuild {
334            tlas: tlas.clone(),
335            dependencies,
336        });
337
338        if instance_count > tlas.max_instance_count {
339            return Err(BuildAccelerationStructureError::TlasInstanceCountExceeded(
340                tlas.error_ident(),
341                instance_count,
342                tlas.max_instance_count,
343            ));
344        }
345
346        tlas_storage.push(TlasStore {
347            internal: UnsafeTlasStore {
348                tlas: tlas.clone(),
349                entries: hal::AccelerationStructureEntries::Instances(
350                    hal::AccelerationStructureInstances {
351                        buffer: Some(tlas.instance_buffer.as_ref()),
352                        offset: 0,
353                        count: instance_count,
354                    },
355                ),
356                scratch_buffer_offset,
357            },
358            range: first_byte_index..instance_buffer_staging_source.len(),
359        });
360    }
361
362    let Some(scratch_size) =
363        wgt::BufferSize::new(max(scratch_buffer_blas_size, scratch_buffer_tlas_size))
364    else {
365        // if the size is zero there is nothing to build
366        return Ok(());
367    };
368
369    let scratch_buffer = ScratchBuffer::new(device, scratch_size)?;
370
371    let scratch_buffer_barrier = hal::BufferBarrier::<dyn hal::DynBuffer> {
372        buffer: scratch_buffer.raw(),
373        usage: hal::StateTransition {
374            from: BufferUses::ACCELERATION_STRUCTURE_SCRATCH,
375            to: BufferUses::ACCELERATION_STRUCTURE_SCRATCH,
376        },
377    };
378
379    let mut tlas_descriptors = Vec::with_capacity(tlas_storage.len());
380
381    for &TlasStore {
382        internal:
383            UnsafeTlasStore {
384                ref tlas,
385                ref entries,
386                ref scratch_buffer_offset,
387            },
388        ..
389    } in &tlas_storage
390    {
391        if tlas.update_mode == wgt::AccelerationStructureUpdateMode::PreferUpdate {
392            log::info!("only rebuild implemented")
393        }
394        tlas_descriptors.push(hal::BuildAccelerationStructureDescriptor {
395            entries,
396            mode: hal::AccelerationStructureBuildMode::Build,
397            flags: tlas.flags,
398            source_acceleration_structure: None,
399            destination_acceleration_structure: tlas.try_raw(&snatch_guard)?,
400            scratch_buffer: scratch_buffer.raw(),
401            scratch_buffer_offset: *scratch_buffer_offset,
402        })
403    }
404
405    let blas_present = !blas_storage.is_empty();
406    let tlas_present = !tlas_storage.is_empty();
407
408    let cmd_buf_raw = cmd_buf_data.encoder.open()?;
409
410    let mut blas_s_compactable = Vec::new();
411    let mut descriptors = Vec::new();
412
413    for storage in &blas_storage {
414        descriptors.push(map_blas(
415            storage,
416            scratch_buffer.raw(),
417            &snatch_guard,
418            &mut blas_s_compactable,
419        )?);
420    }
421
422    build_blas(
423        cmd_buf_raw,
424        blas_present,
425        tlas_present,
426        input_barriers,
427        &descriptors,
428        scratch_buffer_barrier,
429        blas_s_compactable,
430    );
431
432    if tlas_present {
433        let staging_buffer = if !instance_buffer_staging_source.is_empty() {
434            let mut staging_buffer = StagingBuffer::new(
435                device,
436                wgt::BufferSize::new(instance_buffer_staging_source.len() as u64).unwrap(),
437            )?;
438            staging_buffer.write(&instance_buffer_staging_source);
439            let flushed = staging_buffer.flush();
440            Some(flushed)
441        } else {
442            None
443        };
444
445        unsafe {
446            if let Some(ref staging_buffer) = staging_buffer {
447                cmd_buf_raw.transition_buffers(&[hal::BufferBarrier::<dyn hal::DynBuffer> {
448                    buffer: staging_buffer.raw(),
449                    usage: hal::StateTransition {
450                        from: BufferUses::MAP_WRITE,
451                        to: BufferUses::COPY_SRC,
452                    },
453                }]);
454            }
455        }
456
457        let mut instance_buffer_barriers = Vec::new();
458        for &TlasStore {
459            internal: UnsafeTlasStore { ref tlas, .. },
460            ref range,
461        } in &tlas_storage
462        {
463            let size = match wgt::BufferSize::new((range.end - range.start) as u64) {
464                None => continue,
465                Some(size) => size,
466            };
467            instance_buffer_barriers.push(hal::BufferBarrier::<dyn hal::DynBuffer> {
468                buffer: tlas.instance_buffer.as_ref(),
469                usage: hal::StateTransition {
470                    from: BufferUses::COPY_DST,
471                    to: BufferUses::TOP_LEVEL_ACCELERATION_STRUCTURE_INPUT,
472                },
473            });
474            unsafe {
475                cmd_buf_raw.transition_buffers(&[hal::BufferBarrier::<dyn hal::DynBuffer> {
476                    buffer: tlas.instance_buffer.as_ref(),
477                    usage: hal::StateTransition {
478                        from: BufferUses::TOP_LEVEL_ACCELERATION_STRUCTURE_INPUT,
479                        to: BufferUses::COPY_DST,
480                    },
481                }]);
482                let temp = hal::BufferCopy {
483                    src_offset: range.start as u64,
484                    dst_offset: 0,
485                    size,
486                };
487                cmd_buf_raw.copy_buffer_to_buffer(
488                    // the range whose size we just checked end is at (at that point in time) instance_buffer_staging_source.len()
489                    // and since instance_buffer_staging_source doesn't shrink we can un wrap this without a panic
490                    staging_buffer.as_ref().unwrap().raw(),
491                    tlas.instance_buffer.as_ref(),
492                    &[temp],
493                );
494            }
495        }
496
497        unsafe {
498            cmd_buf_raw.transition_buffers(&instance_buffer_barriers);
499
500            cmd_buf_raw.build_acceleration_structures(&tlas_descriptors);
501
502            cmd_buf_raw.place_acceleration_structure_barrier(hal::AccelerationStructureBarrier {
503                usage: hal::StateTransition {
504                    from: hal::AccelerationStructureUses::BUILD_OUTPUT,
505                    to: hal::AccelerationStructureUses::SHADER_INPUT,
506                },
507            });
508        }
509
510        if let Some(staging_buffer) = staging_buffer {
511            cmd_buf_data
512                .temp_resources
513                .push(TempResource::StagingBuffer(staging_buffer));
514        }
515    }
516
517    cmd_buf_data
518        .temp_resources
519        .push(TempResource::ScratchBuffer(scratch_buffer));
520
521    cmd_buf_data.as_actions.push(AsAction::Build(build_command));
522
523    Ok(())
524}
525
526impl CommandBufferMutable {
527    pub(crate) fn validate_acceleration_structure_actions(
528        &self,
529        snatch_guard: &SnatchGuard,
530        command_index_guard: &mut RwLockWriteGuard<CommandIndices>,
531    ) -> Result<(), ValidateAsActionsError> {
532        profiling::scope!("CommandEncoder::[submission]::validate_as_actions");
533        for action in &self.as_actions {
534            match action {
535                AsAction::Build(build) => {
536                    let build_command_index = NonZeroU64::new(
537                        command_index_guard.next_acceleration_structure_build_command_index,
538                    )
539                    .unwrap();
540
541                    command_index_guard.next_acceleration_structure_build_command_index += 1;
542                    for blas in build.blas_s_built.iter() {
543                        let mut state_lock = blas.compacted_state.lock();
544                        *state_lock = match *state_lock {
545                            BlasCompactState::Compacted => {
546                                unreachable!("Should be validated out in build.")
547                            }
548                            // Reset the compacted state to idle. This means any prepares, before mapping their
549                            // internal buffer, will terminate.
550                            _ => BlasCompactState::Idle,
551                        };
552                        *blas.built_index.write() = Some(build_command_index);
553                    }
554
555                    for tlas_build in build.tlas_s_built.iter() {
556                        for blas in &tlas_build.dependencies {
557                            if blas.built_index.read().is_none() {
558                                return Err(ValidateAsActionsError::UsedUnbuiltBlas(
559                                    blas.error_ident(),
560                                    tlas_build.tlas.error_ident(),
561                                ));
562                            }
563                        }
564                        *tlas_build.tlas.built_index.write() = Some(build_command_index);
565                        tlas_build
566                            .tlas
567                            .dependencies
568                            .write()
569                            .clone_from(&tlas_build.dependencies)
570                    }
571                }
572                AsAction::UseTlas(tlas) => {
573                    let tlas_build_index = tlas.built_index.read();
574                    let dependencies = tlas.dependencies.read();
575
576                    if (*tlas_build_index).is_none() {
577                        return Err(ValidateAsActionsError::UsedUnbuiltTlas(tlas.error_ident()));
578                    }
579                    for blas in dependencies.deref() {
580                        let blas_build_index = *blas.built_index.read();
581                        if blas_build_index.is_none() {
582                            return Err(ValidateAsActionsError::UsedUnbuiltBlas(
583                                tlas.error_ident(),
584                                blas.error_ident(),
585                            ));
586                        }
587                        if blas_build_index.unwrap() > tlas_build_index.unwrap() {
588                            return Err(ValidateAsActionsError::BlasNewerThenTlas(
589                                blas.error_ident(),
590                                tlas.error_ident(),
591                            ));
592                        }
593                        blas.try_raw(snatch_guard)?;
594                    }
595                }
596            }
597        }
598        Ok(())
599    }
600}
601
602///iterates over the blas iterator, and it's geometry, pushing the buffers into a storage vector (and also some validation).
603fn iter_blas<'a>(
604    blas_iter: impl Iterator<Item = BlasBuildEntry<'a>>,
605    cmd_buf_data: &mut CommandBufferMutable,
606    build_command: &mut AsBuild,
607    buf_storage: &mut Vec<TriangleBufferStore<'a>>,
608    hub: &Hub,
609) -> Result<(), BuildAccelerationStructureError> {
610    let mut temp_buffer = Vec::new();
611    for entry in blas_iter {
612        let blas = hub.blas_s.get(entry.blas_id).get()?;
613        cmd_buf_data.trackers.blas_s.insert_single(blas.clone());
614
615        build_command.blas_s_built.push(blas.clone());
616
617        match entry.geometries {
618            BlasGeometries::TriangleGeometries(triangle_geometries) => {
619                for (i, mesh) in triangle_geometries.enumerate() {
620                    let size_desc = match &blas.sizes {
621                        wgt::BlasGeometrySizeDescriptors::Triangles { descriptors } => descriptors,
622                    };
623                    if i >= size_desc.len() {
624                        return Err(BuildAccelerationStructureError::IncompatibleBlasBuildSizes(
625                            blas.error_ident(),
626                        ));
627                    }
628                    let size_desc = &size_desc[i];
629
630                    if size_desc.flags != mesh.size.flags {
631                        return Err(BuildAccelerationStructureError::IncompatibleBlasFlags(
632                            blas.error_ident(),
633                            size_desc.flags,
634                            mesh.size.flags,
635                        ));
636                    }
637
638                    if size_desc.vertex_count < mesh.size.vertex_count {
639                        return Err(
640                            BuildAccelerationStructureError::IncompatibleBlasVertexCount(
641                                blas.error_ident(),
642                                size_desc.vertex_count,
643                                mesh.size.vertex_count,
644                            ),
645                        );
646                    }
647
648                    if size_desc.vertex_format != mesh.size.vertex_format {
649                        return Err(BuildAccelerationStructureError::DifferentBlasVertexFormats(
650                            blas.error_ident(),
651                            size_desc.vertex_format,
652                            mesh.size.vertex_format,
653                        ));
654                    }
655
656                    if size_desc
657                        .vertex_format
658                        .min_acceleration_structure_vertex_stride()
659                        > mesh.vertex_stride
660                    {
661                        return Err(BuildAccelerationStructureError::VertexStrideTooSmall(
662                            blas.error_ident(),
663                            size_desc
664                                .vertex_format
665                                .min_acceleration_structure_vertex_stride(),
666                            mesh.vertex_stride,
667                        ));
668                    }
669
670                    if mesh.vertex_stride
671                        % size_desc
672                            .vertex_format
673                            .acceleration_structure_stride_alignment()
674                        != 0
675                    {
676                        return Err(BuildAccelerationStructureError::VertexStrideUnaligned(
677                            blas.error_ident(),
678                            size_desc
679                                .vertex_format
680                                .acceleration_structure_stride_alignment(),
681                            mesh.vertex_stride,
682                        ));
683                    }
684
685                    match (size_desc.index_count, mesh.size.index_count) {
686                        (Some(_), None) | (None, Some(_)) => {
687                            return Err(
688                                BuildAccelerationStructureError::BlasIndexCountProvidedMismatch(
689                                    blas.error_ident(),
690                                ),
691                            )
692                        }
693                        (Some(create), Some(build)) if create < build => {
694                            return Err(
695                                BuildAccelerationStructureError::IncompatibleBlasIndexCount(
696                                    blas.error_ident(),
697                                    create,
698                                    build,
699                                ),
700                            )
701                        }
702                        _ => {}
703                    }
704
705                    if size_desc.index_format != mesh.size.index_format {
706                        return Err(BuildAccelerationStructureError::DifferentBlasIndexFormats(
707                            blas.error_ident(),
708                            size_desc.index_format,
709                            mesh.size.index_format,
710                        ));
711                    }
712
713                    if size_desc.index_count.is_some() && mesh.index_buffer.is_none() {
714                        return Err(BuildAccelerationStructureError::MissingIndexBuffer(
715                            blas.error_ident(),
716                        ));
717                    }
718                    let vertex_buffer = hub.buffers.get(mesh.vertex_buffer).get()?;
719                    let vertex_pending = cmd_buf_data.trackers.buffers.set_single(
720                        &vertex_buffer,
721                        BufferUses::BOTTOM_LEVEL_ACCELERATION_STRUCTURE_INPUT,
722                    );
723                    let index_data = if let Some(index_id) = mesh.index_buffer {
724                        let index_buffer = hub.buffers.get(index_id).get()?;
725                        if mesh.first_index.is_none()
726                            || mesh.size.index_count.is_none()
727                            || mesh.size.index_count.is_none()
728                        {
729                            return Err(BuildAccelerationStructureError::MissingAssociatedData(
730                                index_buffer.error_ident(),
731                            ));
732                        }
733                        let data = cmd_buf_data.trackers.buffers.set_single(
734                            &index_buffer,
735                            BufferUses::BOTTOM_LEVEL_ACCELERATION_STRUCTURE_INPUT,
736                        );
737                        Some((index_buffer, data))
738                    } else {
739                        None
740                    };
741                    let transform_data = if let Some(transform_id) = mesh.transform_buffer {
742                        if !blas
743                            .flags
744                            .contains(wgt::AccelerationStructureFlags::USE_TRANSFORM)
745                        {
746                            return Err(BuildAccelerationStructureError::UseTransformMissing(
747                                blas.error_ident(),
748                            ));
749                        }
750                        let transform_buffer = hub.buffers.get(transform_id).get()?;
751                        if mesh.transform_buffer_offset.is_none() {
752                            return Err(BuildAccelerationStructureError::MissingAssociatedData(
753                                transform_buffer.error_ident(),
754                            ));
755                        }
756                        let data = cmd_buf_data.trackers.buffers.set_single(
757                            &transform_buffer,
758                            BufferUses::BOTTOM_LEVEL_ACCELERATION_STRUCTURE_INPUT,
759                        );
760                        Some((transform_buffer, data))
761                    } else {
762                        if blas
763                            .flags
764                            .contains(wgt::AccelerationStructureFlags::USE_TRANSFORM)
765                        {
766                            return Err(BuildAccelerationStructureError::TransformMissing(
767                                blas.error_ident(),
768                            ));
769                        }
770                        None
771                    };
772                    temp_buffer.push(TriangleBufferStore {
773                        vertex_buffer,
774                        vertex_transition: vertex_pending,
775                        index_buffer_transition: index_data,
776                        transform_buffer_transition: transform_data,
777                        geometry: mesh,
778                        ending_blas: None,
779                    });
780                }
781
782                if let Some(last) = temp_buffer.last_mut() {
783                    last.ending_blas = Some(blas);
784                    buf_storage.append(&mut temp_buffer);
785                }
786            }
787        }
788    }
789    Ok(())
790}
791
792/// Iterates over the buffers generated in [iter_blas], convert the barriers into hal barriers, and the triangles into [hal::AccelerationStructureEntries] (and also some validation).
793fn iter_buffers<'a, 'b>(
794    buf_storage: &'a mut Vec<TriangleBufferStore<'b>>,
795    snatch_guard: &'a SnatchGuard,
796    input_barriers: &mut Vec<hal::BufferBarrier<'a, dyn hal::DynBuffer>>,
797    cmd_buf_data: &mut CommandBufferMutable,
798    scratch_buffer_blas_size: &mut u64,
799    blas_storage: &mut Vec<BlasStore<'a>>,
800    hub: &Hub,
801    ray_tracing_scratch_buffer_alignment: u32,
802) -> Result<(), BuildAccelerationStructureError> {
803    let mut triangle_entries =
804        Vec::<hal::AccelerationStructureTriangles<dyn hal::DynBuffer>>::new();
805    for buf in buf_storage {
806        let mesh = &buf.geometry;
807        let vertex_buffer = {
808            let vertex_buffer = buf.vertex_buffer.as_ref();
809            let vertex_raw = vertex_buffer.try_raw(snatch_guard)?;
810            vertex_buffer.check_usage(BufferUsages::BLAS_INPUT)?;
811
812            if let Some(barrier) = buf
813                .vertex_transition
814                .take()
815                .map(|pending| pending.into_hal(vertex_buffer, snatch_guard))
816            {
817                input_barriers.push(barrier);
818            }
819            if vertex_buffer.size
820                < (mesh.size.vertex_count + mesh.first_vertex) as u64 * mesh.vertex_stride
821            {
822                return Err(BuildAccelerationStructureError::InsufficientBufferSize(
823                    vertex_buffer.error_ident(),
824                    vertex_buffer.size,
825                    (mesh.size.vertex_count + mesh.first_vertex) as u64 * mesh.vertex_stride,
826                ));
827            }
828            let vertex_buffer_offset = mesh.first_vertex as u64 * mesh.vertex_stride;
829            cmd_buf_data.buffer_memory_init_actions.extend(
830                vertex_buffer.initialization_status.read().create_action(
831                    &hub.buffers.get(mesh.vertex_buffer).get()?,
832                    vertex_buffer_offset
833                        ..(vertex_buffer_offset
834                            + mesh.size.vertex_count as u64 * mesh.vertex_stride),
835                    MemoryInitKind::NeedsInitializedMemory,
836                ),
837            );
838            vertex_raw
839        };
840        let index_buffer = if let Some((ref mut index_buffer, ref mut index_pending)) =
841            buf.index_buffer_transition
842        {
843            let index_raw = index_buffer.try_raw(snatch_guard)?;
844            index_buffer.check_usage(BufferUsages::BLAS_INPUT)?;
845
846            if let Some(barrier) = index_pending
847                .take()
848                .map(|pending| pending.into_hal(index_buffer, snatch_guard))
849            {
850                input_barriers.push(barrier);
851            }
852            let index_stride = mesh.size.index_format.unwrap().byte_size() as u64;
853            let offset = mesh.first_index.unwrap() as u64 * index_stride;
854            let index_buffer_size = mesh.size.index_count.unwrap() as u64 * index_stride;
855
856            if mesh.size.index_count.unwrap() % 3 != 0 {
857                return Err(BuildAccelerationStructureError::InvalidIndexCount(
858                    index_buffer.error_ident(),
859                    mesh.size.index_count.unwrap(),
860                ));
861            }
862            if index_buffer.size < mesh.size.index_count.unwrap() as u64 * index_stride + offset {
863                return Err(BuildAccelerationStructureError::InsufficientBufferSize(
864                    index_buffer.error_ident(),
865                    index_buffer.size,
866                    mesh.size.index_count.unwrap() as u64 * index_stride + offset,
867                ));
868            }
869
870            cmd_buf_data.buffer_memory_init_actions.extend(
871                index_buffer.initialization_status.read().create_action(
872                    index_buffer,
873                    offset..(offset + index_buffer_size),
874                    MemoryInitKind::NeedsInitializedMemory,
875                ),
876            );
877            Some(index_raw)
878        } else {
879            None
880        };
881        let transform_buffer = if let Some((ref mut transform_buffer, ref mut transform_pending)) =
882            buf.transform_buffer_transition
883        {
884            if mesh.transform_buffer_offset.is_none() {
885                return Err(BuildAccelerationStructureError::MissingAssociatedData(
886                    transform_buffer.error_ident(),
887                ));
888            }
889            let transform_raw = transform_buffer.try_raw(snatch_guard)?;
890            transform_buffer.check_usage(BufferUsages::BLAS_INPUT)?;
891
892            if let Some(barrier) = transform_pending
893                .take()
894                .map(|pending| pending.into_hal(transform_buffer, snatch_guard))
895            {
896                input_barriers.push(barrier);
897            }
898
899            let offset = mesh.transform_buffer_offset.unwrap();
900
901            if offset % wgt::TRANSFORM_BUFFER_ALIGNMENT != 0 {
902                return Err(
903                    BuildAccelerationStructureError::UnalignedTransformBufferOffset(
904                        transform_buffer.error_ident(),
905                    ),
906                );
907            }
908            if transform_buffer.size < 48 + offset {
909                return Err(BuildAccelerationStructureError::InsufficientBufferSize(
910                    transform_buffer.error_ident(),
911                    transform_buffer.size,
912                    48 + offset,
913                ));
914            }
915            cmd_buf_data.buffer_memory_init_actions.extend(
916                transform_buffer.initialization_status.read().create_action(
917                    transform_buffer,
918                    offset..(offset + 48),
919                    MemoryInitKind::NeedsInitializedMemory,
920                ),
921            );
922            Some(transform_raw)
923        } else {
924            None
925        };
926
927        let triangles = hal::AccelerationStructureTriangles {
928            vertex_buffer: Some(vertex_buffer),
929            vertex_format: mesh.size.vertex_format,
930            first_vertex: mesh.first_vertex,
931            vertex_count: mesh.size.vertex_count,
932            vertex_stride: mesh.vertex_stride,
933            indices: index_buffer.map(|index_buffer| {
934                let index_stride = mesh.size.index_format.unwrap().byte_size() as u32;
935                hal::AccelerationStructureTriangleIndices::<dyn hal::DynBuffer> {
936                    format: mesh.size.index_format.unwrap(),
937                    buffer: Some(index_buffer),
938                    offset: mesh.first_index.unwrap() * index_stride,
939                    count: mesh.size.index_count.unwrap(),
940                }
941            }),
942            transform: transform_buffer.map(|transform_buffer| {
943                hal::AccelerationStructureTriangleTransform {
944                    buffer: transform_buffer,
945                    offset: mesh.transform_buffer_offset.unwrap() as u32,
946                }
947            }),
948            flags: mesh.size.flags,
949        };
950        triangle_entries.push(triangles);
951        if let Some(blas) = buf.ending_blas.take() {
952            let scratch_buffer_offset = *scratch_buffer_blas_size;
953            *scratch_buffer_blas_size += align_to(
954                blas.size_info.build_scratch_size as u32,
955                ray_tracing_scratch_buffer_alignment,
956            ) as u64;
957
958            blas_storage.push(BlasStore {
959                blas,
960                entries: hal::AccelerationStructureEntries::Triangles(triangle_entries),
961                scratch_buffer_offset,
962            });
963            triangle_entries = Vec::new();
964        }
965    }
966    Ok(())
967}
968
969fn map_blas<'a>(
970    storage: &'a BlasStore<'_>,
971    scratch_buffer: &'a dyn hal::DynBuffer,
972    snatch_guard: &'a SnatchGuard,
973    blases_compactable: &mut Vec<(
974        &'a dyn hal::DynBuffer,
975        &'a dyn hal::DynAccelerationStructure,
976    )>,
977) -> Result<
978    hal::BuildAccelerationStructureDescriptor<
979        'a,
980        dyn hal::DynBuffer,
981        dyn hal::DynAccelerationStructure,
982    >,
983    BuildAccelerationStructureError,
984> {
985    let BlasStore {
986        blas,
987        entries,
988        scratch_buffer_offset,
989    } = storage;
990    if blas.update_mode == wgt::AccelerationStructureUpdateMode::PreferUpdate {
991        log::info!("only rebuild implemented")
992    }
993    let raw = blas.try_raw(snatch_guard)?;
994
995    let state_lock = blas.compacted_state.lock();
996    if let BlasCompactState::Compacted = *state_lock {
997        return Err(BuildAccelerationStructureError::CompactedBlas(
998            blas.error_ident(),
999        ));
1000    }
1001
1002    if blas
1003        .flags
1004        .contains(wgpu_types::AccelerationStructureFlags::ALLOW_COMPACTION)
1005    {
1006        blases_compactable.push((blas.compaction_buffer.as_ref().unwrap().as_ref(), raw));
1007    }
1008    Ok(hal::BuildAccelerationStructureDescriptor {
1009        entries,
1010        mode: hal::AccelerationStructureBuildMode::Build,
1011        flags: blas.flags,
1012        source_acceleration_structure: None,
1013        destination_acceleration_structure: raw,
1014        scratch_buffer,
1015        scratch_buffer_offset: *scratch_buffer_offset,
1016    })
1017}
1018
1019fn build_blas<'a>(
1020    cmd_buf_raw: &mut dyn hal::DynCommandEncoder,
1021    blas_present: bool,
1022    tlas_present: bool,
1023    input_barriers: Vec<hal::BufferBarrier<dyn hal::DynBuffer>>,
1024    blas_descriptors: &[hal::BuildAccelerationStructureDescriptor<
1025        'a,
1026        dyn hal::DynBuffer,
1027        dyn hal::DynAccelerationStructure,
1028    >],
1029    scratch_buffer_barrier: hal::BufferBarrier<dyn hal::DynBuffer>,
1030    blas_s_for_compaction: Vec<(
1031        &'a dyn hal::DynBuffer,
1032        &'a dyn hal::DynAccelerationStructure,
1033    )>,
1034) {
1035    unsafe {
1036        cmd_buf_raw.transition_buffers(&input_barriers);
1037    }
1038
1039    if blas_present {
1040        unsafe {
1041            cmd_buf_raw.place_acceleration_structure_barrier(hal::AccelerationStructureBarrier {
1042                usage: hal::StateTransition {
1043                    from: hal::AccelerationStructureUses::BUILD_INPUT,
1044                    to: hal::AccelerationStructureUses::BUILD_OUTPUT,
1045                },
1046            });
1047
1048            cmd_buf_raw.build_acceleration_structures(blas_descriptors);
1049        }
1050    }
1051
1052    if blas_present && tlas_present {
1053        unsafe {
1054            cmd_buf_raw.transition_buffers(&[scratch_buffer_barrier]);
1055        }
1056    }
1057
1058    let mut source_usage = hal::AccelerationStructureUses::empty();
1059    let mut destination_usage = hal::AccelerationStructureUses::empty();
1060    for &(buf, blas) in blas_s_for_compaction.iter() {
1061        unsafe {
1062            cmd_buf_raw.transition_buffers(&[hal::BufferBarrier {
1063                buffer: buf,
1064                usage: hal::StateTransition {
1065                    from: BufferUses::ACCELERATION_STRUCTURE_QUERY,
1066                    to: BufferUses::ACCELERATION_STRUCTURE_QUERY,
1067                },
1068            }])
1069        }
1070        unsafe { cmd_buf_raw.read_acceleration_structure_compact_size(blas, buf) }
1071        destination_usage |= hal::AccelerationStructureUses::COPY_SRC;
1072    }
1073
1074    if blas_present {
1075        source_usage |= hal::AccelerationStructureUses::BUILD_OUTPUT;
1076        destination_usage |= hal::AccelerationStructureUses::BUILD_INPUT
1077    }
1078    if tlas_present {
1079        source_usage |= hal::AccelerationStructureUses::SHADER_INPUT;
1080        destination_usage |= hal::AccelerationStructureUses::BUILD_OUTPUT;
1081    }
1082    unsafe {
1083        cmd_buf_raw.place_acceleration_structure_barrier(hal::AccelerationStructureBarrier {
1084            usage: hal::StateTransition {
1085                from: source_usage,
1086                to: destination_usage,
1087            },
1088        });
1089    }
1090}