1use alloc::{string::ToString as _, sync::Arc, vec::Vec};
2use core::mem::{size_of, ManuallyDrop};
3
4#[cfg(feature = "trace")]
5use crate::device::trace;
6use crate::device::DeviceError;
7use crate::{
8 api_log,
9 device::Device,
10 global::Global,
11 hal_label,
12 id::{self, BlasId, TlasId},
13 lock::RwLock,
14 lock::{rank, Mutex},
15 ray_tracing::BlasPrepareCompactError,
16 ray_tracing::{CreateBlasError, CreateTlasError},
17 resource,
18 resource::{
19 BlasCompactCallback, BlasCompactState, Fallible, InvalidResourceError, TrackingData,
20 },
21 snatch::Snatchable,
22 LabelHelpers,
23};
24use hal::AccelerationStructureTriangleIndices;
25use wgt::Features;
26
27impl Device {
28 fn create_blas(
29 self: &Arc<Self>,
30 blas_desc: &resource::BlasDescriptor,
31 sizes: wgt::BlasGeometrySizeDescriptors,
32 ) -> Result<Arc<resource::Blas>, CreateBlasError> {
33 self.check_is_valid()?;
34 self.require_features(Features::EXPERIMENTAL_RAY_QUERY)?;
35
36 if blas_desc
37 .flags
38 .contains(wgt::AccelerationStructureFlags::ALLOW_RAY_HIT_VERTEX_RETURN)
39 {
40 self.require_features(Features::EXPERIMENTAL_RAY_HIT_VERTEX_RETURN)?;
41 }
42
43 let size_info = match &sizes {
44 wgt::BlasGeometrySizeDescriptors::Triangles { descriptors } => {
45 if descriptors.len() as u32 > self.limits.max_blas_geometry_count {
46 return Err(CreateBlasError::TooManyGeometries(
47 self.limits.max_blas_geometry_count,
48 descriptors.len() as u32,
49 ));
50 }
51
52 let mut entries =
53 Vec::<hal::AccelerationStructureTriangles<dyn hal::DynBuffer>>::with_capacity(
54 descriptors.len(),
55 );
56 for desc in descriptors {
57 if desc.index_count.is_some() != desc.index_format.is_some() {
58 return Err(CreateBlasError::MissingIndexData);
59 }
60 let indices =
61 desc.index_count
62 .map(|count| AccelerationStructureTriangleIndices::<
63 dyn hal::DynBuffer,
64 > {
65 format: desc.index_format.unwrap(),
66 buffer: None,
67 offset: 0,
68 count,
69 });
70 if !self
71 .features
72 .allowed_vertex_formats_for_blas()
73 .contains(&desc.vertex_format)
74 {
75 return Err(CreateBlasError::InvalidVertexFormat(
76 desc.vertex_format,
77 self.features.allowed_vertex_formats_for_blas(),
78 ));
79 }
80
81 let mut transform = None;
82
83 if blas_desc
84 .flags
85 .contains(wgt::AccelerationStructureFlags::USE_TRANSFORM)
86 {
87 transform = Some(wgpu_hal::AccelerationStructureTriangleTransform {
88 buffer: self.zero_buffer.as_ref(),
89 offset: 0,
90 })
91 }
92
93 if desc.vertex_count > self.limits.max_blas_primitive_count {
94 return Err(CreateBlasError::TooManyPrimitives(
95 self.limits.max_blas_primitive_count,
96 desc.vertex_count,
97 ));
98 }
99
100 entries.push(hal::AccelerationStructureTriangles::<dyn hal::DynBuffer> {
101 vertex_buffer: None,
102 vertex_format: desc.vertex_format,
103 first_vertex: 0,
104 vertex_count: desc.vertex_count,
105 vertex_stride: 0,
106 indices,
107 transform,
108 flags: desc.flags,
109 });
110 }
111 unsafe {
112 self.raw().get_acceleration_structure_build_sizes(
113 &hal::GetAccelerationStructureBuildSizesDescriptor {
114 entries: &hal::AccelerationStructureEntries::Triangles(entries),
115 flags: blas_desc.flags,
116 },
117 )
118 }
119 }
120 };
121
122 let raw = unsafe {
123 self.raw()
124 .create_acceleration_structure(&hal::AccelerationStructureDescriptor {
125 label: blas_desc.label.as_deref(),
126 size: size_info.acceleration_structure_size,
127 format: hal::AccelerationStructureFormat::BottomLevel,
128 allow_compaction: blas_desc
129 .flags
130 .contains(wgpu_types::AccelerationStructureFlags::ALLOW_COMPACTION),
131 })
132 }
133 .map_err(|e| self.handle_hal_error_with_nonfatal_oom(e))?;
134
135 let compaction_buffer = if blas_desc
136 .flags
137 .contains(wgpu_types::AccelerationStructureFlags::ALLOW_COMPACTION)
138 {
139 Some(ManuallyDrop::new(unsafe {
140 self.raw()
141 .create_buffer(&hal::BufferDescriptor {
142 label: Some("(wgpu internal) compaction read-back buffer"),
143 size: size_of::<wgpu_types::BufferAddress>() as wgpu_types::BufferAddress,
144 usage: wgpu_types::BufferUses::ACCELERATION_STRUCTURE_QUERY
145 | wgpu_types::BufferUses::MAP_READ,
146 memory_flags: hal::MemoryFlags::PREFER_COHERENT,
147 })
148 .map_err(DeviceError::from_hal)?
149 }))
150 } else {
151 None
152 };
153
154 let handle = unsafe {
155 self.raw()
156 .get_acceleration_structure_device_address(raw.as_ref())
157 };
158
159 Ok(Arc::new(resource::Blas {
160 raw: Snatchable::new(raw),
161 device: self.clone(),
162 size_info,
163 sizes,
164 flags: blas_desc.flags,
165 update_mode: blas_desc.update_mode,
166 handle,
167 label: blas_desc.label.to_string(),
168 built_index: RwLock::new(rank::BLAS_BUILT_INDEX, None),
169 tracking_data: TrackingData::new(self.tracker_indices.blas_s.clone()),
170 compaction_buffer,
171 compacted_state: Mutex::new(rank::BLAS_COMPACTION_STATE, BlasCompactState::Idle),
172 }))
173 }
174
175 fn create_tlas(
176 self: &Arc<Self>,
177 desc: &resource::TlasDescriptor,
178 ) -> Result<Arc<resource::Tlas>, CreateTlasError> {
179 self.check_is_valid()?;
180 self.require_features(Features::EXPERIMENTAL_RAY_QUERY)?;
181
182 if desc.max_instances > self.limits.max_tlas_instance_count {
183 return Err(CreateTlasError::TooManyInstances(
184 self.limits.max_tlas_instance_count,
185 desc.max_instances,
186 ));
187 }
188
189 if desc
190 .flags
191 .contains(wgt::AccelerationStructureFlags::USE_TRANSFORM)
192 {
193 return Err(CreateTlasError::DisallowedFlag(
194 wgt::AccelerationStructureFlags::USE_TRANSFORM,
195 ));
196 }
197
198 if desc
199 .flags
200 .contains(wgt::AccelerationStructureFlags::ALLOW_RAY_HIT_VERTEX_RETURN)
201 {
202 self.require_features(Features::EXPERIMENTAL_RAY_HIT_VERTEX_RETURN)?;
203 }
204
205 let size_info = unsafe {
206 self.raw().get_acceleration_structure_build_sizes(
207 &hal::GetAccelerationStructureBuildSizesDescriptor {
208 entries: &hal::AccelerationStructureEntries::Instances(
209 hal::AccelerationStructureInstances {
210 buffer: None,
211 offset: 0,
212 count: desc.max_instances,
213 },
214 ),
215 flags: desc.flags,
216 },
217 )
218 };
219
220 let raw = unsafe {
221 self.raw()
222 .create_acceleration_structure(&hal::AccelerationStructureDescriptor {
223 label: desc.label.as_deref(),
224 size: size_info.acceleration_structure_size,
225 format: hal::AccelerationStructureFormat::TopLevel,
226 allow_compaction: false,
227 })
228 }
229 .map_err(|e| self.handle_hal_error_with_nonfatal_oom(e))?;
230
231 let instance_buffer_size =
232 self.alignments.raw_tlas_instance_size * desc.max_instances.max(1) as usize;
233 let instance_buffer = unsafe {
234 self.raw().create_buffer(&hal::BufferDescriptor {
235 label: hal_label(Some("(wgpu-core) instances_buffer"), self.instance_flags),
236 size: instance_buffer_size as u64,
237 usage: wgt::BufferUses::COPY_DST
238 | wgt::BufferUses::TOP_LEVEL_ACCELERATION_STRUCTURE_INPUT,
239 memory_flags: hal::MemoryFlags::PREFER_COHERENT,
240 })
241 }
242 .map_err(|e| self.handle_hal_error_with_nonfatal_oom(e))?;
243
244 Ok(Arc::new(resource::Tlas {
245 raw: Snatchable::new(raw),
246 device: self.clone(),
247 size_info,
248 flags: desc.flags,
249 update_mode: desc.update_mode,
250 built_index: RwLock::new(rank::TLAS_BUILT_INDEX, None),
251 dependencies: RwLock::new(rank::TLAS_DEPENDENCIES, Vec::new()),
252 instance_buffer: ManuallyDrop::new(instance_buffer),
253 label: desc.label.to_string(),
254 max_instance_count: desc.max_instances,
255 tracking_data: TrackingData::new(self.tracker_indices.tlas_s.clone()),
256 }))
257 }
258}
259
260impl Global {
261 pub fn device_create_blas(
262 &self,
263 device_id: id::DeviceId,
264 desc: &resource::BlasDescriptor,
265 sizes: wgt::BlasGeometrySizeDescriptors,
266 id_in: Option<BlasId>,
267 ) -> (BlasId, Option<u64>, Option<CreateBlasError>) {
268 profiling::scope!("Device::create_blas");
269
270 let fid = self.hub.blas_s.prepare(id_in);
271
272 let error = 'error: {
273 let device = self.hub.devices.get(device_id);
274
275 #[cfg(feature = "trace")]
276 if let Some(trace) = device.trace.lock().as_mut() {
277 trace.add(trace::Action::CreateBlas {
278 id: fid.id(),
279 desc: desc.clone(),
280 sizes: sizes.clone(),
281 });
282 }
283
284 let blas = match device.create_blas(desc, sizes) {
285 Ok(blas) => blas,
286 Err(e) => break 'error e,
287 };
288 let handle = blas.handle;
289
290 let id = fid.assign(Fallible::Valid(blas));
291 api_log!("Device::create_blas -> {id:?}");
292
293 return (id, Some(handle), None);
294 };
295
296 let id = fid.assign(Fallible::Invalid(Arc::new(error.to_string())));
297 (id, None, Some(error))
298 }
299
300 pub fn device_create_tlas(
301 &self,
302 device_id: id::DeviceId,
303 desc: &resource::TlasDescriptor,
304 id_in: Option<TlasId>,
305 ) -> (TlasId, Option<CreateTlasError>) {
306 profiling::scope!("Device::create_tlas");
307
308 let fid = self.hub.tlas_s.prepare(id_in);
309
310 let error = 'error: {
311 let device = self.hub.devices.get(device_id);
312
313 #[cfg(feature = "trace")]
314 if let Some(trace) = device.trace.lock().as_mut() {
315 trace.add(trace::Action::CreateTlas {
316 id: fid.id(),
317 desc: desc.clone(),
318 });
319 }
320
321 let tlas = match device.create_tlas(desc) {
322 Ok(tlas) => tlas,
323 Err(e) => break 'error e,
324 };
325
326 let id = fid.assign(Fallible::Valid(tlas));
327 api_log!("Device::create_tlas -> {id:?}");
328
329 return (id, None);
330 };
331
332 let id = fid.assign(Fallible::Invalid(Arc::new(error.to_string())));
333 (id, Some(error))
334 }
335
336 pub fn blas_drop(&self, blas_id: BlasId) {
337 profiling::scope!("Blas::drop");
338 api_log!("Blas::drop {blas_id:?}");
339
340 let _blas = self.hub.blas_s.remove(blas_id);
341
342 #[cfg(feature = "trace")]
343 if let Ok(blas) = _blas.get() {
344 if let Some(t) = blas.device.trace.lock().as_mut() {
345 t.add(trace::Action::DestroyBlas(blas_id));
346 }
347 }
348 }
349
350 pub fn tlas_drop(&self, tlas_id: TlasId) {
351 profiling::scope!("Tlas::drop");
352 api_log!("Tlas::drop {tlas_id:?}");
353
354 let _tlas = self.hub.tlas_s.remove(tlas_id);
355
356 #[cfg(feature = "trace")]
357 if let Ok(tlas) = _tlas.get() {
358 if let Some(t) = tlas.device.trace.lock().as_mut() {
359 t.add(trace::Action::DestroyTlas(tlas_id));
360 }
361 }
362 }
363
364 pub fn blas_prepare_compact_async(
365 &self,
366 blas_id: BlasId,
367 callback: Option<BlasCompactCallback>,
368 ) -> Result<crate::SubmissionIndex, BlasPrepareCompactError> {
369 profiling::scope!("Blas::prepare_compact_async");
370 api_log!("Blas::prepare_compact_async {blas_id:?}");
371
372 let hub = &self.hub;
373
374 let compact_result = match hub.blas_s.get(blas_id).get() {
375 Ok(blas) => blas.prepare_compact_async(callback),
376 Err(e) => Err((callback, e.into())),
377 };
378
379 match compact_result {
380 Ok(submission_index) => Ok(submission_index),
381 Err((mut callback, err)) => {
382 if let Some(callback) = callback.take() {
383 callback(Err(err.clone()));
384 }
385 Err(err)
386 }
387 }
388 }
389
390 pub fn ready_for_compaction(&self, blas_id: BlasId) -> Result<bool, InvalidResourceError> {
391 profiling::scope!("Blas::prepare_compact_async");
392 api_log!("Blas::prepare_compact_async {blas_id:?}");
393
394 let hub = &self.hub;
395
396 let blas = hub.blas_s.get(blas_id).get()?;
397
398 let lock = blas.compacted_state.lock();
399
400 Ok(matches!(*lock, BlasCompactState::Ready { .. }))
401 }
402}