1use alloc::{string::ToString as _, sync::Arc, vec::Vec};
2use core::mem::{size_of, ManuallyDrop};
3
4#[cfg(feature = "trace")]
5use crate::device::trace::{Action, IntoTrace};
6use crate::device::DeviceError;
7use crate::{
8 api_log,
9 device::Device,
10 global::Global,
11 hal_label,
12 id::{self, BlasId, TlasId},
13 lock::RwLock,
14 lock::{rank, Mutex},
15 ray_tracing::BlasPrepareCompactError,
16 ray_tracing::{CreateBlasError, CreateTlasError},
17 resource,
18 resource::{
19 BlasCompactCallback, BlasCompactState, Fallible, InvalidResourceError, TrackingData,
20 },
21 snatch::Snatchable,
22 LabelHelpers,
23};
24use hal::AccelerationStructureTriangleIndices;
25use wgt::Features;
26
27impl Device {
28 pub fn create_blas(
29 self: &Arc<Self>,
30 blas_desc: &resource::BlasDescriptor,
31 sizes: wgt::BlasGeometrySizeDescriptors,
32 ) -> Result<Arc<resource::Blas>, CreateBlasError> {
33 self.check_is_valid()?;
34 self.require_features(Features::EXPERIMENTAL_RAY_QUERY)?;
35
36 if blas_desc
37 .flags
38 .contains(wgt::AccelerationStructureFlags::ALLOW_RAY_HIT_VERTEX_RETURN)
39 {
40 self.require_features(Features::EXPERIMENTAL_RAY_HIT_VERTEX_RETURN)?;
41 }
42
43 let size_info = match &sizes {
44 wgt::BlasGeometrySizeDescriptors::Triangles { descriptors } => {
45 if descriptors.len() as u32 > self.limits.max_blas_geometry_count {
46 return Err(CreateBlasError::TooManyGeometries(
47 self.limits.max_blas_geometry_count,
48 descriptors.len() as u32,
49 ));
50 }
51
52 let mut entries =
53 Vec::<hal::AccelerationStructureTriangles<dyn hal::DynBuffer>>::with_capacity(
54 descriptors.len(),
55 );
56 for desc in descriptors {
57 if desc.index_count.is_some() != desc.index_format.is_some() {
58 return Err(CreateBlasError::MissingIndexData);
59 }
60 let indices =
61 desc.index_count
62 .map(|count| AccelerationStructureTriangleIndices::<
63 dyn hal::DynBuffer,
64 > {
65 format: desc.index_format.unwrap(),
66 buffer: None,
67 offset: 0,
68 count,
69 });
70 if !self
71 .features
72 .allowed_vertex_formats_for_blas()
73 .contains(&desc.vertex_format)
74 {
75 return Err(CreateBlasError::InvalidVertexFormat(
76 desc.vertex_format,
77 self.features.allowed_vertex_formats_for_blas(),
78 ));
79 }
80
81 let mut transform = None;
82
83 if blas_desc
84 .flags
85 .contains(wgt::AccelerationStructureFlags::USE_TRANSFORM)
86 {
87 transform = Some(wgpu_hal::AccelerationStructureTriangleTransform {
88 buffer: self.zero_buffer.as_ref(),
89 offset: 0,
90 })
91 }
92
93 if desc.vertex_count > self.limits.max_blas_primitive_count {
94 return Err(CreateBlasError::TooManyPrimitives(
95 self.limits.max_blas_primitive_count,
96 desc.vertex_count,
97 ));
98 }
99
100 entries.push(hal::AccelerationStructureTriangles::<dyn hal::DynBuffer> {
101 vertex_buffer: None,
102 vertex_format: desc.vertex_format,
103 first_vertex: 0,
104 vertex_count: desc.vertex_count,
105 vertex_stride: 0,
106 indices,
107 transform,
108 flags: desc.flags,
109 });
110 }
111 unsafe {
112 self.raw().get_acceleration_structure_build_sizes(
113 &hal::GetAccelerationStructureBuildSizesDescriptor {
114 entries: &hal::AccelerationStructureEntries::Triangles(entries),
115 flags: blas_desc.flags,
116 },
117 )
118 }
119 }
120 };
121
122 let raw = unsafe {
123 self.raw()
124 .create_acceleration_structure(&hal::AccelerationStructureDescriptor {
125 label: blas_desc.label.as_deref(),
126 size: size_info.acceleration_structure_size,
127 format: hal::AccelerationStructureFormat::BottomLevel,
128 allow_compaction: blas_desc
129 .flags
130 .contains(wgpu_types::AccelerationStructureFlags::ALLOW_COMPACTION),
131 })
132 }
133 .map_err(|e| self.handle_hal_error_with_nonfatal_oom(e))?;
134
135 let compaction_buffer = if blas_desc
136 .flags
137 .contains(wgpu_types::AccelerationStructureFlags::ALLOW_COMPACTION)
138 {
139 Some(ManuallyDrop::new(unsafe {
140 self.raw()
141 .create_buffer(&hal::BufferDescriptor {
142 label: Some("(wgpu internal) compaction read-back buffer"),
143 size: size_of::<wgpu_types::BufferAddress>() as wgpu_types::BufferAddress,
144 usage: wgpu_types::BufferUses::ACCELERATION_STRUCTURE_QUERY
145 | wgpu_types::BufferUses::MAP_READ,
146 memory_flags: hal::MemoryFlags::PREFER_COHERENT,
147 })
148 .map_err(DeviceError::from_hal)?
149 }))
150 } else {
151 None
152 };
153
154 let handle = unsafe {
155 self.raw()
156 .get_acceleration_structure_device_address(raw.as_ref())
157 };
158
159 Ok(Arc::new(resource::Blas {
160 raw: Snatchable::new(raw),
161 device: self.clone(),
162 size_info,
163 sizes,
164 flags: blas_desc.flags,
165 update_mode: blas_desc.update_mode,
166 handle,
167 label: blas_desc.label.to_string(),
168 built_index: RwLock::new(rank::BLAS_BUILT_INDEX, None),
169 tracking_data: TrackingData::new(self.tracker_indices.blas_s.clone()),
170 compaction_buffer,
171 compacted_state: Mutex::new(rank::BLAS_COMPACTION_STATE, BlasCompactState::Idle),
172 }))
173 }
174
175 pub fn create_tlas(
176 self: &Arc<Self>,
177 desc: &resource::TlasDescriptor,
178 ) -> Result<Arc<resource::Tlas>, CreateTlasError> {
179 self.check_is_valid()?;
180 self.require_features(Features::EXPERIMENTAL_RAY_QUERY)?;
181
182 if desc.max_instances > self.limits.max_tlas_instance_count {
183 return Err(CreateTlasError::TooManyInstances(
184 self.limits.max_tlas_instance_count,
185 desc.max_instances,
186 ));
187 }
188
189 if desc
190 .flags
191 .contains(wgt::AccelerationStructureFlags::USE_TRANSFORM)
192 {
193 return Err(CreateTlasError::DisallowedFlag(
194 wgt::AccelerationStructureFlags::USE_TRANSFORM,
195 ));
196 }
197
198 if desc
199 .flags
200 .contains(wgt::AccelerationStructureFlags::ALLOW_RAY_HIT_VERTEX_RETURN)
201 {
202 self.require_features(Features::EXPERIMENTAL_RAY_HIT_VERTEX_RETURN)?;
203 }
204
205 let size_info = unsafe {
206 self.raw().get_acceleration_structure_build_sizes(
207 &hal::GetAccelerationStructureBuildSizesDescriptor {
208 entries: &hal::AccelerationStructureEntries::Instances(
209 hal::AccelerationStructureInstances {
210 buffer: None,
211 offset: 0,
212 count: desc.max_instances,
213 },
214 ),
215 flags: desc.flags,
216 },
217 )
218 };
219
220 let raw = unsafe {
221 self.raw()
222 .create_acceleration_structure(&hal::AccelerationStructureDescriptor {
223 label: desc.label.as_deref(),
224 size: size_info.acceleration_structure_size,
225 format: hal::AccelerationStructureFormat::TopLevel,
226 allow_compaction: false,
227 })
228 }
229 .map_err(|e| self.handle_hal_error_with_nonfatal_oom(e))?;
230
231 let instance_buffer_size =
232 self.alignments.raw_tlas_instance_size * desc.max_instances.max(1) as usize;
233 let instance_buffer = unsafe {
234 self.raw().create_buffer(&hal::BufferDescriptor {
235 label: hal_label(Some("(wgpu-core) instances_buffer"), self.instance_flags),
236 size: instance_buffer_size as u64,
237 usage: wgt::BufferUses::COPY_DST
238 | wgt::BufferUses::TOP_LEVEL_ACCELERATION_STRUCTURE_INPUT,
239 memory_flags: hal::MemoryFlags::PREFER_COHERENT,
240 })
241 }
242 .map_err(|e| self.handle_hal_error_with_nonfatal_oom(e))?;
243
244 Ok(Arc::new(resource::Tlas {
245 raw: Snatchable::new(raw),
246 device: self.clone(),
247 size_info,
248 flags: desc.flags,
249 update_mode: desc.update_mode,
250 built_index: RwLock::new(rank::TLAS_BUILT_INDEX, None),
251 dependencies: RwLock::new(rank::TLAS_DEPENDENCIES, Vec::new()),
252 instance_buffer: ManuallyDrop::new(instance_buffer),
253 label: desc.label.to_string(),
254 max_instance_count: desc.max_instances,
255 tracking_data: TrackingData::new(self.tracker_indices.tlas_s.clone()),
256 }))
257 }
258}
259
260impl Global {
261 pub fn device_create_blas(
262 &self,
263 device_id: id::DeviceId,
264 desc: &resource::BlasDescriptor,
265 sizes: wgt::BlasGeometrySizeDescriptors,
266 id_in: Option<BlasId>,
267 ) -> (BlasId, Option<u64>, Option<CreateBlasError>) {
268 profiling::scope!("Device::create_blas");
269
270 let fid = self.hub.blas_s.prepare(id_in);
271
272 let error = 'error: {
273 let device = self.hub.devices.get(device_id);
274
275 #[cfg(feature = "trace")]
276 let trace_sizes = sizes.clone();
277
278 let blas = match device.create_blas(desc, sizes) {
279 Ok(blas) => blas,
280 Err(e) => break 'error e,
281 };
282 let handle = blas.handle;
283
284 #[cfg(feature = "trace")]
285 if let Some(trace) = device.trace.lock().as_mut() {
286 trace.add(Action::CreateBlas {
287 id: blas.to_trace(),
288 desc: desc.clone(),
289 sizes: trace_sizes,
290 });
291 }
292
293 let id = fid.assign(Fallible::Valid(blas));
294 api_log!("Device::create_blas -> {id:?}");
295
296 return (id, Some(handle), None);
297 };
298
299 let id = fid.assign(Fallible::Invalid(Arc::new(error.to_string())));
300 (id, None, Some(error))
301 }
302
303 pub fn device_create_tlas(
304 &self,
305 device_id: id::DeviceId,
306 desc: &resource::TlasDescriptor,
307 id_in: Option<TlasId>,
308 ) -> (TlasId, Option<CreateTlasError>) {
309 profiling::scope!("Device::create_tlas");
310
311 let fid = self.hub.tlas_s.prepare(id_in);
312
313 let error = 'error: {
314 let device = self.hub.devices.get(device_id);
315
316 let tlas = match device.create_tlas(desc) {
317 Ok(tlas) => tlas,
318 Err(e) => break 'error e,
319 };
320
321 #[cfg(feature = "trace")]
322 if let Some(trace) = device.trace.lock().as_mut() {
323 trace.add(Action::CreateTlas {
324 id: tlas.to_trace(),
325 desc: desc.clone(),
326 });
327 }
328
329 let id = fid.assign(Fallible::Valid(tlas));
330 api_log!("Device::create_tlas -> {id:?}");
331
332 return (id, None);
333 };
334
335 let id = fid.assign(Fallible::Invalid(Arc::new(error.to_string())));
336 (id, Some(error))
337 }
338
339 pub fn blas_drop(&self, blas_id: BlasId) {
340 profiling::scope!("Blas::drop");
341 api_log!("Blas::drop {blas_id:?}");
342
343 let _blas = self.hub.blas_s.remove(blas_id);
344
345 #[cfg(feature = "trace")]
346 if let Ok(blas) = _blas.get() {
347 if let Some(t) = blas.device.trace.lock().as_mut() {
348 t.add(Action::DestroyBlas(blas.to_trace()));
349 }
350 }
351 }
352
353 pub fn tlas_drop(&self, tlas_id: TlasId) {
354 profiling::scope!("Tlas::drop");
355 api_log!("Tlas::drop {tlas_id:?}");
356
357 let _tlas = self.hub.tlas_s.remove(tlas_id);
358
359 #[cfg(feature = "trace")]
360 if let Ok(tlas) = _tlas.get() {
361 if let Some(t) = tlas.device.trace.lock().as_mut() {
362 t.add(Action::DestroyTlas(tlas.to_trace()));
363 }
364 }
365 }
366
367 pub fn blas_prepare_compact_async(
368 &self,
369 blas_id: BlasId,
370 callback: Option<BlasCompactCallback>,
371 ) -> Result<crate::SubmissionIndex, BlasPrepareCompactError> {
372 profiling::scope!("Blas::prepare_compact_async");
373 api_log!("Blas::prepare_compact_async {blas_id:?}");
374
375 let hub = &self.hub;
376
377 let compact_result = match hub.blas_s.get(blas_id).get() {
378 Ok(blas) => blas.prepare_compact_async(callback),
379 Err(e) => Err((callback, e.into())),
380 };
381
382 match compact_result {
383 Ok(submission_index) => Ok(submission_index),
384 Err((mut callback, err)) => {
385 if let Some(callback) = callback.take() {
386 callback(Err(err.clone()));
387 }
388 Err(err)
389 }
390 }
391 }
392
393 pub fn ready_for_compaction(&self, blas_id: BlasId) -> Result<bool, InvalidResourceError> {
394 profiling::scope!("Blas::prepare_compact_async");
395 api_log!("Blas::prepare_compact_async {blas_id:?}");
396
397 let hub = &self.hub;
398
399 let blas = hub.blas_s.get(blas_id).get()?;
400
401 let lock = blas.compacted_state.lock();
402
403 Ok(matches!(*lock, BlasCompactState::Ready { .. }))
404 }
405}