wgpu_core/device/
ray_tracing.rs

1use alloc::{string::ToString as _, sync::Arc, vec::Vec};
2use core::mem::{size_of, ManuallyDrop};
3
4#[cfg(feature = "trace")]
5use crate::device::trace::{Action, IntoTrace};
6use crate::device::DeviceError;
7use crate::{
8    api_log,
9    device::Device,
10    global::Global,
11    hal_label,
12    id::{self, BlasId, TlasId},
13    lock::RwLock,
14    lock::{rank, Mutex},
15    ray_tracing::BlasPrepareCompactError,
16    ray_tracing::{CreateBlasError, CreateTlasError},
17    resource,
18    resource::{
19        BlasCompactCallback, BlasCompactState, Fallible, InvalidResourceError, TrackingData,
20    },
21    snatch::Snatchable,
22    LabelHelpers,
23};
24use hal::AccelerationStructureTriangleIndices;
25use wgt::{Features, AABB_GEOMETRY_MIN_STRIDE};
26
27impl Device {
28    pub fn create_blas(
29        self: &Arc<Self>,
30        blas_desc: &resource::BlasDescriptor,
31        sizes: wgt::BlasGeometrySizeDescriptors,
32    ) -> Result<Arc<resource::Blas>, CreateBlasError> {
33        self.check_is_valid()?;
34        self.require_features(Features::EXPERIMENTAL_RAY_QUERY)?;
35
36        if blas_desc
37            .flags
38            .contains(wgt::AccelerationStructureFlags::ALLOW_RAY_HIT_VERTEX_RETURN)
39        {
40            self.require_features(Features::EXPERIMENTAL_RAY_HIT_VERTEX_RETURN)?;
41        }
42
43        let size_info = match &sizes {
44            wgt::BlasGeometrySizeDescriptors::Triangles { descriptors } => {
45                if descriptors.len() as u32 > self.limits.max_blas_geometry_count {
46                    return Err(CreateBlasError::TooManyGeometries(
47                        self.limits.max_blas_geometry_count,
48                        descriptors.len() as u32,
49                    ));
50                }
51
52                let mut entries =
53                    Vec::<hal::AccelerationStructureTriangles<dyn hal::DynBuffer>>::with_capacity(
54                        descriptors.len(),
55                    );
56                for desc in descriptors {
57                    if desc.index_count.is_some() != desc.index_format.is_some() {
58                        return Err(CreateBlasError::MissingIndexData);
59                    }
60                    let indices =
61                        desc.index_count
62                            .map(|count| AccelerationStructureTriangleIndices::<
63                                dyn hal::DynBuffer,
64                            > {
65                                format: desc.index_format.unwrap(),
66                                buffer: Some(self.zero_buffer.as_ref()),
67                                offset: 0,
68                                count,
69                            });
70                    if !self
71                        .features
72                        .allowed_vertex_formats_for_blas()
73                        .contains(&desc.vertex_format)
74                    {
75                        return Err(CreateBlasError::InvalidVertexFormat(
76                            desc.vertex_format,
77                            self.features.allowed_vertex_formats_for_blas(),
78                        ));
79                    }
80
81                    let mut transform = None;
82
83                    if blas_desc
84                        .flags
85                        .contains(wgt::AccelerationStructureFlags::USE_TRANSFORM)
86                    {
87                        transform = Some(wgpu_hal::AccelerationStructureTriangleTransform {
88                            buffer: self.zero_buffer.as_ref(),
89                            offset: 0,
90                        })
91                    }
92
93                    if desc.vertex_count > self.limits.max_blas_primitive_count {
94                        return Err(CreateBlasError::TooManyPrimitives(
95                            self.limits.max_blas_primitive_count,
96                            desc.vertex_count,
97                        ));
98                    }
99
100                    entries.push(hal::AccelerationStructureTriangles::<dyn hal::DynBuffer> {
101                        vertex_buffer: Some(self.zero_buffer.as_ref()),
102                        vertex_format: desc.vertex_format,
103                        first_vertex: 0,
104                        vertex_count: desc.vertex_count,
105                        vertex_stride: 0,
106                        indices,
107                        transform,
108                        flags: desc.flags,
109                    });
110                }
111                unsafe {
112                    self.raw().get_acceleration_structure_build_sizes(
113                        &hal::GetAccelerationStructureBuildSizesDescriptor {
114                            entries: &hal::AccelerationStructureEntries::Triangles(entries),
115                            flags: blas_desc.flags,
116                        },
117                    )
118                }
119            }
120            wgt::BlasGeometrySizeDescriptors::AABBs { descriptors } => {
121                if descriptors.len() as u32 > self.limits.max_blas_geometry_count {
122                    return Err(CreateBlasError::TooManyGeometries(
123                        self.limits.max_blas_geometry_count,
124                        descriptors.len() as u32,
125                    ));
126                }
127
128                let mut entries =
129                    Vec::<hal::AccelerationStructureAABBs<dyn hal::DynBuffer>>::with_capacity(
130                        descriptors.len(),
131                    );
132                for desc in descriptors {
133                    if desc.primitive_count > self.limits.max_blas_primitive_count {
134                        return Err(CreateBlasError::TooManyPrimitives(
135                            self.limits.max_blas_primitive_count,
136                            desc.primitive_count,
137                        ));
138                    }
139
140                    entries.push(hal::AccelerationStructureAABBs::<dyn hal::DynBuffer> {
141                        buffer: Some(self.zero_buffer.as_ref()),
142                        offset: 0,
143                        count: desc.primitive_count,
144                        stride: AABB_GEOMETRY_MIN_STRIDE,
145                        flags: desc.flags,
146                    });
147                }
148                unsafe {
149                    self.raw().get_acceleration_structure_build_sizes(
150                        &hal::GetAccelerationStructureBuildSizesDescriptor {
151                            entries: &hal::AccelerationStructureEntries::AABBs(entries),
152                            flags: blas_desc.flags,
153                        },
154                    )
155                }
156            }
157        };
158
159        let raw = unsafe {
160            self.raw()
161                .create_acceleration_structure(&hal::AccelerationStructureDescriptor {
162                    label: blas_desc.label.as_deref(),
163                    size: size_info.acceleration_structure_size,
164                    format: hal::AccelerationStructureFormat::BottomLevel,
165                    allow_compaction: blas_desc
166                        .flags
167                        .contains(wgpu_types::AccelerationStructureFlags::ALLOW_COMPACTION),
168                })
169        }
170        .map_err(|e| self.handle_hal_error_with_nonfatal_oom(e))?;
171
172        let compaction_buffer = if blas_desc
173            .flags
174            .contains(wgpu_types::AccelerationStructureFlags::ALLOW_COMPACTION)
175        {
176            Some(ManuallyDrop::new(unsafe {
177                self.raw()
178                    .create_buffer(&hal::BufferDescriptor {
179                        label: Some("(wgpu internal) compaction read-back buffer"),
180                        size: size_of::<wgpu_types::BufferAddress>() as wgpu_types::BufferAddress,
181                        usage: wgpu_types::BufferUses::ACCELERATION_STRUCTURE_QUERY
182                            | wgpu_types::BufferUses::MAP_READ,
183                        memory_flags: hal::MemoryFlags::PREFER_COHERENT,
184                    })
185                    .map_err(DeviceError::from_hal)?
186            }))
187        } else {
188            None
189        };
190
191        let handle = unsafe {
192            self.raw()
193                .get_acceleration_structure_device_address(raw.as_ref())
194        };
195
196        Ok(Arc::new(resource::Blas {
197            raw: Snatchable::new(raw),
198            device: self.clone(),
199            size_info,
200            sizes,
201            flags: blas_desc.flags,
202            update_mode: blas_desc.update_mode,
203            handle,
204            label: blas_desc.label.to_string(),
205            built_index: RwLock::new(rank::BLAS_BUILT_INDEX, None),
206            tracking_data: TrackingData::new(self.tracker_indices.blas_s.clone()),
207            compaction_buffer,
208            compacted_state: Mutex::new(rank::BLAS_COMPACTION_STATE, BlasCompactState::Idle),
209        }))
210    }
211
212    pub fn create_tlas(
213        self: &Arc<Self>,
214        desc: &resource::TlasDescriptor,
215    ) -> Result<Arc<resource::Tlas>, CreateTlasError> {
216        self.check_is_valid()?;
217        self.require_features(Features::EXPERIMENTAL_RAY_QUERY)?;
218
219        if desc.max_instances > self.limits.max_tlas_instance_count {
220            return Err(CreateTlasError::TooManyInstances(
221                self.limits.max_tlas_instance_count,
222                desc.max_instances,
223            ));
224        }
225
226        if desc
227            .flags
228            .contains(wgt::AccelerationStructureFlags::USE_TRANSFORM)
229        {
230            return Err(CreateTlasError::DisallowedFlag(
231                wgt::AccelerationStructureFlags::USE_TRANSFORM,
232            ));
233        }
234
235        if desc
236            .flags
237            .contains(wgt::AccelerationStructureFlags::ALLOW_RAY_HIT_VERTEX_RETURN)
238        {
239            self.require_features(Features::EXPERIMENTAL_RAY_HIT_VERTEX_RETURN)?;
240        }
241
242        let size_info = unsafe {
243            self.raw().get_acceleration_structure_build_sizes(
244                &hal::GetAccelerationStructureBuildSizesDescriptor {
245                    entries: &hal::AccelerationStructureEntries::Instances(
246                        hal::AccelerationStructureInstances {
247                            buffer: Some(self.zero_buffer.as_ref()),
248                            offset: 0,
249                            count: desc.max_instances,
250                        },
251                    ),
252                    flags: desc.flags,
253                },
254            )
255        };
256
257        let raw = unsafe {
258            self.raw()
259                .create_acceleration_structure(&hal::AccelerationStructureDescriptor {
260                    label: desc.label.as_deref(),
261                    size: size_info.acceleration_structure_size,
262                    format: hal::AccelerationStructureFormat::TopLevel,
263                    allow_compaction: false,
264                })
265        }
266        .map_err(|e| self.handle_hal_error_with_nonfatal_oom(e))?;
267
268        let instance_buffer_size = self
269            .alignments
270            .raw_tlas_instance_size
271            .checked_mul(desc.max_instances.max(1))
272            .expect("max_tlas_instance_count should not allow excessive buffer size");
273        let instance_buffer = unsafe {
274            self.raw().create_buffer(&hal::BufferDescriptor {
275                label: hal_label(Some("(wgpu-core) instances_buffer"), self.instance_flags),
276                size: u64::from(instance_buffer_size),
277                usage: wgt::BufferUses::COPY_DST
278                    | wgt::BufferUses::TOP_LEVEL_ACCELERATION_STRUCTURE_INPUT,
279                memory_flags: hal::MemoryFlags::PREFER_COHERENT,
280            })
281        }
282        .map_err(|e| self.handle_hal_error_with_nonfatal_oom(e))?;
283
284        Ok(Arc::new(resource::Tlas {
285            raw: Snatchable::new(raw),
286            device: self.clone(),
287            size_info,
288            flags: desc.flags,
289            update_mode: desc.update_mode,
290            built_index: RwLock::new(rank::TLAS_BUILT_INDEX, None),
291            dependencies: RwLock::new(rank::TLAS_DEPENDENCIES, Vec::new()),
292            instance_buffer: ManuallyDrop::new(instance_buffer),
293            label: desc.label.to_string(),
294            max_instance_count: desc.max_instances,
295            tracking_data: TrackingData::new(self.tracker_indices.tlas_s.clone()),
296        }))
297    }
298}
299
300impl Global {
301    pub fn device_create_blas(
302        &self,
303        device_id: id::DeviceId,
304        desc: &resource::BlasDescriptor,
305        sizes: wgt::BlasGeometrySizeDescriptors,
306        id_in: Option<BlasId>,
307    ) -> (BlasId, Option<u64>, Option<CreateBlasError>) {
308        profiling::scope!("Device::create_blas");
309
310        let fid = self.hub.blas_s.prepare(id_in);
311
312        let error = 'error: {
313            let device = self.hub.devices.get(device_id);
314
315            #[cfg(feature = "trace")]
316            let trace_sizes = sizes.clone();
317
318            let blas = match device.create_blas(desc, sizes) {
319                Ok(blas) => blas,
320                Err(e) => break 'error e,
321            };
322            let handle = blas.handle;
323
324            #[cfg(feature = "trace")]
325            if let Some(trace) = device.trace.lock().as_mut() {
326                trace.add(Action::CreateBlas {
327                    id: blas.to_trace(),
328                    desc: desc.clone(),
329                    sizes: trace_sizes,
330                });
331            }
332
333            let id = fid.assign(Fallible::Valid(blas));
334            api_log!("Device::create_blas -> {id:?}");
335
336            return (id, Some(handle), None);
337        };
338
339        let id = fid.assign(Fallible::Invalid(Arc::new(error.to_string())));
340        (id, None, Some(error))
341    }
342
343    pub fn device_create_tlas(
344        &self,
345        device_id: id::DeviceId,
346        desc: &resource::TlasDescriptor,
347        id_in: Option<TlasId>,
348    ) -> (TlasId, Option<CreateTlasError>) {
349        profiling::scope!("Device::create_tlas");
350
351        let fid = self.hub.tlas_s.prepare(id_in);
352
353        let error = 'error: {
354            let device = self.hub.devices.get(device_id);
355
356            let tlas = match device.create_tlas(desc) {
357                Ok(tlas) => tlas,
358                Err(e) => break 'error e,
359            };
360
361            #[cfg(feature = "trace")]
362            if let Some(trace) = device.trace.lock().as_mut() {
363                trace.add(Action::CreateTlas {
364                    id: tlas.to_trace(),
365                    desc: desc.clone(),
366                });
367            }
368
369            let id = fid.assign(Fallible::Valid(tlas));
370            api_log!("Device::create_tlas -> {id:?}");
371
372            return (id, None);
373        };
374
375        let id = fid.assign(Fallible::Invalid(Arc::new(error.to_string())));
376        (id, Some(error))
377    }
378
379    pub fn blas_drop(&self, blas_id: BlasId) {
380        profiling::scope!("Blas::drop");
381        api_log!("Blas::drop {blas_id:?}");
382
383        let _blas = self.hub.blas_s.remove(blas_id);
384
385        #[cfg(feature = "trace")]
386        if let Ok(blas) = _blas.get() {
387            if let Some(t) = blas.device.trace.lock().as_mut() {
388                t.add(Action::DestroyBlas(blas.to_trace()));
389            }
390        }
391    }
392
393    pub fn tlas_drop(&self, tlas_id: TlasId) {
394        profiling::scope!("Tlas::drop");
395        api_log!("Tlas::drop {tlas_id:?}");
396
397        let _tlas = self.hub.tlas_s.remove(tlas_id);
398
399        #[cfg(feature = "trace")]
400        if let Ok(tlas) = _tlas.get() {
401            if let Some(t) = tlas.device.trace.lock().as_mut() {
402                t.add(Action::DestroyTlas(tlas.to_trace()));
403            }
404        }
405    }
406
407    pub fn blas_prepare_compact_async(
408        &self,
409        blas_id: BlasId,
410        callback: Option<BlasCompactCallback>,
411    ) -> Result<crate::SubmissionIndex, BlasPrepareCompactError> {
412        profiling::scope!("Blas::prepare_compact_async");
413        api_log!("Blas::prepare_compact_async {blas_id:?}");
414
415        let hub = &self.hub;
416
417        let compact_result = match hub.blas_s.get(blas_id).get() {
418            Ok(blas) => blas.prepare_compact_async(callback),
419            Err(e) => Err((callback, e.into())),
420        };
421
422        match compact_result {
423            Ok(submission_index) => Ok(submission_index),
424            Err((mut callback, err)) => {
425                if let Some(callback) = callback.take() {
426                    callback(Err(err.clone()));
427                }
428                Err(err)
429            }
430        }
431    }
432
433    pub fn ready_for_compaction(&self, blas_id: BlasId) -> Result<bool, InvalidResourceError> {
434        profiling::scope!("Blas::prepare_compact_async");
435        api_log!("Blas::prepare_compact_async {blas_id:?}");
436
437        let hub = &self.hub;
438
439        let blas = hub.blas_s.get(blas_id).get()?;
440
441        let lock = blas.compacted_state.lock();
442
443        Ok(matches!(*lock, BlasCompactState::Ready { .. }))
444    }
445}