1use alloc::{string::ToString as _, sync::Arc, vec::Vec};
2use core::mem::{size_of, ManuallyDrop};
3
4#[cfg(feature = "trace")]
5use crate::device::trace::{Action, IntoTrace};
6use crate::device::DeviceError;
7use crate::{
8 api_log,
9 device::Device,
10 global::Global,
11 hal_label,
12 id::{self, BlasId, TlasId},
13 lock::RwLock,
14 lock::{rank, Mutex},
15 ray_tracing::BlasPrepareCompactError,
16 ray_tracing::{CreateBlasError, CreateTlasError},
17 resource,
18 resource::{
19 BlasCompactCallback, BlasCompactState, Fallible, InvalidResourceError, TrackingData,
20 },
21 snatch::Snatchable,
22 LabelHelpers,
23};
24use hal::AccelerationStructureTriangleIndices;
25use wgt::{Features, AABB_GEOMETRY_MIN_STRIDE};
26
27impl Device {
28 pub fn create_blas(
29 self: &Arc<Self>,
30 blas_desc: &resource::BlasDescriptor,
31 sizes: wgt::BlasGeometrySizeDescriptors,
32 ) -> Result<Arc<resource::Blas>, CreateBlasError> {
33 self.check_is_valid()?;
34 self.require_features(Features::EXPERIMENTAL_RAY_QUERY)?;
35
36 if blas_desc
37 .flags
38 .contains(wgt::AccelerationStructureFlags::ALLOW_RAY_HIT_VERTEX_RETURN)
39 {
40 self.require_features(Features::EXPERIMENTAL_RAY_HIT_VERTEX_RETURN)?;
41 }
42
43 let size_info = match &sizes {
44 wgt::BlasGeometrySizeDescriptors::Triangles { descriptors } => {
45 if descriptors.len() as u32 > self.limits.max_blas_geometry_count {
46 return Err(CreateBlasError::TooManyGeometries(
47 self.limits.max_blas_geometry_count,
48 descriptors.len() as u32,
49 ));
50 }
51
52 let mut entries =
53 Vec::<hal::AccelerationStructureTriangles<dyn hal::DynBuffer>>::with_capacity(
54 descriptors.len(),
55 );
56 for desc in descriptors {
57 if desc.index_count.is_some() != desc.index_format.is_some() {
58 return Err(CreateBlasError::MissingIndexData);
59 }
60 let indices =
61 desc.index_count
62 .map(|count| AccelerationStructureTriangleIndices::<
63 dyn hal::DynBuffer,
64 > {
65 format: desc.index_format.unwrap(),
66 buffer: Some(self.zero_buffer.as_ref()),
67 offset: 0,
68 count,
69 });
70 if !self
71 .features
72 .allowed_vertex_formats_for_blas()
73 .contains(&desc.vertex_format)
74 {
75 return Err(CreateBlasError::InvalidVertexFormat(
76 desc.vertex_format,
77 self.features.allowed_vertex_formats_for_blas(),
78 ));
79 }
80
81 let mut transform = None;
82
83 if blas_desc
84 .flags
85 .contains(wgt::AccelerationStructureFlags::USE_TRANSFORM)
86 {
87 transform = Some(wgpu_hal::AccelerationStructureTriangleTransform {
88 buffer: self.zero_buffer.as_ref(),
89 offset: 0,
90 })
91 }
92
93 if desc.vertex_count > self.limits.max_blas_primitive_count {
94 return Err(CreateBlasError::TooManyPrimitives(
95 self.limits.max_blas_primitive_count,
96 desc.vertex_count,
97 ));
98 }
99
100 entries.push(hal::AccelerationStructureTriangles::<dyn hal::DynBuffer> {
101 vertex_buffer: Some(self.zero_buffer.as_ref()),
102 vertex_format: desc.vertex_format,
103 first_vertex: 0,
104 vertex_count: desc.vertex_count,
105 vertex_stride: 0,
106 indices,
107 transform,
108 flags: desc.flags,
109 });
110 }
111 unsafe {
112 self.raw().get_acceleration_structure_build_sizes(
113 &hal::GetAccelerationStructureBuildSizesDescriptor {
114 entries: &hal::AccelerationStructureEntries::Triangles(entries),
115 flags: blas_desc.flags,
116 },
117 )
118 }
119 }
120 wgt::BlasGeometrySizeDescriptors::AABBs { descriptors } => {
121 if descriptors.len() as u32 > self.limits.max_blas_geometry_count {
122 return Err(CreateBlasError::TooManyGeometries(
123 self.limits.max_blas_geometry_count,
124 descriptors.len() as u32,
125 ));
126 }
127
128 let mut entries =
129 Vec::<hal::AccelerationStructureAABBs<dyn hal::DynBuffer>>::with_capacity(
130 descriptors.len(),
131 );
132 for desc in descriptors {
133 if desc.primitive_count > self.limits.max_blas_primitive_count {
134 return Err(CreateBlasError::TooManyPrimitives(
135 self.limits.max_blas_primitive_count,
136 desc.primitive_count,
137 ));
138 }
139
140 entries.push(hal::AccelerationStructureAABBs::<dyn hal::DynBuffer> {
141 buffer: Some(self.zero_buffer.as_ref()),
142 offset: 0,
143 count: desc.primitive_count,
144 stride: AABB_GEOMETRY_MIN_STRIDE,
145 flags: desc.flags,
146 });
147 }
148 unsafe {
149 self.raw().get_acceleration_structure_build_sizes(
150 &hal::GetAccelerationStructureBuildSizesDescriptor {
151 entries: &hal::AccelerationStructureEntries::AABBs(entries),
152 flags: blas_desc.flags,
153 },
154 )
155 }
156 }
157 };
158
159 let raw = unsafe {
160 self.raw()
161 .create_acceleration_structure(&hal::AccelerationStructureDescriptor {
162 label: blas_desc.label.as_deref(),
163 size: size_info.acceleration_structure_size,
164 format: hal::AccelerationStructureFormat::BottomLevel,
165 allow_compaction: blas_desc
166 .flags
167 .contains(wgpu_types::AccelerationStructureFlags::ALLOW_COMPACTION),
168 })
169 }
170 .map_err(|e| self.handle_hal_error_with_nonfatal_oom(e))?;
171
172 let compaction_buffer = if blas_desc
173 .flags
174 .contains(wgpu_types::AccelerationStructureFlags::ALLOW_COMPACTION)
175 {
176 Some(ManuallyDrop::new(unsafe {
177 self.raw()
178 .create_buffer(&hal::BufferDescriptor {
179 label: Some("(wgpu internal) compaction read-back buffer"),
180 size: size_of::<wgpu_types::BufferAddress>() as wgpu_types::BufferAddress,
181 usage: wgpu_types::BufferUses::ACCELERATION_STRUCTURE_QUERY
182 | wgpu_types::BufferUses::MAP_READ,
183 memory_flags: hal::MemoryFlags::PREFER_COHERENT,
184 })
185 .map_err(DeviceError::from_hal)?
186 }))
187 } else {
188 None
189 };
190
191 let handle = unsafe {
192 self.raw()
193 .get_acceleration_structure_device_address(raw.as_ref())
194 };
195
196 Ok(Arc::new(resource::Blas {
197 raw: Snatchable::new(raw),
198 device: self.clone(),
199 size_info,
200 sizes,
201 flags: blas_desc.flags,
202 update_mode: blas_desc.update_mode,
203 handle,
204 label: blas_desc.label.to_string(),
205 built_index: RwLock::new(rank::BLAS_BUILT_INDEX, None),
206 tracking_data: TrackingData::new(self.tracker_indices.blas_s.clone()),
207 compaction_buffer,
208 compacted_state: Mutex::new(rank::BLAS_COMPACTION_STATE, BlasCompactState::Idle),
209 }))
210 }
211
212 pub fn create_tlas(
213 self: &Arc<Self>,
214 desc: &resource::TlasDescriptor,
215 ) -> Result<Arc<resource::Tlas>, CreateTlasError> {
216 self.check_is_valid()?;
217 self.require_features(Features::EXPERIMENTAL_RAY_QUERY)?;
218
219 if desc.max_instances > self.limits.max_tlas_instance_count {
220 return Err(CreateTlasError::TooManyInstances(
221 self.limits.max_tlas_instance_count,
222 desc.max_instances,
223 ));
224 }
225
226 if desc
227 .flags
228 .contains(wgt::AccelerationStructureFlags::USE_TRANSFORM)
229 {
230 return Err(CreateTlasError::DisallowedFlag(
231 wgt::AccelerationStructureFlags::USE_TRANSFORM,
232 ));
233 }
234
235 if desc
236 .flags
237 .contains(wgt::AccelerationStructureFlags::ALLOW_RAY_HIT_VERTEX_RETURN)
238 {
239 self.require_features(Features::EXPERIMENTAL_RAY_HIT_VERTEX_RETURN)?;
240 }
241
242 let size_info = unsafe {
243 self.raw().get_acceleration_structure_build_sizes(
244 &hal::GetAccelerationStructureBuildSizesDescriptor {
245 entries: &hal::AccelerationStructureEntries::Instances(
246 hal::AccelerationStructureInstances {
247 buffer: Some(self.zero_buffer.as_ref()),
248 offset: 0,
249 count: desc.max_instances,
250 },
251 ),
252 flags: desc.flags,
253 },
254 )
255 };
256
257 let raw = unsafe {
258 self.raw()
259 .create_acceleration_structure(&hal::AccelerationStructureDescriptor {
260 label: desc.label.as_deref(),
261 size: size_info.acceleration_structure_size,
262 format: hal::AccelerationStructureFormat::TopLevel,
263 allow_compaction: false,
264 })
265 }
266 .map_err(|e| self.handle_hal_error_with_nonfatal_oom(e))?;
267
268 let instance_buffer_size = self
269 .alignments
270 .raw_tlas_instance_size
271 .saturating_mul(desc.max_instances.max(1));
272 let instance_buffer = unsafe {
273 self.raw().create_buffer(&hal::BufferDescriptor {
274 label: hal_label(Some("(wgpu-core) instances_buffer"), self.instance_flags),
275 size: u64::from(instance_buffer_size),
276 usage: wgt::BufferUses::COPY_DST
277 | wgt::BufferUses::TOP_LEVEL_ACCELERATION_STRUCTURE_INPUT,
278 memory_flags: hal::MemoryFlags::PREFER_COHERENT,
279 })
280 }
281 .map_err(|e| self.handle_hal_error_with_nonfatal_oom(e))?;
282
283 Ok(Arc::new(resource::Tlas {
284 raw: Snatchable::new(raw),
285 device: self.clone(),
286 size_info,
287 flags: desc.flags,
288 update_mode: desc.update_mode,
289 built_index: RwLock::new(rank::TLAS_BUILT_INDEX, None),
290 dependencies: RwLock::new(rank::TLAS_DEPENDENCIES, Vec::new()),
291 instance_buffer: ManuallyDrop::new(instance_buffer),
292 label: desc.label.to_string(),
293 max_instance_count: desc.max_instances,
294 tracking_data: TrackingData::new(self.tracker_indices.tlas_s.clone()),
295 }))
296 }
297}
298
299impl Global {
300 pub fn device_create_blas(
301 &self,
302 device_id: id::DeviceId,
303 desc: &resource::BlasDescriptor,
304 sizes: wgt::BlasGeometrySizeDescriptors,
305 id_in: Option<BlasId>,
306 ) -> (BlasId, Option<u64>, Option<CreateBlasError>) {
307 profiling::scope!("Device::create_blas");
308
309 let fid = self.hub.blas_s.prepare(id_in);
310
311 let error = 'error: {
312 let device = self.hub.devices.get(device_id);
313
314 #[cfg(feature = "trace")]
315 let trace_sizes = sizes.clone();
316
317 let blas = match device.create_blas(desc, sizes) {
318 Ok(blas) => blas,
319 Err(e) => break 'error e,
320 };
321 let handle = blas.handle;
322
323 #[cfg(feature = "trace")]
324 if let Some(trace) = device.trace.lock().as_mut() {
325 trace.add(Action::CreateBlas {
326 id: blas.to_trace(),
327 desc: desc.clone(),
328 sizes: trace_sizes,
329 });
330 }
331
332 let id = fid.assign(Fallible::Valid(blas));
333 api_log!("Device::create_blas -> {id:?}");
334
335 return (id, Some(handle), None);
336 };
337
338 let id = fid.assign(Fallible::Invalid(Arc::new(error.to_string())));
339 (id, None, Some(error))
340 }
341
342 pub fn device_create_tlas(
343 &self,
344 device_id: id::DeviceId,
345 desc: &resource::TlasDescriptor,
346 id_in: Option<TlasId>,
347 ) -> (TlasId, Option<CreateTlasError>) {
348 profiling::scope!("Device::create_tlas");
349
350 let fid = self.hub.tlas_s.prepare(id_in);
351
352 let error = 'error: {
353 let device = self.hub.devices.get(device_id);
354
355 let tlas = match device.create_tlas(desc) {
356 Ok(tlas) => tlas,
357 Err(e) => break 'error e,
358 };
359
360 #[cfg(feature = "trace")]
361 if let Some(trace) = device.trace.lock().as_mut() {
362 trace.add(Action::CreateTlas {
363 id: tlas.to_trace(),
364 desc: desc.clone(),
365 });
366 }
367
368 let id = fid.assign(Fallible::Valid(tlas));
369 api_log!("Device::create_tlas -> {id:?}");
370
371 return (id, None);
372 };
373
374 let id = fid.assign(Fallible::Invalid(Arc::new(error.to_string())));
375 (id, Some(error))
376 }
377
378 pub fn blas_drop(&self, blas_id: BlasId) {
379 profiling::scope!("Blas::drop");
380 api_log!("Blas::drop {blas_id:?}");
381
382 let _blas = self.hub.blas_s.remove(blas_id);
383
384 #[cfg(feature = "trace")]
385 if let Ok(blas) = _blas.get() {
386 if let Some(t) = blas.device.trace.lock().as_mut() {
387 t.add(Action::DestroyBlas(blas.to_trace()));
388 }
389 }
390 }
391
392 pub fn tlas_drop(&self, tlas_id: TlasId) {
393 profiling::scope!("Tlas::drop");
394 api_log!("Tlas::drop {tlas_id:?}");
395
396 let _tlas = self.hub.tlas_s.remove(tlas_id);
397
398 #[cfg(feature = "trace")]
399 if let Ok(tlas) = _tlas.get() {
400 if let Some(t) = tlas.device.trace.lock().as_mut() {
401 t.add(Action::DestroyTlas(tlas.to_trace()));
402 }
403 }
404 }
405
406 pub fn blas_prepare_compact_async(
407 &self,
408 blas_id: BlasId,
409 callback: Option<BlasCompactCallback>,
410 ) -> Result<crate::SubmissionIndex, BlasPrepareCompactError> {
411 profiling::scope!("Blas::prepare_compact_async");
412 api_log!("Blas::prepare_compact_async {blas_id:?}");
413
414 let hub = &self.hub;
415
416 let compact_result = match hub.blas_s.get(blas_id).get() {
417 Ok(blas) => blas.prepare_compact_async(callback),
418 Err(e) => Err((callback, e.into())),
419 };
420
421 match compact_result {
422 Ok(submission_index) => Ok(submission_index),
423 Err((mut callback, err)) => {
424 if let Some(callback) = callback.take() {
425 callback(Err(err.clone()));
426 }
427 Err(err)
428 }
429 }
430 }
431
432 pub fn ready_for_compaction(&self, blas_id: BlasId) -> Result<bool, InvalidResourceError> {
433 profiling::scope!("Blas::prepare_compact_async");
434 api_log!("Blas::prepare_compact_async {blas_id:?}");
435
436 let hub = &self.hub;
437
438 let blas = hub.blas_s.get(blas_id).get()?;
439
440 let lock = blas.compacted_state.lock();
441
442 Ok(matches!(*lock, BlasCompactState::Ready { .. }))
443 }
444}