1use alloc::{string::ToString as _, sync::Arc, vec::Vec};
2use core::mem::{size_of, ManuallyDrop};
3
4#[cfg(feature = "trace")]
5use crate::device::trace::{Action, IntoTrace};
6use crate::device::DeviceError;
7use crate::{
8 api_log,
9 device::Device,
10 global::Global,
11 hal_label,
12 id::{self, BlasId, TlasId},
13 lock::RwLock,
14 lock::{rank, Mutex},
15 ray_tracing::BlasPrepareCompactError,
16 ray_tracing::{CreateBlasError, CreateTlasError},
17 resource,
18 resource::{
19 BlasCompactCallback, BlasCompactState, Fallible, InvalidResourceError, TrackingData,
20 },
21 snatch::Snatchable,
22 LabelHelpers,
23};
24use hal::AccelerationStructureTriangleIndices;
25use wgt::{Features, AABB_GEOMETRY_MIN_STRIDE};
26
27impl Device {
28 pub fn create_blas(
29 self: &Arc<Self>,
30 blas_desc: &resource::BlasDescriptor,
31 sizes: wgt::BlasGeometrySizeDescriptors,
32 ) -> Result<Arc<resource::Blas>, CreateBlasError> {
33 self.check_is_valid()?;
34 self.require_features(Features::EXPERIMENTAL_RAY_QUERY)?;
35
36 if blas_desc
37 .flags
38 .contains(wgt::AccelerationStructureFlags::ALLOW_RAY_HIT_VERTEX_RETURN)
39 {
40 self.require_features(Features::EXPERIMENTAL_RAY_HIT_VERTEX_RETURN)?;
41 }
42
43 let size_info = match &sizes {
44 wgt::BlasGeometrySizeDescriptors::Triangles { descriptors } => {
45 if descriptors.len() as u32 > self.limits.max_blas_geometry_count {
46 return Err(CreateBlasError::TooManyGeometries(
47 self.limits.max_blas_geometry_count,
48 descriptors.len() as u32,
49 ));
50 }
51
52 let mut entries =
53 Vec::<hal::AccelerationStructureTriangles<dyn hal::DynBuffer>>::with_capacity(
54 descriptors.len(),
55 );
56 for desc in descriptors {
57 if desc.index_count.is_some() != desc.index_format.is_some() {
58 return Err(CreateBlasError::MissingIndexData);
59 }
60 let indices =
61 desc.index_count
62 .map(|count| AccelerationStructureTriangleIndices::<
63 dyn hal::DynBuffer,
64 > {
65 format: desc.index_format.unwrap(),
66 buffer: Some(self.zero_buffer.as_ref()),
67 offset: 0,
68 count,
69 });
70 if !self
71 .features
72 .allowed_vertex_formats_for_blas()
73 .contains(&desc.vertex_format)
74 {
75 return Err(CreateBlasError::InvalidVertexFormat(
76 desc.vertex_format,
77 self.features.allowed_vertex_formats_for_blas(),
78 ));
79 }
80
81 let mut transform = None;
82
83 if blas_desc
84 .flags
85 .contains(wgt::AccelerationStructureFlags::USE_TRANSFORM)
86 {
87 transform = Some(wgpu_hal::AccelerationStructureTriangleTransform {
88 buffer: self.zero_buffer.as_ref(),
89 offset: 0,
90 })
91 }
92
93 if desc.vertex_count > self.limits.max_blas_primitive_count {
94 return Err(CreateBlasError::TooManyPrimitives(
95 self.limits.max_blas_primitive_count,
96 desc.vertex_count,
97 ));
98 }
99
100 entries.push(hal::AccelerationStructureTriangles::<dyn hal::DynBuffer> {
101 vertex_buffer: Some(self.zero_buffer.as_ref()),
102 vertex_format: desc.vertex_format,
103 first_vertex: 0,
104 vertex_count: desc.vertex_count,
105 vertex_stride: 0,
106 indices,
107 transform,
108 flags: desc.flags,
109 });
110 }
111 unsafe {
112 self.raw().get_acceleration_structure_build_sizes(
113 &hal::GetAccelerationStructureBuildSizesDescriptor {
114 entries: &hal::AccelerationStructureEntries::Triangles(entries),
115 flags: blas_desc.flags,
116 },
117 )
118 }
119 }
120 wgt::BlasGeometrySizeDescriptors::AABBs { descriptors } => {
121 if descriptors.len() as u32 > self.limits.max_blas_geometry_count {
122 return Err(CreateBlasError::TooManyGeometries(
123 self.limits.max_blas_geometry_count,
124 descriptors.len() as u32,
125 ));
126 }
127
128 let mut entries =
129 Vec::<hal::AccelerationStructureAABBs<dyn hal::DynBuffer>>::with_capacity(
130 descriptors.len(),
131 );
132 for desc in descriptors {
133 if desc.primitive_count > self.limits.max_blas_primitive_count {
134 return Err(CreateBlasError::TooManyPrimitives(
135 self.limits.max_blas_primitive_count,
136 desc.primitive_count,
137 ));
138 }
139
140 entries.push(hal::AccelerationStructureAABBs::<dyn hal::DynBuffer> {
141 buffer: Some(self.zero_buffer.as_ref()),
142 offset: 0,
143 count: desc.primitive_count,
144 stride: AABB_GEOMETRY_MIN_STRIDE,
145 flags: desc.flags,
146 });
147 }
148 unsafe {
149 self.raw().get_acceleration_structure_build_sizes(
150 &hal::GetAccelerationStructureBuildSizesDescriptor {
151 entries: &hal::AccelerationStructureEntries::AABBs(entries),
152 flags: blas_desc.flags,
153 },
154 )
155 }
156 }
157 };
158
159 let raw = unsafe {
160 self.raw()
161 .create_acceleration_structure(&hal::AccelerationStructureDescriptor {
162 label: blas_desc.label.as_deref(),
163 size: size_info.acceleration_structure_size,
164 format: hal::AccelerationStructureFormat::BottomLevel,
165 allow_compaction: blas_desc
166 .flags
167 .contains(wgpu_types::AccelerationStructureFlags::ALLOW_COMPACTION),
168 })
169 }
170 .map_err(|e| self.handle_hal_error_with_nonfatal_oom(e))?;
171
172 let compaction_buffer = if blas_desc
173 .flags
174 .contains(wgpu_types::AccelerationStructureFlags::ALLOW_COMPACTION)
175 {
176 Some(ManuallyDrop::new(unsafe {
177 self.raw()
178 .create_buffer(&hal::BufferDescriptor {
179 label: Some("(wgpu internal) compaction read-back buffer"),
180 size: size_of::<wgpu_types::BufferAddress>() as wgpu_types::BufferAddress,
181 usage: wgpu_types::BufferUses::ACCELERATION_STRUCTURE_QUERY
182 | wgpu_types::BufferUses::MAP_READ,
183 memory_flags: hal::MemoryFlags::PREFER_COHERENT,
184 })
185 .map_err(DeviceError::from_hal)?
186 }))
187 } else {
188 None
189 };
190
191 let handle = unsafe {
192 self.raw()
193 .get_acceleration_structure_device_address(raw.as_ref())
194 };
195
196 Ok(Arc::new(resource::Blas {
197 raw: Snatchable::new(raw),
198 device: self.clone(),
199 size_info,
200 sizes,
201 flags: blas_desc.flags,
202 update_mode: blas_desc.update_mode,
203 handle,
204 label: blas_desc.label.to_string(),
205 built_index: RwLock::new(rank::BLAS_BUILT_INDEX, None),
206 tracking_data: TrackingData::new(self.tracker_indices.blas_s.clone()),
207 compaction_buffer,
208 compacted_state: Mutex::new(rank::BLAS_COMPACTION_STATE, BlasCompactState::Idle),
209 }))
210 }
211
212 pub fn create_tlas(
213 self: &Arc<Self>,
214 desc: &resource::TlasDescriptor,
215 ) -> Result<Arc<resource::Tlas>, CreateTlasError> {
216 self.check_is_valid()?;
217 self.require_features(Features::EXPERIMENTAL_RAY_QUERY)?;
218
219 if desc.max_instances > self.limits.max_tlas_instance_count {
220 return Err(CreateTlasError::TooManyInstances(
221 self.limits.max_tlas_instance_count,
222 desc.max_instances,
223 ));
224 }
225
226 if desc
227 .flags
228 .contains(wgt::AccelerationStructureFlags::USE_TRANSFORM)
229 {
230 return Err(CreateTlasError::DisallowedFlag(
231 wgt::AccelerationStructureFlags::USE_TRANSFORM,
232 ));
233 }
234
235 if desc
236 .flags
237 .contains(wgt::AccelerationStructureFlags::ALLOW_RAY_HIT_VERTEX_RETURN)
238 {
239 self.require_features(Features::EXPERIMENTAL_RAY_HIT_VERTEX_RETURN)?;
240 }
241
242 let size_info = unsafe {
243 self.raw().get_acceleration_structure_build_sizes(
244 &hal::GetAccelerationStructureBuildSizesDescriptor {
245 entries: &hal::AccelerationStructureEntries::Instances(
246 hal::AccelerationStructureInstances {
247 buffer: Some(self.zero_buffer.as_ref()),
248 offset: 0,
249 count: desc.max_instances,
250 },
251 ),
252 flags: desc.flags,
253 },
254 )
255 };
256
257 let raw = unsafe {
258 self.raw()
259 .create_acceleration_structure(&hal::AccelerationStructureDescriptor {
260 label: desc.label.as_deref(),
261 size: size_info.acceleration_structure_size,
262 format: hal::AccelerationStructureFormat::TopLevel,
263 allow_compaction: false,
264 })
265 }
266 .map_err(|e| self.handle_hal_error_with_nonfatal_oom(e))?;
267
268 let instance_buffer_size = self
269 .alignments
270 .raw_tlas_instance_size
271 .checked_mul(desc.max_instances.max(1))
272 .expect("max_tlas_instance_count should not allow excessive buffer size");
273 let instance_buffer = unsafe {
274 self.raw().create_buffer(&hal::BufferDescriptor {
275 label: hal_label(Some("(wgpu-core) instances_buffer"), self.instance_flags),
276 size: u64::from(instance_buffer_size),
277 usage: wgt::BufferUses::COPY_DST
278 | wgt::BufferUses::TOP_LEVEL_ACCELERATION_STRUCTURE_INPUT,
279 memory_flags: hal::MemoryFlags::PREFER_COHERENT,
280 })
281 }
282 .map_err(|e| self.handle_hal_error_with_nonfatal_oom(e))?;
283
284 Ok(Arc::new(resource::Tlas {
285 raw: Snatchable::new(raw),
286 device: self.clone(),
287 size_info,
288 flags: desc.flags,
289 update_mode: desc.update_mode,
290 built_index: RwLock::new(rank::TLAS_BUILT_INDEX, None),
291 dependencies: RwLock::new(rank::TLAS_DEPENDENCIES, Vec::new()),
292 instance_buffer: ManuallyDrop::new(instance_buffer),
293 label: desc.label.to_string(),
294 max_instance_count: desc.max_instances,
295 tracking_data: TrackingData::new(self.tracker_indices.tlas_s.clone()),
296 }))
297 }
298}
299
300impl Global {
301 pub fn device_create_blas(
302 &self,
303 device_id: id::DeviceId,
304 desc: &resource::BlasDescriptor,
305 sizes: wgt::BlasGeometrySizeDescriptors,
306 id_in: Option<BlasId>,
307 ) -> (BlasId, Option<u64>, Option<CreateBlasError>) {
308 profiling::scope!("Device::create_blas");
309
310 let fid = self.hub.blas_s.prepare(id_in);
311
312 let error = 'error: {
313 let device = self.hub.devices.get(device_id);
314
315 #[cfg(feature = "trace")]
316 let trace_sizes = sizes.clone();
317
318 let blas = match device.create_blas(desc, sizes) {
319 Ok(blas) => blas,
320 Err(e) => break 'error e,
321 };
322 let handle = blas.handle;
323
324 #[cfg(feature = "trace")]
325 if let Some(trace) = device.trace.lock().as_mut() {
326 trace.add(Action::CreateBlas {
327 id: blas.to_trace(),
328 desc: desc.clone(),
329 sizes: trace_sizes,
330 });
331 }
332
333 let id = fid.assign(Fallible::Valid(blas));
334 api_log!("Device::create_blas -> {id:?}");
335
336 return (id, Some(handle), None);
337 };
338
339 let id = fid.assign(Fallible::Invalid(Arc::new(error.to_string())));
340 (id, None, Some(error))
341 }
342
343 pub fn device_create_tlas(
344 &self,
345 device_id: id::DeviceId,
346 desc: &resource::TlasDescriptor,
347 id_in: Option<TlasId>,
348 ) -> (TlasId, Option<CreateTlasError>) {
349 profiling::scope!("Device::create_tlas");
350
351 let fid = self.hub.tlas_s.prepare(id_in);
352
353 let error = 'error: {
354 let device = self.hub.devices.get(device_id);
355
356 let tlas = match device.create_tlas(desc) {
357 Ok(tlas) => tlas,
358 Err(e) => break 'error e,
359 };
360
361 #[cfg(feature = "trace")]
362 if let Some(trace) = device.trace.lock().as_mut() {
363 trace.add(Action::CreateTlas {
364 id: tlas.to_trace(),
365 desc: desc.clone(),
366 });
367 }
368
369 let id = fid.assign(Fallible::Valid(tlas));
370 api_log!("Device::create_tlas -> {id:?}");
371
372 return (id, None);
373 };
374
375 let id = fid.assign(Fallible::Invalid(Arc::new(error.to_string())));
376 (id, Some(error))
377 }
378
379 pub fn blas_drop(&self, blas_id: BlasId) {
380 profiling::scope!("Blas::drop");
381 api_log!("Blas::drop {blas_id:?}");
382
383 let _blas = self.hub.blas_s.remove(blas_id);
384
385 #[cfg(feature = "trace")]
386 if let Ok(blas) = _blas.get() {
387 if let Some(t) = blas.device.trace.lock().as_mut() {
388 t.add(Action::DestroyBlas(blas.to_trace()));
389 }
390 }
391 }
392
393 pub fn tlas_drop(&self, tlas_id: TlasId) {
394 profiling::scope!("Tlas::drop");
395 api_log!("Tlas::drop {tlas_id:?}");
396
397 let _tlas = self.hub.tlas_s.remove(tlas_id);
398
399 #[cfg(feature = "trace")]
400 if let Ok(tlas) = _tlas.get() {
401 if let Some(t) = tlas.device.trace.lock().as_mut() {
402 t.add(Action::DestroyTlas(tlas.to_trace()));
403 }
404 }
405 }
406
407 pub fn blas_prepare_compact_async(
408 &self,
409 blas_id: BlasId,
410 callback: Option<BlasCompactCallback>,
411 ) -> Result<crate::SubmissionIndex, BlasPrepareCompactError> {
412 profiling::scope!("Blas::prepare_compact_async");
413 api_log!("Blas::prepare_compact_async {blas_id:?}");
414
415 let hub = &self.hub;
416
417 let compact_result = match hub.blas_s.get(blas_id).get() {
418 Ok(blas) => blas.prepare_compact_async(callback),
419 Err(e) => Err((callback, e.into())),
420 };
421
422 match compact_result {
423 Ok(submission_index) => Ok(submission_index),
424 Err((mut callback, err)) => {
425 if let Some(callback) = callback.take() {
426 callback(Err(err.clone()));
427 }
428 Err(err)
429 }
430 }
431 }
432
433 pub fn ready_for_compaction(&self, blas_id: BlasId) -> Result<bool, InvalidResourceError> {
434 profiling::scope!("Blas::prepare_compact_async");
435 api_log!("Blas::prepare_compact_async {blas_id:?}");
436
437 let hub = &self.hub;
438
439 let blas = hub.blas_s.get(blas_id).get()?;
440
441 let lock = blas.compacted_state.lock();
442
443 Ok(matches!(*lock, BlasCompactState::Ready { .. }))
444 }
445}