wgpu_core/device/
ray_tracing.rs

1use alloc::{string::ToString as _, sync::Arc, vec::Vec};
2use core::mem::{size_of, ManuallyDrop};
3
4#[cfg(feature = "trace")]
5use crate::device::trace::{Action, IntoTrace};
6use crate::device::DeviceError;
7use crate::{
8    api_log,
9    device::Device,
10    global::Global,
11    hal_label,
12    id::{self, BlasId, TlasId},
13    lock::RwLock,
14    lock::{rank, Mutex},
15    ray_tracing::BlasPrepareCompactError,
16    ray_tracing::{CreateBlasError, CreateTlasError},
17    resource,
18    resource::{
19        BlasCompactCallback, BlasCompactState, Fallible, InvalidResourceError, TrackingData,
20    },
21    snatch::Snatchable,
22    LabelHelpers,
23};
24use hal::AccelerationStructureTriangleIndices;
25use wgt::{Features, AABB_GEOMETRY_MIN_STRIDE};
26
27impl Device {
28    pub fn create_blas(
29        self: &Arc<Self>,
30        blas_desc: &resource::BlasDescriptor,
31        sizes: wgt::BlasGeometrySizeDescriptors,
32    ) -> Result<Arc<resource::Blas>, CreateBlasError> {
33        self.check_is_valid()?;
34        self.require_features(Features::EXPERIMENTAL_RAY_QUERY)?;
35
36        if blas_desc
37            .flags
38            .contains(wgt::AccelerationStructureFlags::ALLOW_RAY_HIT_VERTEX_RETURN)
39        {
40            self.require_features(Features::EXPERIMENTAL_RAY_HIT_VERTEX_RETURN)?;
41        }
42
43        let size_info = match &sizes {
44            wgt::BlasGeometrySizeDescriptors::Triangles { descriptors } => {
45                if descriptors.len() as u32 > self.limits.max_blas_geometry_count {
46                    return Err(CreateBlasError::TooManyGeometries(
47                        self.limits.max_blas_geometry_count,
48                        descriptors.len() as u32,
49                    ));
50                }
51
52                let mut entries =
53                    Vec::<hal::AccelerationStructureTriangles<dyn hal::DynBuffer>>::with_capacity(
54                        descriptors.len(),
55                    );
56                for desc in descriptors {
57                    if desc.index_count.is_some() != desc.index_format.is_some() {
58                        return Err(CreateBlasError::MissingIndexData);
59                    }
60                    let indices =
61                        desc.index_count
62                            .map(|count| AccelerationStructureTriangleIndices::<
63                                dyn hal::DynBuffer,
64                            > {
65                                format: desc.index_format.unwrap(),
66                                buffer: Some(self.zero_buffer.as_ref()),
67                                offset: 0,
68                                count,
69                            });
70                    if !self
71                        .features
72                        .allowed_vertex_formats_for_blas()
73                        .contains(&desc.vertex_format)
74                    {
75                        return Err(CreateBlasError::InvalidVertexFormat(
76                            desc.vertex_format,
77                            self.features.allowed_vertex_formats_for_blas(),
78                        ));
79                    }
80
81                    let mut transform = None;
82
83                    if blas_desc
84                        .flags
85                        .contains(wgt::AccelerationStructureFlags::USE_TRANSFORM)
86                    {
87                        transform = Some(wgpu_hal::AccelerationStructureTriangleTransform {
88                            buffer: self.zero_buffer.as_ref(),
89                            offset: 0,
90                        })
91                    }
92
93                    if desc.vertex_count > self.limits.max_blas_primitive_count {
94                        return Err(CreateBlasError::TooManyPrimitives(
95                            self.limits.max_blas_primitive_count,
96                            desc.vertex_count,
97                        ));
98                    }
99
100                    entries.push(hal::AccelerationStructureTriangles::<dyn hal::DynBuffer> {
101                        vertex_buffer: Some(self.zero_buffer.as_ref()),
102                        vertex_format: desc.vertex_format,
103                        first_vertex: 0,
104                        vertex_count: desc.vertex_count,
105                        vertex_stride: 0,
106                        indices,
107                        transform,
108                        flags: desc.flags,
109                    });
110                }
111                unsafe {
112                    self.raw().get_acceleration_structure_build_sizes(
113                        &hal::GetAccelerationStructureBuildSizesDescriptor {
114                            entries: &hal::AccelerationStructureEntries::Triangles(entries),
115                            flags: blas_desc.flags,
116                        },
117                    )
118                }
119            }
120            wgt::BlasGeometrySizeDescriptors::AABBs { descriptors } => {
121                if descriptors.len() as u32 > self.limits.max_blas_geometry_count {
122                    return Err(CreateBlasError::TooManyGeometries(
123                        self.limits.max_blas_geometry_count,
124                        descriptors.len() as u32,
125                    ));
126                }
127
128                let mut entries =
129                    Vec::<hal::AccelerationStructureAABBs<dyn hal::DynBuffer>>::with_capacity(
130                        descriptors.len(),
131                    );
132                for desc in descriptors {
133                    if desc.primitive_count > self.limits.max_blas_primitive_count {
134                        return Err(CreateBlasError::TooManyPrimitives(
135                            self.limits.max_blas_primitive_count,
136                            desc.primitive_count,
137                        ));
138                    }
139
140                    entries.push(hal::AccelerationStructureAABBs::<dyn hal::DynBuffer> {
141                        buffer: Some(self.zero_buffer.as_ref()),
142                        offset: 0,
143                        count: desc.primitive_count,
144                        stride: AABB_GEOMETRY_MIN_STRIDE,
145                        flags: desc.flags,
146                    });
147                }
148                unsafe {
149                    self.raw().get_acceleration_structure_build_sizes(
150                        &hal::GetAccelerationStructureBuildSizesDescriptor {
151                            entries: &hal::AccelerationStructureEntries::AABBs(entries),
152                            flags: blas_desc.flags,
153                        },
154                    )
155                }
156            }
157        };
158
159        let raw = unsafe {
160            self.raw()
161                .create_acceleration_structure(&hal::AccelerationStructureDescriptor {
162                    label: blas_desc.label.as_deref(),
163                    size: size_info.acceleration_structure_size,
164                    format: hal::AccelerationStructureFormat::BottomLevel,
165                    allow_compaction: blas_desc
166                        .flags
167                        .contains(wgpu_types::AccelerationStructureFlags::ALLOW_COMPACTION),
168                })
169        }
170        .map_err(|e| self.handle_hal_error_with_nonfatal_oom(e))?;
171
172        let compaction_buffer = if blas_desc
173            .flags
174            .contains(wgpu_types::AccelerationStructureFlags::ALLOW_COMPACTION)
175        {
176            Some(ManuallyDrop::new(unsafe {
177                self.raw()
178                    .create_buffer(&hal::BufferDescriptor {
179                        label: Some("(wgpu internal) compaction read-back buffer"),
180                        size: size_of::<wgpu_types::BufferAddress>() as wgpu_types::BufferAddress,
181                        usage: wgpu_types::BufferUses::ACCELERATION_STRUCTURE_QUERY
182                            | wgpu_types::BufferUses::MAP_READ,
183                        memory_flags: hal::MemoryFlags::PREFER_COHERENT,
184                    })
185                    .map_err(DeviceError::from_hal)?
186            }))
187        } else {
188            None
189        };
190
191        let handle = unsafe {
192            self.raw()
193                .get_acceleration_structure_device_address(raw.as_ref())
194        };
195
196        Ok(Arc::new(resource::Blas {
197            raw: Snatchable::new(raw),
198            device: self.clone(),
199            size_info,
200            sizes,
201            flags: blas_desc.flags,
202            update_mode: blas_desc.update_mode,
203            handle,
204            label: blas_desc.label.to_string(),
205            built_index: RwLock::new(rank::BLAS_BUILT_INDEX, None),
206            tracking_data: TrackingData::new(self.tracker_indices.blas_s.clone()),
207            compaction_buffer,
208            compacted_state: Mutex::new(rank::BLAS_COMPACTION_STATE, BlasCompactState::Idle),
209        }))
210    }
211
212    pub fn create_tlas(
213        self: &Arc<Self>,
214        desc: &resource::TlasDescriptor,
215    ) -> Result<Arc<resource::Tlas>, CreateTlasError> {
216        self.check_is_valid()?;
217        self.require_features(Features::EXPERIMENTAL_RAY_QUERY)?;
218
219        if desc.max_instances > self.limits.max_tlas_instance_count {
220            return Err(CreateTlasError::TooManyInstances(
221                self.limits.max_tlas_instance_count,
222                desc.max_instances,
223            ));
224        }
225
226        if desc
227            .flags
228            .contains(wgt::AccelerationStructureFlags::USE_TRANSFORM)
229        {
230            return Err(CreateTlasError::DisallowedFlag(
231                wgt::AccelerationStructureFlags::USE_TRANSFORM,
232            ));
233        }
234
235        if desc
236            .flags
237            .contains(wgt::AccelerationStructureFlags::ALLOW_RAY_HIT_VERTEX_RETURN)
238        {
239            self.require_features(Features::EXPERIMENTAL_RAY_HIT_VERTEX_RETURN)?;
240        }
241
242        let size_info = unsafe {
243            self.raw().get_acceleration_structure_build_sizes(
244                &hal::GetAccelerationStructureBuildSizesDescriptor {
245                    entries: &hal::AccelerationStructureEntries::Instances(
246                        hal::AccelerationStructureInstances {
247                            buffer: Some(self.zero_buffer.as_ref()),
248                            offset: 0,
249                            count: desc.max_instances,
250                        },
251                    ),
252                    flags: desc.flags,
253                },
254            )
255        };
256
257        let raw = unsafe {
258            self.raw()
259                .create_acceleration_structure(&hal::AccelerationStructureDescriptor {
260                    label: desc.label.as_deref(),
261                    size: size_info.acceleration_structure_size,
262                    format: hal::AccelerationStructureFormat::TopLevel,
263                    allow_compaction: false,
264                })
265        }
266        .map_err(|e| self.handle_hal_error_with_nonfatal_oom(e))?;
267
268        let instance_buffer_size = self
269            .alignments
270            .raw_tlas_instance_size
271            .saturating_mul(desc.max_instances.max(1));
272        let instance_buffer = unsafe {
273            self.raw().create_buffer(&hal::BufferDescriptor {
274                label: hal_label(Some("(wgpu-core) instances_buffer"), self.instance_flags),
275                size: u64::from(instance_buffer_size),
276                usage: wgt::BufferUses::COPY_DST
277                    | wgt::BufferUses::TOP_LEVEL_ACCELERATION_STRUCTURE_INPUT,
278                memory_flags: hal::MemoryFlags::PREFER_COHERENT,
279            })
280        }
281        .map_err(|e| self.handle_hal_error_with_nonfatal_oom(e))?;
282
283        Ok(Arc::new(resource::Tlas {
284            raw: Snatchable::new(raw),
285            device: self.clone(),
286            size_info,
287            flags: desc.flags,
288            update_mode: desc.update_mode,
289            built_index: RwLock::new(rank::TLAS_BUILT_INDEX, None),
290            dependencies: RwLock::new(rank::TLAS_DEPENDENCIES, Vec::new()),
291            instance_buffer: ManuallyDrop::new(instance_buffer),
292            label: desc.label.to_string(),
293            max_instance_count: desc.max_instances,
294            tracking_data: TrackingData::new(self.tracker_indices.tlas_s.clone()),
295        }))
296    }
297}
298
299impl Global {
300    pub fn device_create_blas(
301        &self,
302        device_id: id::DeviceId,
303        desc: &resource::BlasDescriptor,
304        sizes: wgt::BlasGeometrySizeDescriptors,
305        id_in: Option<BlasId>,
306    ) -> (BlasId, Option<u64>, Option<CreateBlasError>) {
307        profiling::scope!("Device::create_blas");
308
309        let fid = self.hub.blas_s.prepare(id_in);
310
311        let error = 'error: {
312            let device = self.hub.devices.get(device_id);
313
314            #[cfg(feature = "trace")]
315            let trace_sizes = sizes.clone();
316
317            let blas = match device.create_blas(desc, sizes) {
318                Ok(blas) => blas,
319                Err(e) => break 'error e,
320            };
321            let handle = blas.handle;
322
323            #[cfg(feature = "trace")]
324            if let Some(trace) = device.trace.lock().as_mut() {
325                trace.add(Action::CreateBlas {
326                    id: blas.to_trace(),
327                    desc: desc.clone(),
328                    sizes: trace_sizes,
329                });
330            }
331
332            let id = fid.assign(Fallible::Valid(blas));
333            api_log!("Device::create_blas -> {id:?}");
334
335            return (id, Some(handle), None);
336        };
337
338        let id = fid.assign(Fallible::Invalid(Arc::new(error.to_string())));
339        (id, None, Some(error))
340    }
341
342    pub fn device_create_tlas(
343        &self,
344        device_id: id::DeviceId,
345        desc: &resource::TlasDescriptor,
346        id_in: Option<TlasId>,
347    ) -> (TlasId, Option<CreateTlasError>) {
348        profiling::scope!("Device::create_tlas");
349
350        let fid = self.hub.tlas_s.prepare(id_in);
351
352        let error = 'error: {
353            let device = self.hub.devices.get(device_id);
354
355            let tlas = match device.create_tlas(desc) {
356                Ok(tlas) => tlas,
357                Err(e) => break 'error e,
358            };
359
360            #[cfg(feature = "trace")]
361            if let Some(trace) = device.trace.lock().as_mut() {
362                trace.add(Action::CreateTlas {
363                    id: tlas.to_trace(),
364                    desc: desc.clone(),
365                });
366            }
367
368            let id = fid.assign(Fallible::Valid(tlas));
369            api_log!("Device::create_tlas -> {id:?}");
370
371            return (id, None);
372        };
373
374        let id = fid.assign(Fallible::Invalid(Arc::new(error.to_string())));
375        (id, Some(error))
376    }
377
378    pub fn blas_drop(&self, blas_id: BlasId) {
379        profiling::scope!("Blas::drop");
380        api_log!("Blas::drop {blas_id:?}");
381
382        let _blas = self.hub.blas_s.remove(blas_id);
383
384        #[cfg(feature = "trace")]
385        if let Ok(blas) = _blas.get() {
386            if let Some(t) = blas.device.trace.lock().as_mut() {
387                t.add(Action::DestroyBlas(blas.to_trace()));
388            }
389        }
390    }
391
392    pub fn tlas_drop(&self, tlas_id: TlasId) {
393        profiling::scope!("Tlas::drop");
394        api_log!("Tlas::drop {tlas_id:?}");
395
396        let _tlas = self.hub.tlas_s.remove(tlas_id);
397
398        #[cfg(feature = "trace")]
399        if let Ok(tlas) = _tlas.get() {
400            if let Some(t) = tlas.device.trace.lock().as_mut() {
401                t.add(Action::DestroyTlas(tlas.to_trace()));
402            }
403        }
404    }
405
406    pub fn blas_prepare_compact_async(
407        &self,
408        blas_id: BlasId,
409        callback: Option<BlasCompactCallback>,
410    ) -> Result<crate::SubmissionIndex, BlasPrepareCompactError> {
411        profiling::scope!("Blas::prepare_compact_async");
412        api_log!("Blas::prepare_compact_async {blas_id:?}");
413
414        let hub = &self.hub;
415
416        let compact_result = match hub.blas_s.get(blas_id).get() {
417            Ok(blas) => blas.prepare_compact_async(callback),
418            Err(e) => Err((callback, e.into())),
419        };
420
421        match compact_result {
422            Ok(submission_index) => Ok(submission_index),
423            Err((mut callback, err)) => {
424                if let Some(callback) = callback.take() {
425                    callback(Err(err.clone()));
426                }
427                Err(err)
428            }
429        }
430    }
431
432    pub fn ready_for_compaction(&self, blas_id: BlasId) -> Result<bool, InvalidResourceError> {
433        profiling::scope!("Blas::prepare_compact_async");
434        api_log!("Blas::prepare_compact_async {blas_id:?}");
435
436        let hub = &self.hub;
437
438        let blas = hub.blas_s.get(blas_id).get()?;
439
440        let lock = blas.compacted_state.lock();
441
442        Ok(matches!(*lock, BlasCompactState::Ready { .. }))
443    }
444}