wgpu_core/device/
ray_tracing.rs

1use alloc::{string::ToString as _, sync::Arc, vec::Vec};
2use core::mem::{size_of, ManuallyDrop};
3
4#[cfg(feature = "trace")]
5use crate::device::trace::{Action, IntoTrace};
6use crate::device::DeviceError;
7use crate::{
8    api_log,
9    device::Device,
10    global::Global,
11    hal_label,
12    id::{self, BlasId, TlasId},
13    lock::RwLock,
14    lock::{rank, Mutex},
15    ray_tracing::BlasPrepareCompactError,
16    ray_tracing::{CreateBlasError, CreateTlasError},
17    resource,
18    resource::{
19        BlasCompactCallback, BlasCompactState, Fallible, InvalidResourceError, TrackingData,
20    },
21    snatch::Snatchable,
22    LabelHelpers,
23};
24use hal::AccelerationStructureTriangleIndices;
25use wgt::Features;
26
27impl Device {
28    pub fn create_blas(
29        self: &Arc<Self>,
30        blas_desc: &resource::BlasDescriptor,
31        sizes: wgt::BlasGeometrySizeDescriptors,
32    ) -> Result<Arc<resource::Blas>, CreateBlasError> {
33        self.check_is_valid()?;
34        self.require_features(Features::EXPERIMENTAL_RAY_QUERY)?;
35
36        if blas_desc
37            .flags
38            .contains(wgt::AccelerationStructureFlags::ALLOW_RAY_HIT_VERTEX_RETURN)
39        {
40            self.require_features(Features::EXPERIMENTAL_RAY_HIT_VERTEX_RETURN)?;
41        }
42
43        let size_info = match &sizes {
44            wgt::BlasGeometrySizeDescriptors::Triangles { descriptors } => {
45                if descriptors.len() as u32 > self.limits.max_blas_geometry_count {
46                    return Err(CreateBlasError::TooManyGeometries(
47                        self.limits.max_blas_geometry_count,
48                        descriptors.len() as u32,
49                    ));
50                }
51
52                let mut entries =
53                    Vec::<hal::AccelerationStructureTriangles<dyn hal::DynBuffer>>::with_capacity(
54                        descriptors.len(),
55                    );
56                for desc in descriptors {
57                    if desc.index_count.is_some() != desc.index_format.is_some() {
58                        return Err(CreateBlasError::MissingIndexData);
59                    }
60                    let indices =
61                        desc.index_count
62                            .map(|count| AccelerationStructureTriangleIndices::<
63                                dyn hal::DynBuffer,
64                            > {
65                                format: desc.index_format.unwrap(),
66                                buffer: None,
67                                offset: 0,
68                                count,
69                            });
70                    if !self
71                        .features
72                        .allowed_vertex_formats_for_blas()
73                        .contains(&desc.vertex_format)
74                    {
75                        return Err(CreateBlasError::InvalidVertexFormat(
76                            desc.vertex_format,
77                            self.features.allowed_vertex_formats_for_blas(),
78                        ));
79                    }
80
81                    let mut transform = None;
82
83                    if blas_desc
84                        .flags
85                        .contains(wgt::AccelerationStructureFlags::USE_TRANSFORM)
86                    {
87                        transform = Some(wgpu_hal::AccelerationStructureTriangleTransform {
88                            buffer: self.zero_buffer.as_ref(),
89                            offset: 0,
90                        })
91                    }
92
93                    if desc.vertex_count > self.limits.max_blas_primitive_count {
94                        return Err(CreateBlasError::TooManyPrimitives(
95                            self.limits.max_blas_primitive_count,
96                            desc.vertex_count,
97                        ));
98                    }
99
100                    entries.push(hal::AccelerationStructureTriangles::<dyn hal::DynBuffer> {
101                        vertex_buffer: None,
102                        vertex_format: desc.vertex_format,
103                        first_vertex: 0,
104                        vertex_count: desc.vertex_count,
105                        vertex_stride: 0,
106                        indices,
107                        transform,
108                        flags: desc.flags,
109                    });
110                }
111                unsafe {
112                    self.raw().get_acceleration_structure_build_sizes(
113                        &hal::GetAccelerationStructureBuildSizesDescriptor {
114                            entries: &hal::AccelerationStructureEntries::Triangles(entries),
115                            flags: blas_desc.flags,
116                        },
117                    )
118                }
119            }
120        };
121
122        let raw = unsafe {
123            self.raw()
124                .create_acceleration_structure(&hal::AccelerationStructureDescriptor {
125                    label: blas_desc.label.as_deref(),
126                    size: size_info.acceleration_structure_size,
127                    format: hal::AccelerationStructureFormat::BottomLevel,
128                    allow_compaction: blas_desc
129                        .flags
130                        .contains(wgpu_types::AccelerationStructureFlags::ALLOW_COMPACTION),
131                })
132        }
133        .map_err(|e| self.handle_hal_error_with_nonfatal_oom(e))?;
134
135        let compaction_buffer = if blas_desc
136            .flags
137            .contains(wgpu_types::AccelerationStructureFlags::ALLOW_COMPACTION)
138        {
139            Some(ManuallyDrop::new(unsafe {
140                self.raw()
141                    .create_buffer(&hal::BufferDescriptor {
142                        label: Some("(wgpu internal) compaction read-back buffer"),
143                        size: size_of::<wgpu_types::BufferAddress>() as wgpu_types::BufferAddress,
144                        usage: wgpu_types::BufferUses::ACCELERATION_STRUCTURE_QUERY
145                            | wgpu_types::BufferUses::MAP_READ,
146                        memory_flags: hal::MemoryFlags::PREFER_COHERENT,
147                    })
148                    .map_err(DeviceError::from_hal)?
149            }))
150        } else {
151            None
152        };
153
154        let handle = unsafe {
155            self.raw()
156                .get_acceleration_structure_device_address(raw.as_ref())
157        };
158
159        Ok(Arc::new(resource::Blas {
160            raw: Snatchable::new(raw),
161            device: self.clone(),
162            size_info,
163            sizes,
164            flags: blas_desc.flags,
165            update_mode: blas_desc.update_mode,
166            handle,
167            label: blas_desc.label.to_string(),
168            built_index: RwLock::new(rank::BLAS_BUILT_INDEX, None),
169            tracking_data: TrackingData::new(self.tracker_indices.blas_s.clone()),
170            compaction_buffer,
171            compacted_state: Mutex::new(rank::BLAS_COMPACTION_STATE, BlasCompactState::Idle),
172        }))
173    }
174
175    pub fn create_tlas(
176        self: &Arc<Self>,
177        desc: &resource::TlasDescriptor,
178    ) -> Result<Arc<resource::Tlas>, CreateTlasError> {
179        self.check_is_valid()?;
180        self.require_features(Features::EXPERIMENTAL_RAY_QUERY)?;
181
182        if desc.max_instances > self.limits.max_tlas_instance_count {
183            return Err(CreateTlasError::TooManyInstances(
184                self.limits.max_tlas_instance_count,
185                desc.max_instances,
186            ));
187        }
188
189        if desc
190            .flags
191            .contains(wgt::AccelerationStructureFlags::USE_TRANSFORM)
192        {
193            return Err(CreateTlasError::DisallowedFlag(
194                wgt::AccelerationStructureFlags::USE_TRANSFORM,
195            ));
196        }
197
198        if desc
199            .flags
200            .contains(wgt::AccelerationStructureFlags::ALLOW_RAY_HIT_VERTEX_RETURN)
201        {
202            self.require_features(Features::EXPERIMENTAL_RAY_HIT_VERTEX_RETURN)?;
203        }
204
205        let size_info = unsafe {
206            self.raw().get_acceleration_structure_build_sizes(
207                &hal::GetAccelerationStructureBuildSizesDescriptor {
208                    entries: &hal::AccelerationStructureEntries::Instances(
209                        hal::AccelerationStructureInstances {
210                            buffer: None,
211                            offset: 0,
212                            count: desc.max_instances,
213                        },
214                    ),
215                    flags: desc.flags,
216                },
217            )
218        };
219
220        let raw = unsafe {
221            self.raw()
222                .create_acceleration_structure(&hal::AccelerationStructureDescriptor {
223                    label: desc.label.as_deref(),
224                    size: size_info.acceleration_structure_size,
225                    format: hal::AccelerationStructureFormat::TopLevel,
226                    allow_compaction: false,
227                })
228        }
229        .map_err(|e| self.handle_hal_error_with_nonfatal_oom(e))?;
230
231        let instance_buffer_size =
232            self.alignments.raw_tlas_instance_size * desc.max_instances.max(1) as usize;
233        let instance_buffer = unsafe {
234            self.raw().create_buffer(&hal::BufferDescriptor {
235                label: hal_label(Some("(wgpu-core) instances_buffer"), self.instance_flags),
236                size: instance_buffer_size as u64,
237                usage: wgt::BufferUses::COPY_DST
238                    | wgt::BufferUses::TOP_LEVEL_ACCELERATION_STRUCTURE_INPUT,
239                memory_flags: hal::MemoryFlags::PREFER_COHERENT,
240            })
241        }
242        .map_err(|e| self.handle_hal_error_with_nonfatal_oom(e))?;
243
244        Ok(Arc::new(resource::Tlas {
245            raw: Snatchable::new(raw),
246            device: self.clone(),
247            size_info,
248            flags: desc.flags,
249            update_mode: desc.update_mode,
250            built_index: RwLock::new(rank::TLAS_BUILT_INDEX, None),
251            dependencies: RwLock::new(rank::TLAS_DEPENDENCIES, Vec::new()),
252            instance_buffer: ManuallyDrop::new(instance_buffer),
253            label: desc.label.to_string(),
254            max_instance_count: desc.max_instances,
255            tracking_data: TrackingData::new(self.tracker_indices.tlas_s.clone()),
256        }))
257    }
258}
259
260impl Global {
261    pub fn device_create_blas(
262        &self,
263        device_id: id::DeviceId,
264        desc: &resource::BlasDescriptor,
265        sizes: wgt::BlasGeometrySizeDescriptors,
266        id_in: Option<BlasId>,
267    ) -> (BlasId, Option<u64>, Option<CreateBlasError>) {
268        profiling::scope!("Device::create_blas");
269
270        let fid = self.hub.blas_s.prepare(id_in);
271
272        let error = 'error: {
273            let device = self.hub.devices.get(device_id);
274
275            #[cfg(feature = "trace")]
276            let trace_sizes = sizes.clone();
277
278            let blas = match device.create_blas(desc, sizes) {
279                Ok(blas) => blas,
280                Err(e) => break 'error e,
281            };
282            let handle = blas.handle;
283
284            #[cfg(feature = "trace")]
285            if let Some(trace) = device.trace.lock().as_mut() {
286                trace.add(Action::CreateBlas {
287                    id: blas.to_trace(),
288                    desc: desc.clone(),
289                    sizes: trace_sizes,
290                });
291            }
292
293            let id = fid.assign(Fallible::Valid(blas));
294            api_log!("Device::create_blas -> {id:?}");
295
296            return (id, Some(handle), None);
297        };
298
299        let id = fid.assign(Fallible::Invalid(Arc::new(error.to_string())));
300        (id, None, Some(error))
301    }
302
303    pub fn device_create_tlas(
304        &self,
305        device_id: id::DeviceId,
306        desc: &resource::TlasDescriptor,
307        id_in: Option<TlasId>,
308    ) -> (TlasId, Option<CreateTlasError>) {
309        profiling::scope!("Device::create_tlas");
310
311        let fid = self.hub.tlas_s.prepare(id_in);
312
313        let error = 'error: {
314            let device = self.hub.devices.get(device_id);
315
316            let tlas = match device.create_tlas(desc) {
317                Ok(tlas) => tlas,
318                Err(e) => break 'error e,
319            };
320
321            #[cfg(feature = "trace")]
322            if let Some(trace) = device.trace.lock().as_mut() {
323                trace.add(Action::CreateTlas {
324                    id: tlas.to_trace(),
325                    desc: desc.clone(),
326                });
327            }
328
329            let id = fid.assign(Fallible::Valid(tlas));
330            api_log!("Device::create_tlas -> {id:?}");
331
332            return (id, None);
333        };
334
335        let id = fid.assign(Fallible::Invalid(Arc::new(error.to_string())));
336        (id, Some(error))
337    }
338
339    pub fn blas_drop(&self, blas_id: BlasId) {
340        profiling::scope!("Blas::drop");
341        api_log!("Blas::drop {blas_id:?}");
342
343        let _blas = self.hub.blas_s.remove(blas_id);
344
345        #[cfg(feature = "trace")]
346        if let Ok(blas) = _blas.get() {
347            if let Some(t) = blas.device.trace.lock().as_mut() {
348                t.add(Action::DestroyBlas(blas.to_trace()));
349            }
350        }
351    }
352
353    pub fn tlas_drop(&self, tlas_id: TlasId) {
354        profiling::scope!("Tlas::drop");
355        api_log!("Tlas::drop {tlas_id:?}");
356
357        let _tlas = self.hub.tlas_s.remove(tlas_id);
358
359        #[cfg(feature = "trace")]
360        if let Ok(tlas) = _tlas.get() {
361            if let Some(t) = tlas.device.trace.lock().as_mut() {
362                t.add(Action::DestroyTlas(tlas.to_trace()));
363            }
364        }
365    }
366
367    pub fn blas_prepare_compact_async(
368        &self,
369        blas_id: BlasId,
370        callback: Option<BlasCompactCallback>,
371    ) -> Result<crate::SubmissionIndex, BlasPrepareCompactError> {
372        profiling::scope!("Blas::prepare_compact_async");
373        api_log!("Blas::prepare_compact_async {blas_id:?}");
374
375        let hub = &self.hub;
376
377        let compact_result = match hub.blas_s.get(blas_id).get() {
378            Ok(blas) => blas.prepare_compact_async(callback),
379            Err(e) => Err((callback, e.into())),
380        };
381
382        match compact_result {
383            Ok(submission_index) => Ok(submission_index),
384            Err((mut callback, err)) => {
385                if let Some(callback) = callback.take() {
386                    callback(Err(err.clone()));
387                }
388                Err(err)
389            }
390        }
391    }
392
393    pub fn ready_for_compaction(&self, blas_id: BlasId) -> Result<bool, InvalidResourceError> {
394        profiling::scope!("Blas::prepare_compact_async");
395        api_log!("Blas::prepare_compact_async {blas_id:?}");
396
397        let hub = &self.hub;
398
399        let blas = hub.blas_s.get(blas_id).get()?;
400
401        let lock = blas.compacted_state.lock();
402
403        Ok(matches!(*lock, BlasCompactState::Ready { .. }))
404    }
405}