wgpu_core/command/
ray_tracing.rs

1use alloc::{boxed::Box, sync::Arc, vec::Vec};
2use core::{
3    cmp::max,
4    num::NonZeroU64,
5    ops::{Deref, Range},
6};
7
8use wgt::{math::align_to, BufferUsages, BufferUses, Features};
9
10use crate::ray_tracing::{AsAction, AsBuild, TlasBuild, ValidateAsActionsError};
11use crate::{
12    command::CommandBufferMutable,
13    device::queue::TempResource,
14    global::Global,
15    hub::Hub,
16    id::CommandEncoderId,
17    init_tracker::MemoryInitKind,
18    ray_tracing::{
19        BlasBuildEntry, BlasGeometries, BlasTriangleGeometry, BuildAccelerationStructureError,
20        TlasInstance, TlasPackage, TraceBlasBuildEntry, TraceBlasGeometries,
21        TraceBlasTriangleGeometry, TraceTlasInstance, TraceTlasPackage,
22    },
23    resource::{Blas, BlasCompactState, Buffer, Labeled, StagingBuffer, Tlas},
24    scratch::ScratchBuffer,
25    snatch::SnatchGuard,
26    track::PendingTransition,
27};
28use crate::{command::EncoderStateError, device::resource::CommandIndices};
29use crate::{lock::RwLockWriteGuard, resource::RawResourceAccess};
30
31use crate::id::{BlasId, TlasId};
32
33struct TriangleBufferStore<'a> {
34    vertex_buffer: Arc<Buffer>,
35    vertex_transition: Option<PendingTransition<BufferUses>>,
36    index_buffer_transition: Option<(Arc<Buffer>, Option<PendingTransition<BufferUses>>)>,
37    transform_buffer_transition: Option<(Arc<Buffer>, Option<PendingTransition<BufferUses>>)>,
38    geometry: BlasTriangleGeometry<'a>,
39    ending_blas: Option<Arc<Blas>>,
40}
41
42struct BlasStore<'a> {
43    blas: Arc<Blas>,
44    entries: hal::AccelerationStructureEntries<'a, dyn hal::DynBuffer>,
45    scratch_buffer_offset: u64,
46}
47
48struct UnsafeTlasStore<'a> {
49    tlas: Arc<Tlas>,
50    entries: hal::AccelerationStructureEntries<'a, dyn hal::DynBuffer>,
51    scratch_buffer_offset: u64,
52}
53
54struct TlasStore<'a> {
55    internal: UnsafeTlasStore<'a>,
56    range: Range<usize>,
57}
58
59impl Global {
60    pub fn command_encoder_mark_acceleration_structures_built(
61        &self,
62        command_encoder_id: CommandEncoderId,
63        blas_ids: &[BlasId],
64        tlas_ids: &[TlasId],
65    ) -> Result<(), EncoderStateError> {
66        profiling::scope!("CommandEncoder::mark_acceleration_structures_built");
67
68        let hub = &self.hub;
69
70        let cmd_enc = hub.command_encoders.get(command_encoder_id);
71
72        let mut cmd_buf_data = cmd_enc.data.lock();
73        cmd_buf_data.record_with(
74            |cmd_buf_data| -> Result<(), BuildAccelerationStructureError> {
75                let device = &cmd_enc.device;
76                device.check_is_valid()?;
77                device.require_features(Features::EXPERIMENTAL_RAY_QUERY)?;
78
79                let mut build_command = AsBuild::default();
80
81                for blas in blas_ids {
82                    let blas = hub.blas_s.get(*blas).get()?;
83                    build_command.blas_s_built.push(blas);
84                }
85
86                for tlas in tlas_ids {
87                    let tlas = hub.tlas_s.get(*tlas).get()?;
88                    build_command.tlas_s_built.push(TlasBuild {
89                        tlas,
90                        dependencies: Vec::new(),
91                    });
92                }
93
94                cmd_buf_data.as_actions.push(AsAction::Build(build_command));
95                Ok(())
96            },
97        )
98    }
99
100    pub fn command_encoder_build_acceleration_structures<'a>(
101        &self,
102        command_encoder_id: CommandEncoderId,
103        blas_iter: impl Iterator<Item = BlasBuildEntry<'a>>,
104        tlas_iter: impl Iterator<Item = TlasPackage<'a>>,
105    ) -> Result<(), EncoderStateError> {
106        profiling::scope!("CommandEncoder::build_acceleration_structures");
107
108        let hub = &self.hub;
109
110        let cmd_enc = hub.command_encoders.get(command_encoder_id);
111
112        let mut build_command = AsBuild::default();
113
114        let trace_blas: Vec<TraceBlasBuildEntry> = blas_iter
115            .map(|blas_entry| {
116                let geometries = match blas_entry.geometries {
117                    BlasGeometries::TriangleGeometries(triangle_geometries) => {
118                        TraceBlasGeometries::TriangleGeometries(
119                            triangle_geometries
120                                .map(|tg| TraceBlasTriangleGeometry {
121                                    size: tg.size.clone(),
122                                    vertex_buffer: tg.vertex_buffer,
123                                    index_buffer: tg.index_buffer,
124                                    transform_buffer: tg.transform_buffer,
125                                    first_vertex: tg.first_vertex,
126                                    vertex_stride: tg.vertex_stride,
127                                    first_index: tg.first_index,
128                                    transform_buffer_offset: tg.transform_buffer_offset,
129                                })
130                                .collect(),
131                        )
132                    }
133                };
134                TraceBlasBuildEntry {
135                    blas_id: blas_entry.blas_id,
136                    geometries,
137                }
138            })
139            .collect();
140
141        let trace_tlas: Vec<TraceTlasPackage> = tlas_iter
142            .map(|package: TlasPackage| {
143                let instances = package
144                    .instances
145                    .map(|instance| {
146                        instance.map(|instance| TraceTlasInstance {
147                            blas_id: instance.blas_id,
148                            transform: *instance.transform,
149                            custom_data: instance.custom_data,
150                            mask: instance.mask,
151                        })
152                    })
153                    .collect();
154                TraceTlasPackage {
155                    tlas_id: package.tlas_id,
156                    instances,
157                    lowest_unmodified: package.lowest_unmodified,
158                }
159            })
160            .collect();
161
162        let blas_iter = trace_blas.iter().map(|blas_entry| {
163            let geometries = match &blas_entry.geometries {
164                TraceBlasGeometries::TriangleGeometries(triangle_geometries) => {
165                    let iter = triangle_geometries.iter().map(|tg| BlasTriangleGeometry {
166                        size: &tg.size,
167                        vertex_buffer: tg.vertex_buffer,
168                        index_buffer: tg.index_buffer,
169                        transform_buffer: tg.transform_buffer,
170                        first_vertex: tg.first_vertex,
171                        vertex_stride: tg.vertex_stride,
172                        first_index: tg.first_index,
173                        transform_buffer_offset: tg.transform_buffer_offset,
174                    });
175                    BlasGeometries::TriangleGeometries(Box::new(iter))
176                }
177            };
178            BlasBuildEntry {
179                blas_id: blas_entry.blas_id,
180                geometries,
181            }
182        });
183
184        let tlas_iter = trace_tlas.iter().map(|tlas_package| {
185            let instances = tlas_package.instances.iter().map(|instance| {
186                instance.as_ref().map(|instance| TlasInstance {
187                    blas_id: instance.blas_id,
188                    transform: &instance.transform,
189                    custom_data: instance.custom_data,
190                    mask: instance.mask,
191                })
192            });
193            TlasPackage {
194                tlas_id: tlas_package.tlas_id,
195                instances: Box::new(instances),
196                lowest_unmodified: tlas_package.lowest_unmodified,
197            }
198        });
199
200        let mut cmd_buf_data = cmd_enc.data.lock();
201        cmd_buf_data.record_with(|cmd_buf_data| {
202            #[cfg(feature = "trace")]
203            if let Some(ref mut list) = cmd_buf_data.commands {
204                list.push(crate::device::trace::Command::BuildAccelerationStructures {
205                    blas: trace_blas.clone(),
206                    tlas: trace_tlas.clone(),
207                });
208            }
209
210            let device = &cmd_enc.device;
211            device.check_is_valid()?;
212            device.require_features(Features::EXPERIMENTAL_RAY_QUERY)?;
213
214            let mut buf_storage = Vec::new();
215            iter_blas(
216                blas_iter,
217                cmd_buf_data,
218                &mut build_command,
219                &mut buf_storage,
220                hub,
221            )?;
222
223            let snatch_guard = device.snatchable_lock.read();
224            let mut input_barriers = Vec::<hal::BufferBarrier<dyn hal::DynBuffer>>::new();
225            let mut scratch_buffer_blas_size = 0;
226            let mut blas_storage = Vec::new();
227            iter_buffers(
228                &mut buf_storage,
229                &snatch_guard,
230                &mut input_barriers,
231                cmd_buf_data,
232                &mut scratch_buffer_blas_size,
233                &mut blas_storage,
234                hub,
235                device.alignments.ray_tracing_scratch_buffer_alignment,
236            )?;
237            let mut tlas_lock_store = Vec::<(Option<TlasPackage>, Arc<Tlas>)>::new();
238
239            for package in tlas_iter {
240                let tlas = hub.tlas_s.get(package.tlas_id).get()?;
241
242                cmd_buf_data.trackers.tlas_s.insert_single(tlas.clone());
243
244                tlas_lock_store.push((Some(package), tlas))
245            }
246
247            let mut scratch_buffer_tlas_size = 0;
248            let mut tlas_storage = Vec::<TlasStore>::new();
249            let mut instance_buffer_staging_source = Vec::<u8>::new();
250
251            for (package, tlas) in &mut tlas_lock_store {
252                let package = package.take().unwrap();
253
254                let scratch_buffer_offset = scratch_buffer_tlas_size;
255                scratch_buffer_tlas_size += align_to(
256                    tlas.size_info.build_scratch_size as u32,
257                    device.alignments.ray_tracing_scratch_buffer_alignment,
258                ) as u64;
259
260                let first_byte_index = instance_buffer_staging_source.len();
261
262                let mut dependencies = Vec::new();
263
264                let mut instance_count = 0;
265                for instance in package.instances.flatten() {
266                    if instance.custom_data >= (1u32 << 24u32) {
267                        return Err(BuildAccelerationStructureError::TlasInvalidCustomIndex(
268                            tlas.error_ident(),
269                        ));
270                    }
271                    let blas = hub.blas_s.get(instance.blas_id).get()?;
272
273                    cmd_buf_data.trackers.blas_s.insert_single(blas.clone());
274
275                    instance_buffer_staging_source.extend(device.raw().tlas_instance_to_bytes(
276                        hal::TlasInstance {
277                            transform: *instance.transform,
278                            custom_data: instance.custom_data,
279                            mask: instance.mask,
280                            blas_address: blas.handle,
281                        },
282                    ));
283
284                    if tlas.flags.contains(
285                        wgpu_types::AccelerationStructureFlags::ALLOW_RAY_HIT_VERTEX_RETURN,
286                    ) && !blas.flags.contains(
287                        wgpu_types::AccelerationStructureFlags::ALLOW_RAY_HIT_VERTEX_RETURN,
288                    ) {
289                        return Err(
290                            BuildAccelerationStructureError::TlasDependentMissingVertexReturn(
291                                tlas.error_ident(),
292                                blas.error_ident(),
293                            ),
294                        );
295                    }
296
297                    instance_count += 1;
298
299                    dependencies.push(blas.clone());
300                }
301
302                build_command.tlas_s_built.push(TlasBuild {
303                    tlas: tlas.clone(),
304                    dependencies,
305                });
306
307                if instance_count > tlas.max_instance_count {
308                    return Err(BuildAccelerationStructureError::TlasInstanceCountExceeded(
309                        tlas.error_ident(),
310                        instance_count,
311                        tlas.max_instance_count,
312                    ));
313                }
314
315                tlas_storage.push(TlasStore {
316                    internal: UnsafeTlasStore {
317                        tlas: tlas.clone(),
318                        entries: hal::AccelerationStructureEntries::Instances(
319                            hal::AccelerationStructureInstances {
320                                buffer: Some(tlas.instance_buffer.as_ref()),
321                                offset: 0,
322                                count: instance_count,
323                            },
324                        ),
325                        scratch_buffer_offset,
326                    },
327                    range: first_byte_index..instance_buffer_staging_source.len(),
328                });
329            }
330
331            let Some(scratch_size) =
332                wgt::BufferSize::new(max(scratch_buffer_blas_size, scratch_buffer_tlas_size))
333            else {
334                // if the size is zero there is nothing to build
335                return Ok(());
336            };
337
338            let scratch_buffer = ScratchBuffer::new(device, scratch_size)?;
339
340            let scratch_buffer_barrier = hal::BufferBarrier::<dyn hal::DynBuffer> {
341                buffer: scratch_buffer.raw(),
342                usage: hal::StateTransition {
343                    from: BufferUses::ACCELERATION_STRUCTURE_SCRATCH,
344                    to: BufferUses::ACCELERATION_STRUCTURE_SCRATCH,
345                },
346            };
347
348            let mut tlas_descriptors = Vec::with_capacity(tlas_storage.len());
349
350            for &TlasStore {
351                internal:
352                    UnsafeTlasStore {
353                        ref tlas,
354                        ref entries,
355                        ref scratch_buffer_offset,
356                    },
357                ..
358            } in &tlas_storage
359            {
360                if tlas.update_mode == wgt::AccelerationStructureUpdateMode::PreferUpdate {
361                    log::info!("only rebuild implemented")
362                }
363                tlas_descriptors.push(hal::BuildAccelerationStructureDescriptor {
364                    entries,
365                    mode: hal::AccelerationStructureBuildMode::Build,
366                    flags: tlas.flags,
367                    source_acceleration_structure: None,
368                    destination_acceleration_structure: tlas.try_raw(&snatch_guard)?,
369                    scratch_buffer: scratch_buffer.raw(),
370                    scratch_buffer_offset: *scratch_buffer_offset,
371                })
372            }
373
374            let blas_present = !blas_storage.is_empty();
375            let tlas_present = !tlas_storage.is_empty();
376
377            let cmd_buf_raw = cmd_buf_data.encoder.open()?;
378
379            let mut blas_s_compactable = Vec::new();
380            let mut descriptors = Vec::new();
381
382            for storage in &blas_storage {
383                descriptors.push(map_blas(
384                    storage,
385                    scratch_buffer.raw(),
386                    &snatch_guard,
387                    &mut blas_s_compactable,
388                )?);
389            }
390
391            build_blas(
392                cmd_buf_raw,
393                blas_present,
394                tlas_present,
395                input_barriers,
396                &descriptors,
397                scratch_buffer_barrier,
398                blas_s_compactable,
399            );
400
401            if tlas_present {
402                let staging_buffer = if !instance_buffer_staging_source.is_empty() {
403                    let mut staging_buffer = StagingBuffer::new(
404                        device,
405                        wgt::BufferSize::new(instance_buffer_staging_source.len() as u64).unwrap(),
406                    )?;
407                    staging_buffer.write(&instance_buffer_staging_source);
408                    let flushed = staging_buffer.flush();
409                    Some(flushed)
410                } else {
411                    None
412                };
413
414                unsafe {
415                    if let Some(ref staging_buffer) = staging_buffer {
416                        cmd_buf_raw.transition_buffers(&[
417                            hal::BufferBarrier::<dyn hal::DynBuffer> {
418                                buffer: staging_buffer.raw(),
419                                usage: hal::StateTransition {
420                                    from: BufferUses::MAP_WRITE,
421                                    to: BufferUses::COPY_SRC,
422                                },
423                            },
424                        ]);
425                    }
426                }
427
428                let mut instance_buffer_barriers = Vec::new();
429                for &TlasStore {
430                    internal: UnsafeTlasStore { ref tlas, .. },
431                    ref range,
432                } in &tlas_storage
433                {
434                    let size = match wgt::BufferSize::new((range.end - range.start) as u64) {
435                        None => continue,
436                        Some(size) => size,
437                    };
438                    instance_buffer_barriers.push(hal::BufferBarrier::<dyn hal::DynBuffer> {
439                        buffer: tlas.instance_buffer.as_ref(),
440                        usage: hal::StateTransition {
441                            from: BufferUses::COPY_DST,
442                            to: BufferUses::TOP_LEVEL_ACCELERATION_STRUCTURE_INPUT,
443                        },
444                    });
445                    unsafe {
446                        cmd_buf_raw.transition_buffers(&[
447                            hal::BufferBarrier::<dyn hal::DynBuffer> {
448                                buffer: tlas.instance_buffer.as_ref(),
449                                usage: hal::StateTransition {
450                                    from: BufferUses::TOP_LEVEL_ACCELERATION_STRUCTURE_INPUT,
451                                    to: BufferUses::COPY_DST,
452                                },
453                            },
454                        ]);
455                        let temp = hal::BufferCopy {
456                            src_offset: range.start as u64,
457                            dst_offset: 0,
458                            size,
459                        };
460                        cmd_buf_raw.copy_buffer_to_buffer(
461                            // the range whose size we just checked end is at (at that point in time) instance_buffer_staging_source.len()
462                            // and since instance_buffer_staging_source doesn't shrink we can un wrap this without a panic
463                            staging_buffer.as_ref().unwrap().raw(),
464                            tlas.instance_buffer.as_ref(),
465                            &[temp],
466                        );
467                    }
468                }
469
470                unsafe {
471                    cmd_buf_raw.transition_buffers(&instance_buffer_barriers);
472
473                    cmd_buf_raw.build_acceleration_structures(&tlas_descriptors);
474
475                    cmd_buf_raw.place_acceleration_structure_barrier(
476                        hal::AccelerationStructureBarrier {
477                            usage: hal::StateTransition {
478                                from: hal::AccelerationStructureUses::BUILD_OUTPUT,
479                                to: hal::AccelerationStructureUses::SHADER_INPUT,
480                            },
481                        },
482                    );
483                }
484
485                if let Some(staging_buffer) = staging_buffer {
486                    cmd_buf_data
487                        .temp_resources
488                        .push(TempResource::StagingBuffer(staging_buffer));
489                }
490            }
491
492            cmd_buf_data
493                .temp_resources
494                .push(TempResource::ScratchBuffer(scratch_buffer));
495
496            cmd_buf_data.as_actions.push(AsAction::Build(build_command));
497
498            Ok(())
499        })
500    }
501}
502
503impl CommandBufferMutable {
504    pub(crate) fn validate_acceleration_structure_actions(
505        &self,
506        snatch_guard: &SnatchGuard,
507        command_index_guard: &mut RwLockWriteGuard<CommandIndices>,
508    ) -> Result<(), ValidateAsActionsError> {
509        profiling::scope!("CommandEncoder::[submission]::validate_as_actions");
510        for action in &self.as_actions {
511            match action {
512                AsAction::Build(build) => {
513                    let build_command_index = NonZeroU64::new(
514                        command_index_guard.next_acceleration_structure_build_command_index,
515                    )
516                    .unwrap();
517
518                    command_index_guard.next_acceleration_structure_build_command_index += 1;
519                    for blas in build.blas_s_built.iter() {
520                        let mut state_lock = blas.compacted_state.lock();
521                        *state_lock = match *state_lock {
522                            BlasCompactState::Compacted => {
523                                unreachable!("Should be validated out in build.")
524                            }
525                            // Reset the compacted state to idle. This means any prepares, before mapping their
526                            // internal buffer, will terminate.
527                            _ => BlasCompactState::Idle,
528                        };
529                        *blas.built_index.write() = Some(build_command_index);
530                    }
531
532                    for tlas_build in build.tlas_s_built.iter() {
533                        for blas in &tlas_build.dependencies {
534                            if blas.built_index.read().is_none() {
535                                return Err(ValidateAsActionsError::UsedUnbuiltBlas(
536                                    blas.error_ident(),
537                                    tlas_build.tlas.error_ident(),
538                                ));
539                            }
540                        }
541                        *tlas_build.tlas.built_index.write() = Some(build_command_index);
542                        tlas_build
543                            .tlas
544                            .dependencies
545                            .write()
546                            .clone_from(&tlas_build.dependencies)
547                    }
548                }
549                AsAction::UseTlas(tlas) => {
550                    let tlas_build_index = tlas.built_index.read();
551                    let dependencies = tlas.dependencies.read();
552
553                    if (*tlas_build_index).is_none() {
554                        return Err(ValidateAsActionsError::UsedUnbuiltTlas(tlas.error_ident()));
555                    }
556                    for blas in dependencies.deref() {
557                        let blas_build_index = *blas.built_index.read();
558                        if blas_build_index.is_none() {
559                            return Err(ValidateAsActionsError::UsedUnbuiltBlas(
560                                tlas.error_ident(),
561                                blas.error_ident(),
562                            ));
563                        }
564                        if blas_build_index.unwrap() > tlas_build_index.unwrap() {
565                            return Err(ValidateAsActionsError::BlasNewerThenTlas(
566                                blas.error_ident(),
567                                tlas.error_ident(),
568                            ));
569                        }
570                        blas.try_raw(snatch_guard)?;
571                    }
572                }
573            }
574        }
575        Ok(())
576    }
577}
578
579///iterates over the blas iterator, and it's geometry, pushing the buffers into a storage vector (and also some validation).
580fn iter_blas<'a>(
581    blas_iter: impl Iterator<Item = BlasBuildEntry<'a>>,
582    cmd_buf_data: &mut CommandBufferMutable,
583    build_command: &mut AsBuild,
584    buf_storage: &mut Vec<TriangleBufferStore<'a>>,
585    hub: &Hub,
586) -> Result<(), BuildAccelerationStructureError> {
587    let mut temp_buffer = Vec::new();
588    for entry in blas_iter {
589        let blas = hub.blas_s.get(entry.blas_id).get()?;
590        cmd_buf_data.trackers.blas_s.insert_single(blas.clone());
591
592        build_command.blas_s_built.push(blas.clone());
593
594        match entry.geometries {
595            BlasGeometries::TriangleGeometries(triangle_geometries) => {
596                for (i, mesh) in triangle_geometries.enumerate() {
597                    let size_desc = match &blas.sizes {
598                        wgt::BlasGeometrySizeDescriptors::Triangles { descriptors } => descriptors,
599                    };
600                    if i >= size_desc.len() {
601                        return Err(BuildAccelerationStructureError::IncompatibleBlasBuildSizes(
602                            blas.error_ident(),
603                        ));
604                    }
605                    let size_desc = &size_desc[i];
606
607                    if size_desc.flags != mesh.size.flags {
608                        return Err(BuildAccelerationStructureError::IncompatibleBlasFlags(
609                            blas.error_ident(),
610                            size_desc.flags,
611                            mesh.size.flags,
612                        ));
613                    }
614
615                    if size_desc.vertex_count < mesh.size.vertex_count {
616                        return Err(
617                            BuildAccelerationStructureError::IncompatibleBlasVertexCount(
618                                blas.error_ident(),
619                                size_desc.vertex_count,
620                                mesh.size.vertex_count,
621                            ),
622                        );
623                    }
624
625                    if size_desc.vertex_format != mesh.size.vertex_format {
626                        return Err(BuildAccelerationStructureError::DifferentBlasVertexFormats(
627                            blas.error_ident(),
628                            size_desc.vertex_format,
629                            mesh.size.vertex_format,
630                        ));
631                    }
632
633                    if size_desc
634                        .vertex_format
635                        .min_acceleration_structure_vertex_stride()
636                        > mesh.vertex_stride
637                    {
638                        return Err(BuildAccelerationStructureError::VertexStrideTooSmall(
639                            blas.error_ident(),
640                            size_desc
641                                .vertex_format
642                                .min_acceleration_structure_vertex_stride(),
643                            mesh.vertex_stride,
644                        ));
645                    }
646
647                    if mesh.vertex_stride
648                        % size_desc
649                            .vertex_format
650                            .acceleration_structure_stride_alignment()
651                        != 0
652                    {
653                        return Err(BuildAccelerationStructureError::VertexStrideUnaligned(
654                            blas.error_ident(),
655                            size_desc
656                                .vertex_format
657                                .acceleration_structure_stride_alignment(),
658                            mesh.vertex_stride,
659                        ));
660                    }
661
662                    match (size_desc.index_count, mesh.size.index_count) {
663                        (Some(_), None) | (None, Some(_)) => {
664                            return Err(
665                                BuildAccelerationStructureError::BlasIndexCountProvidedMismatch(
666                                    blas.error_ident(),
667                                ),
668                            )
669                        }
670                        (Some(create), Some(build)) if create < build => {
671                            return Err(
672                                BuildAccelerationStructureError::IncompatibleBlasIndexCount(
673                                    blas.error_ident(),
674                                    create,
675                                    build,
676                                ),
677                            )
678                        }
679                        _ => {}
680                    }
681
682                    if size_desc.index_format != mesh.size.index_format {
683                        return Err(BuildAccelerationStructureError::DifferentBlasIndexFormats(
684                            blas.error_ident(),
685                            size_desc.index_format,
686                            mesh.size.index_format,
687                        ));
688                    }
689
690                    if size_desc.index_count.is_some() && mesh.index_buffer.is_none() {
691                        return Err(BuildAccelerationStructureError::MissingIndexBuffer(
692                            blas.error_ident(),
693                        ));
694                    }
695                    let vertex_buffer = hub.buffers.get(mesh.vertex_buffer).get()?;
696                    let vertex_pending = cmd_buf_data.trackers.buffers.set_single(
697                        &vertex_buffer,
698                        BufferUses::BOTTOM_LEVEL_ACCELERATION_STRUCTURE_INPUT,
699                    );
700                    let index_data = if let Some(index_id) = mesh.index_buffer {
701                        let index_buffer = hub.buffers.get(index_id).get()?;
702                        if mesh.first_index.is_none()
703                            || mesh.size.index_count.is_none()
704                            || mesh.size.index_count.is_none()
705                        {
706                            return Err(BuildAccelerationStructureError::MissingAssociatedData(
707                                index_buffer.error_ident(),
708                            ));
709                        }
710                        let data = cmd_buf_data.trackers.buffers.set_single(
711                            &index_buffer,
712                            BufferUses::BOTTOM_LEVEL_ACCELERATION_STRUCTURE_INPUT,
713                        );
714                        Some((index_buffer, data))
715                    } else {
716                        None
717                    };
718                    let transform_data = if let Some(transform_id) = mesh.transform_buffer {
719                        if !blas
720                            .flags
721                            .contains(wgt::AccelerationStructureFlags::USE_TRANSFORM)
722                        {
723                            return Err(BuildAccelerationStructureError::UseTransformMissing(
724                                blas.error_ident(),
725                            ));
726                        }
727                        let transform_buffer = hub.buffers.get(transform_id).get()?;
728                        if mesh.transform_buffer_offset.is_none() {
729                            return Err(BuildAccelerationStructureError::MissingAssociatedData(
730                                transform_buffer.error_ident(),
731                            ));
732                        }
733                        let data = cmd_buf_data.trackers.buffers.set_single(
734                            &transform_buffer,
735                            BufferUses::BOTTOM_LEVEL_ACCELERATION_STRUCTURE_INPUT,
736                        );
737                        Some((transform_buffer, data))
738                    } else {
739                        if blas
740                            .flags
741                            .contains(wgt::AccelerationStructureFlags::USE_TRANSFORM)
742                        {
743                            return Err(BuildAccelerationStructureError::TransformMissing(
744                                blas.error_ident(),
745                            ));
746                        }
747                        None
748                    };
749                    temp_buffer.push(TriangleBufferStore {
750                        vertex_buffer,
751                        vertex_transition: vertex_pending,
752                        index_buffer_transition: index_data,
753                        transform_buffer_transition: transform_data,
754                        geometry: mesh,
755                        ending_blas: None,
756                    });
757                }
758
759                if let Some(last) = temp_buffer.last_mut() {
760                    last.ending_blas = Some(blas);
761                    buf_storage.append(&mut temp_buffer);
762                }
763            }
764        }
765    }
766    Ok(())
767}
768
769/// Iterates over the buffers generated in [iter_blas], convert the barriers into hal barriers, and the triangles into [hal::AccelerationStructureEntries] (and also some validation).
770fn iter_buffers<'a, 'b>(
771    buf_storage: &'a mut Vec<TriangleBufferStore<'b>>,
772    snatch_guard: &'a SnatchGuard,
773    input_barriers: &mut Vec<hal::BufferBarrier<'a, dyn hal::DynBuffer>>,
774    cmd_buf_data: &mut CommandBufferMutable,
775    scratch_buffer_blas_size: &mut u64,
776    blas_storage: &mut Vec<BlasStore<'a>>,
777    hub: &Hub,
778    ray_tracing_scratch_buffer_alignment: u32,
779) -> Result<(), BuildAccelerationStructureError> {
780    let mut triangle_entries =
781        Vec::<hal::AccelerationStructureTriangles<dyn hal::DynBuffer>>::new();
782    for buf in buf_storage {
783        let mesh = &buf.geometry;
784        let vertex_buffer = {
785            let vertex_buffer = buf.vertex_buffer.as_ref();
786            let vertex_raw = vertex_buffer.try_raw(snatch_guard)?;
787            vertex_buffer.check_usage(BufferUsages::BLAS_INPUT)?;
788
789            if let Some(barrier) = buf
790                .vertex_transition
791                .take()
792                .map(|pending| pending.into_hal(vertex_buffer, snatch_guard))
793            {
794                input_barriers.push(barrier);
795            }
796            if vertex_buffer.size
797                < (mesh.size.vertex_count + mesh.first_vertex) as u64 * mesh.vertex_stride
798            {
799                return Err(BuildAccelerationStructureError::InsufficientBufferSize(
800                    vertex_buffer.error_ident(),
801                    vertex_buffer.size,
802                    (mesh.size.vertex_count + mesh.first_vertex) as u64 * mesh.vertex_stride,
803                ));
804            }
805            let vertex_buffer_offset = mesh.first_vertex as u64 * mesh.vertex_stride;
806            cmd_buf_data.buffer_memory_init_actions.extend(
807                vertex_buffer.initialization_status.read().create_action(
808                    &hub.buffers.get(mesh.vertex_buffer).get()?,
809                    vertex_buffer_offset
810                        ..(vertex_buffer_offset
811                            + mesh.size.vertex_count as u64 * mesh.vertex_stride),
812                    MemoryInitKind::NeedsInitializedMemory,
813                ),
814            );
815            vertex_raw
816        };
817        let index_buffer = if let Some((ref mut index_buffer, ref mut index_pending)) =
818            buf.index_buffer_transition
819        {
820            let index_raw = index_buffer.try_raw(snatch_guard)?;
821            index_buffer.check_usage(BufferUsages::BLAS_INPUT)?;
822
823            if let Some(barrier) = index_pending
824                .take()
825                .map(|pending| pending.into_hal(index_buffer, snatch_guard))
826            {
827                input_barriers.push(barrier);
828            }
829            let index_stride = mesh.size.index_format.unwrap().byte_size() as u64;
830            let offset = mesh.first_index.unwrap() as u64 * index_stride;
831            let index_buffer_size = mesh.size.index_count.unwrap() as u64 * index_stride;
832
833            if mesh.size.index_count.unwrap() % 3 != 0 {
834                return Err(BuildAccelerationStructureError::InvalidIndexCount(
835                    index_buffer.error_ident(),
836                    mesh.size.index_count.unwrap(),
837                ));
838            }
839            if index_buffer.size < mesh.size.index_count.unwrap() as u64 * index_stride + offset {
840                return Err(BuildAccelerationStructureError::InsufficientBufferSize(
841                    index_buffer.error_ident(),
842                    index_buffer.size,
843                    mesh.size.index_count.unwrap() as u64 * index_stride + offset,
844                ));
845            }
846
847            cmd_buf_data.buffer_memory_init_actions.extend(
848                index_buffer.initialization_status.read().create_action(
849                    index_buffer,
850                    offset..(offset + index_buffer_size),
851                    MemoryInitKind::NeedsInitializedMemory,
852                ),
853            );
854            Some(index_raw)
855        } else {
856            None
857        };
858        let transform_buffer = if let Some((ref mut transform_buffer, ref mut transform_pending)) =
859            buf.transform_buffer_transition
860        {
861            if mesh.transform_buffer_offset.is_none() {
862                return Err(BuildAccelerationStructureError::MissingAssociatedData(
863                    transform_buffer.error_ident(),
864                ));
865            }
866            let transform_raw = transform_buffer.try_raw(snatch_guard)?;
867            transform_buffer.check_usage(BufferUsages::BLAS_INPUT)?;
868
869            if let Some(barrier) = transform_pending
870                .take()
871                .map(|pending| pending.into_hal(transform_buffer, snatch_guard))
872            {
873                input_barriers.push(barrier);
874            }
875
876            let offset = mesh.transform_buffer_offset.unwrap();
877
878            if offset % wgt::TRANSFORM_BUFFER_ALIGNMENT != 0 {
879                return Err(
880                    BuildAccelerationStructureError::UnalignedTransformBufferOffset(
881                        transform_buffer.error_ident(),
882                    ),
883                );
884            }
885            if transform_buffer.size < 48 + offset {
886                return Err(BuildAccelerationStructureError::InsufficientBufferSize(
887                    transform_buffer.error_ident(),
888                    transform_buffer.size,
889                    48 + offset,
890                ));
891            }
892            cmd_buf_data.buffer_memory_init_actions.extend(
893                transform_buffer.initialization_status.read().create_action(
894                    transform_buffer,
895                    offset..(offset + 48),
896                    MemoryInitKind::NeedsInitializedMemory,
897                ),
898            );
899            Some(transform_raw)
900        } else {
901            None
902        };
903
904        let triangles = hal::AccelerationStructureTriangles {
905            vertex_buffer: Some(vertex_buffer),
906            vertex_format: mesh.size.vertex_format,
907            first_vertex: mesh.first_vertex,
908            vertex_count: mesh.size.vertex_count,
909            vertex_stride: mesh.vertex_stride,
910            indices: index_buffer.map(|index_buffer| {
911                let index_stride = mesh.size.index_format.unwrap().byte_size() as u32;
912                hal::AccelerationStructureTriangleIndices::<dyn hal::DynBuffer> {
913                    format: mesh.size.index_format.unwrap(),
914                    buffer: Some(index_buffer),
915                    offset: mesh.first_index.unwrap() * index_stride,
916                    count: mesh.size.index_count.unwrap(),
917                }
918            }),
919            transform: transform_buffer.map(|transform_buffer| {
920                hal::AccelerationStructureTriangleTransform {
921                    buffer: transform_buffer,
922                    offset: mesh.transform_buffer_offset.unwrap() as u32,
923                }
924            }),
925            flags: mesh.size.flags,
926        };
927        triangle_entries.push(triangles);
928        if let Some(blas) = buf.ending_blas.take() {
929            let scratch_buffer_offset = *scratch_buffer_blas_size;
930            *scratch_buffer_blas_size += align_to(
931                blas.size_info.build_scratch_size as u32,
932                ray_tracing_scratch_buffer_alignment,
933            ) as u64;
934
935            blas_storage.push(BlasStore {
936                blas,
937                entries: hal::AccelerationStructureEntries::Triangles(triangle_entries),
938                scratch_buffer_offset,
939            });
940            triangle_entries = Vec::new();
941        }
942    }
943    Ok(())
944}
945
946fn map_blas<'a>(
947    storage: &'a BlasStore<'_>,
948    scratch_buffer: &'a dyn hal::DynBuffer,
949    snatch_guard: &'a SnatchGuard,
950    blases_compactable: &mut Vec<(
951        &'a dyn hal::DynBuffer,
952        &'a dyn hal::DynAccelerationStructure,
953    )>,
954) -> Result<
955    hal::BuildAccelerationStructureDescriptor<
956        'a,
957        dyn hal::DynBuffer,
958        dyn hal::DynAccelerationStructure,
959    >,
960    BuildAccelerationStructureError,
961> {
962    let BlasStore {
963        blas,
964        entries,
965        scratch_buffer_offset,
966    } = storage;
967    if blas.update_mode == wgt::AccelerationStructureUpdateMode::PreferUpdate {
968        log::info!("only rebuild implemented")
969    }
970    let raw = blas.try_raw(snatch_guard)?;
971
972    let state_lock = blas.compacted_state.lock();
973    if let BlasCompactState::Compacted = *state_lock {
974        return Err(BuildAccelerationStructureError::CompactedBlas(
975            blas.error_ident(),
976        ));
977    }
978
979    if blas
980        .flags
981        .contains(wgpu_types::AccelerationStructureFlags::ALLOW_COMPACTION)
982    {
983        blases_compactable.push((blas.compaction_buffer.as_ref().unwrap().as_ref(), raw));
984    }
985    Ok(hal::BuildAccelerationStructureDescriptor {
986        entries,
987        mode: hal::AccelerationStructureBuildMode::Build,
988        flags: blas.flags,
989        source_acceleration_structure: None,
990        destination_acceleration_structure: raw,
991        scratch_buffer,
992        scratch_buffer_offset: *scratch_buffer_offset,
993    })
994}
995
996fn build_blas<'a>(
997    cmd_buf_raw: &mut dyn hal::DynCommandEncoder,
998    blas_present: bool,
999    tlas_present: bool,
1000    input_barriers: Vec<hal::BufferBarrier<dyn hal::DynBuffer>>,
1001    blas_descriptors: &[hal::BuildAccelerationStructureDescriptor<
1002        'a,
1003        dyn hal::DynBuffer,
1004        dyn hal::DynAccelerationStructure,
1005    >],
1006    scratch_buffer_barrier: hal::BufferBarrier<dyn hal::DynBuffer>,
1007    blas_s_for_compaction: Vec<(
1008        &'a dyn hal::DynBuffer,
1009        &'a dyn hal::DynAccelerationStructure,
1010    )>,
1011) {
1012    unsafe {
1013        cmd_buf_raw.transition_buffers(&input_barriers);
1014    }
1015
1016    if blas_present {
1017        unsafe {
1018            cmd_buf_raw.place_acceleration_structure_barrier(hal::AccelerationStructureBarrier {
1019                usage: hal::StateTransition {
1020                    from: hal::AccelerationStructureUses::BUILD_INPUT,
1021                    to: hal::AccelerationStructureUses::BUILD_OUTPUT,
1022                },
1023            });
1024
1025            cmd_buf_raw.build_acceleration_structures(blas_descriptors);
1026        }
1027    }
1028
1029    if blas_present && tlas_present {
1030        unsafe {
1031            cmd_buf_raw.transition_buffers(&[scratch_buffer_barrier]);
1032        }
1033    }
1034
1035    let mut source_usage = hal::AccelerationStructureUses::empty();
1036    let mut destination_usage = hal::AccelerationStructureUses::empty();
1037    for &(buf, blas) in blas_s_for_compaction.iter() {
1038        unsafe {
1039            cmd_buf_raw.transition_buffers(&[hal::BufferBarrier {
1040                buffer: buf,
1041                usage: hal::StateTransition {
1042                    from: BufferUses::ACCELERATION_STRUCTURE_QUERY,
1043                    to: BufferUses::ACCELERATION_STRUCTURE_QUERY,
1044                },
1045            }])
1046        }
1047        unsafe { cmd_buf_raw.read_acceleration_structure_compact_size(blas, buf) }
1048        destination_usage |= hal::AccelerationStructureUses::COPY_SRC;
1049    }
1050
1051    if blas_present {
1052        source_usage |= hal::AccelerationStructureUses::BUILD_OUTPUT;
1053        destination_usage |= hal::AccelerationStructureUses::BUILD_INPUT
1054    }
1055    if tlas_present {
1056        source_usage |= hal::AccelerationStructureUses::SHADER_INPUT;
1057        destination_usage |= hal::AccelerationStructureUses::BUILD_OUTPUT;
1058    }
1059    unsafe {
1060        cmd_buf_raw.place_acceleration_structure_barrier(hal::AccelerationStructureBarrier {
1061            usage: hal::StateTransition {
1062                from: source_usage,
1063                to: destination_usage,
1064            },
1065        });
1066    }
1067}