wgpu_core/device/
ray_tracing.rs

1use alloc::{string::ToString as _, sync::Arc, vec::Vec};
2use core::mem::{size_of, ManuallyDrop};
3
4#[cfg(feature = "trace")]
5use crate::device::trace;
6use crate::device::DeviceError;
7use crate::{
8    api_log,
9    device::Device,
10    global::Global,
11    hal_label,
12    id::{self, BlasId, TlasId},
13    lock::RwLock,
14    lock::{rank, Mutex},
15    ray_tracing::BlasPrepareCompactError,
16    ray_tracing::{CreateBlasError, CreateTlasError},
17    resource,
18    resource::{
19        BlasCompactCallback, BlasCompactState, Fallible, InvalidResourceError, TrackingData,
20    },
21    snatch::Snatchable,
22    LabelHelpers,
23};
24use hal::AccelerationStructureTriangleIndices;
25use wgt::Features;
26
27impl Device {
28    fn create_blas(
29        self: &Arc<Self>,
30        blas_desc: &resource::BlasDescriptor,
31        sizes: wgt::BlasGeometrySizeDescriptors,
32    ) -> Result<Arc<resource::Blas>, CreateBlasError> {
33        self.check_is_valid()?;
34        self.require_features(Features::EXPERIMENTAL_RAY_QUERY)?;
35
36        if blas_desc
37            .flags
38            .contains(wgt::AccelerationStructureFlags::ALLOW_RAY_HIT_VERTEX_RETURN)
39        {
40            self.require_features(Features::EXPERIMENTAL_RAY_HIT_VERTEX_RETURN)?;
41        }
42
43        let size_info = match &sizes {
44            wgt::BlasGeometrySizeDescriptors::Triangles { descriptors } => {
45                if descriptors.len() as u32 > self.limits.max_blas_geometry_count {
46                    return Err(CreateBlasError::TooManyGeometries(
47                        self.limits.max_blas_geometry_count,
48                        descriptors.len() as u32,
49                    ));
50                }
51
52                let mut entries =
53                    Vec::<hal::AccelerationStructureTriangles<dyn hal::DynBuffer>>::with_capacity(
54                        descriptors.len(),
55                    );
56                for desc in descriptors {
57                    if desc.index_count.is_some() != desc.index_format.is_some() {
58                        return Err(CreateBlasError::MissingIndexData);
59                    }
60                    let indices =
61                        desc.index_count
62                            .map(|count| AccelerationStructureTriangleIndices::<
63                                dyn hal::DynBuffer,
64                            > {
65                                format: desc.index_format.unwrap(),
66                                buffer: None,
67                                offset: 0,
68                                count,
69                            });
70                    if !self
71                        .features
72                        .allowed_vertex_formats_for_blas()
73                        .contains(&desc.vertex_format)
74                    {
75                        return Err(CreateBlasError::InvalidVertexFormat(
76                            desc.vertex_format,
77                            self.features.allowed_vertex_formats_for_blas(),
78                        ));
79                    }
80
81                    let mut transform = None;
82
83                    if blas_desc
84                        .flags
85                        .contains(wgt::AccelerationStructureFlags::USE_TRANSFORM)
86                    {
87                        transform = Some(wgpu_hal::AccelerationStructureTriangleTransform {
88                            buffer: self.zero_buffer.as_ref(),
89                            offset: 0,
90                        })
91                    }
92
93                    if desc.vertex_count > self.limits.max_blas_primitive_count {
94                        return Err(CreateBlasError::TooManyPrimitives(
95                            self.limits.max_blas_primitive_count,
96                            desc.vertex_count,
97                        ));
98                    }
99
100                    entries.push(hal::AccelerationStructureTriangles::<dyn hal::DynBuffer> {
101                        vertex_buffer: None,
102                        vertex_format: desc.vertex_format,
103                        first_vertex: 0,
104                        vertex_count: desc.vertex_count,
105                        vertex_stride: 0,
106                        indices,
107                        transform,
108                        flags: desc.flags,
109                    });
110                }
111                unsafe {
112                    self.raw().get_acceleration_structure_build_sizes(
113                        &hal::GetAccelerationStructureBuildSizesDescriptor {
114                            entries: &hal::AccelerationStructureEntries::Triangles(entries),
115                            flags: blas_desc.flags,
116                        },
117                    )
118                }
119            }
120        };
121
122        let raw = unsafe {
123            self.raw()
124                .create_acceleration_structure(&hal::AccelerationStructureDescriptor {
125                    label: blas_desc.label.as_deref(),
126                    size: size_info.acceleration_structure_size,
127                    format: hal::AccelerationStructureFormat::BottomLevel,
128                    allow_compaction: blas_desc
129                        .flags
130                        .contains(wgpu_types::AccelerationStructureFlags::ALLOW_COMPACTION),
131                })
132        }
133        .map_err(|e| self.handle_hal_error_with_nonfatal_oom(e))?;
134
135        let compaction_buffer = if blas_desc
136            .flags
137            .contains(wgpu_types::AccelerationStructureFlags::ALLOW_COMPACTION)
138        {
139            Some(ManuallyDrop::new(unsafe {
140                self.raw()
141                    .create_buffer(&hal::BufferDescriptor {
142                        label: Some("(wgpu internal) compaction read-back buffer"),
143                        size: size_of::<wgpu_types::BufferAddress>() as wgpu_types::BufferAddress,
144                        usage: wgpu_types::BufferUses::ACCELERATION_STRUCTURE_QUERY
145                            | wgpu_types::BufferUses::MAP_READ,
146                        memory_flags: hal::MemoryFlags::PREFER_COHERENT,
147                    })
148                    .map_err(DeviceError::from_hal)?
149            }))
150        } else {
151            None
152        };
153
154        let handle = unsafe {
155            self.raw()
156                .get_acceleration_structure_device_address(raw.as_ref())
157        };
158
159        Ok(Arc::new(resource::Blas {
160            raw: Snatchable::new(raw),
161            device: self.clone(),
162            size_info,
163            sizes,
164            flags: blas_desc.flags,
165            update_mode: blas_desc.update_mode,
166            handle,
167            label: blas_desc.label.to_string(),
168            built_index: RwLock::new(rank::BLAS_BUILT_INDEX, None),
169            tracking_data: TrackingData::new(self.tracker_indices.blas_s.clone()),
170            compaction_buffer,
171            compacted_state: Mutex::new(rank::BLAS_COMPACTION_STATE, BlasCompactState::Idle),
172        }))
173    }
174
175    fn create_tlas(
176        self: &Arc<Self>,
177        desc: &resource::TlasDescriptor,
178    ) -> Result<Arc<resource::Tlas>, CreateTlasError> {
179        self.check_is_valid()?;
180        self.require_features(Features::EXPERIMENTAL_RAY_QUERY)?;
181
182        if desc.max_instances > self.limits.max_tlas_instance_count {
183            return Err(CreateTlasError::TooManyInstances(
184                self.limits.max_tlas_instance_count,
185                desc.max_instances,
186            ));
187        }
188
189        if desc
190            .flags
191            .contains(wgt::AccelerationStructureFlags::USE_TRANSFORM)
192        {
193            return Err(CreateTlasError::DisallowedFlag(
194                wgt::AccelerationStructureFlags::USE_TRANSFORM,
195            ));
196        }
197
198        if desc
199            .flags
200            .contains(wgt::AccelerationStructureFlags::ALLOW_RAY_HIT_VERTEX_RETURN)
201        {
202            self.require_features(Features::EXPERIMENTAL_RAY_HIT_VERTEX_RETURN)?;
203        }
204
205        let size_info = unsafe {
206            self.raw().get_acceleration_structure_build_sizes(
207                &hal::GetAccelerationStructureBuildSizesDescriptor {
208                    entries: &hal::AccelerationStructureEntries::Instances(
209                        hal::AccelerationStructureInstances {
210                            buffer: None,
211                            offset: 0,
212                            count: desc.max_instances,
213                        },
214                    ),
215                    flags: desc.flags,
216                },
217            )
218        };
219
220        let raw = unsafe {
221            self.raw()
222                .create_acceleration_structure(&hal::AccelerationStructureDescriptor {
223                    label: desc.label.as_deref(),
224                    size: size_info.acceleration_structure_size,
225                    format: hal::AccelerationStructureFormat::TopLevel,
226                    allow_compaction: false,
227                })
228        }
229        .map_err(|e| self.handle_hal_error_with_nonfatal_oom(e))?;
230
231        let instance_buffer_size =
232            self.alignments.raw_tlas_instance_size * desc.max_instances.max(1) as usize;
233        let instance_buffer = unsafe {
234            self.raw().create_buffer(&hal::BufferDescriptor {
235                label: hal_label(Some("(wgpu-core) instances_buffer"), self.instance_flags),
236                size: instance_buffer_size as u64,
237                usage: wgt::BufferUses::COPY_DST
238                    | wgt::BufferUses::TOP_LEVEL_ACCELERATION_STRUCTURE_INPUT,
239                memory_flags: hal::MemoryFlags::PREFER_COHERENT,
240            })
241        }
242        .map_err(|e| self.handle_hal_error_with_nonfatal_oom(e))?;
243
244        Ok(Arc::new(resource::Tlas {
245            raw: Snatchable::new(raw),
246            device: self.clone(),
247            size_info,
248            flags: desc.flags,
249            update_mode: desc.update_mode,
250            built_index: RwLock::new(rank::TLAS_BUILT_INDEX, None),
251            dependencies: RwLock::new(rank::TLAS_DEPENDENCIES, Vec::new()),
252            instance_buffer: ManuallyDrop::new(instance_buffer),
253            label: desc.label.to_string(),
254            max_instance_count: desc.max_instances,
255            tracking_data: TrackingData::new(self.tracker_indices.tlas_s.clone()),
256        }))
257    }
258}
259
260impl Global {
261    pub fn device_create_blas(
262        &self,
263        device_id: id::DeviceId,
264        desc: &resource::BlasDescriptor,
265        sizes: wgt::BlasGeometrySizeDescriptors,
266        id_in: Option<BlasId>,
267    ) -> (BlasId, Option<u64>, Option<CreateBlasError>) {
268        profiling::scope!("Device::create_blas");
269
270        let fid = self.hub.blas_s.prepare(id_in);
271
272        let error = 'error: {
273            let device = self.hub.devices.get(device_id);
274
275            #[cfg(feature = "trace")]
276            if let Some(trace) = device.trace.lock().as_mut() {
277                trace.add(trace::Action::CreateBlas {
278                    id: fid.id(),
279                    desc: desc.clone(),
280                    sizes: sizes.clone(),
281                });
282            }
283
284            let blas = match device.create_blas(desc, sizes) {
285                Ok(blas) => blas,
286                Err(e) => break 'error e,
287            };
288            let handle = blas.handle;
289
290            let id = fid.assign(Fallible::Valid(blas));
291            api_log!("Device::create_blas -> {id:?}");
292
293            return (id, Some(handle), None);
294        };
295
296        let id = fid.assign(Fallible::Invalid(Arc::new(error.to_string())));
297        (id, None, Some(error))
298    }
299
300    pub fn device_create_tlas(
301        &self,
302        device_id: id::DeviceId,
303        desc: &resource::TlasDescriptor,
304        id_in: Option<TlasId>,
305    ) -> (TlasId, Option<CreateTlasError>) {
306        profiling::scope!("Device::create_tlas");
307
308        let fid = self.hub.tlas_s.prepare(id_in);
309
310        let error = 'error: {
311            let device = self.hub.devices.get(device_id);
312
313            #[cfg(feature = "trace")]
314            if let Some(trace) = device.trace.lock().as_mut() {
315                trace.add(trace::Action::CreateTlas {
316                    id: fid.id(),
317                    desc: desc.clone(),
318                });
319            }
320
321            let tlas = match device.create_tlas(desc) {
322                Ok(tlas) => tlas,
323                Err(e) => break 'error e,
324            };
325
326            let id = fid.assign(Fallible::Valid(tlas));
327            api_log!("Device::create_tlas -> {id:?}");
328
329            return (id, None);
330        };
331
332        let id = fid.assign(Fallible::Invalid(Arc::new(error.to_string())));
333        (id, Some(error))
334    }
335
336    pub fn blas_drop(&self, blas_id: BlasId) {
337        profiling::scope!("Blas::drop");
338        api_log!("Blas::drop {blas_id:?}");
339
340        let _blas = self.hub.blas_s.remove(blas_id);
341
342        #[cfg(feature = "trace")]
343        if let Ok(blas) = _blas.get() {
344            if let Some(t) = blas.device.trace.lock().as_mut() {
345                t.add(trace::Action::DestroyBlas(blas_id));
346            }
347        }
348    }
349
350    pub fn tlas_drop(&self, tlas_id: TlasId) {
351        profiling::scope!("Tlas::drop");
352        api_log!("Tlas::drop {tlas_id:?}");
353
354        let _tlas = self.hub.tlas_s.remove(tlas_id);
355
356        #[cfg(feature = "trace")]
357        if let Ok(tlas) = _tlas.get() {
358            if let Some(t) = tlas.device.trace.lock().as_mut() {
359                t.add(trace::Action::DestroyTlas(tlas_id));
360            }
361        }
362    }
363
364    pub fn blas_prepare_compact_async(
365        &self,
366        blas_id: BlasId,
367        callback: Option<BlasCompactCallback>,
368    ) -> Result<crate::SubmissionIndex, BlasPrepareCompactError> {
369        profiling::scope!("Blas::prepare_compact_async");
370        api_log!("Blas::prepare_compact_async {blas_id:?}");
371
372        let hub = &self.hub;
373
374        let compact_result = match hub.blas_s.get(blas_id).get() {
375            Ok(blas) => blas.prepare_compact_async(callback),
376            Err(e) => Err((callback, e.into())),
377        };
378
379        match compact_result {
380            Ok(submission_index) => Ok(submission_index),
381            Err((mut callback, err)) => {
382                if let Some(callback) = callback.take() {
383                    callback(Err(err.clone()));
384                }
385                Err(err)
386            }
387        }
388    }
389
390    pub fn ready_for_compaction(&self, blas_id: BlasId) -> Result<bool, InvalidResourceError> {
391        profiling::scope!("Blas::prepare_compact_async");
392        api_log!("Blas::prepare_compact_async {blas_id:?}");
393
394        let hub = &self.hub;
395
396        let blas = hub.blas_s.get(blas_id).get()?;
397
398        let lock = blas.compacted_state.lock();
399
400        Ok(matches!(*lock, BlasCompactState::Ready { .. }))
401    }
402}