1use std::{borrow::Cow, iter, mem};
2
3use bytemuck::{Pod, Zeroable};
4use glam::{Affine3A, Mat4, Quat, Vec3};
5use wgpu::util::DeviceExt;
6
7use wgpu::StoreOp;
8
9use crate::utils;
10
11#[repr(C)]
12#[derive(Clone, Copy, Pod, Zeroable)]
13struct GpuAabb {
14 min: [f32; 3],
15 max: [f32; 3],
16 _pad: [f32; 2],
17}
18
19#[repr(C)]
20#[derive(Clone, Copy, Pod, Zeroable)]
21struct Uniforms {
22 view_inverse: Mat4,
23 proj_inverse: Mat4,
24}
25
26#[inline]
27fn affine_to_rows(mat: &Affine3A) -> [f32; 12] {
28 let row_0 = mat.matrix3.row(0);
29 let row_1 = mat.matrix3.row(1);
30 let row_2 = mat.matrix3.row(2);
31 let translation = mat.translation;
32 [
33 row_0.x,
34 row_0.y,
35 row_0.z,
36 translation.x,
37 row_1.x,
38 row_1.y,
39 row_1.z,
40 translation.y,
41 row_2.x,
42 row_2.y,
43 row_2.z,
44 translation.z,
45 ]
46}
47
48struct Example {
49 rt_target: wgpu::Texture,
50 #[expect(dead_code)]
51 rt_view: wgpu::TextureView,
52 #[expect(dead_code)]
53 sampler: wgpu::Sampler,
54 #[expect(dead_code)]
55 uniform_buf: wgpu::Buffer,
56 #[expect(dead_code)]
57 aabb_buf: wgpu::Buffer,
58 tlas: wgpu::Tlas,
59 compute_pipeline: wgpu::ComputePipeline,
60 compute_bind_group: wgpu::BindGroup,
61 blit_pipeline: wgpu::RenderPipeline,
62 blit_bind_group: wgpu::BindGroup,
63 animation_timer: utils::AnimationTimer,
64}
65
66impl crate::framework::Example for Example {
67 fn required_features() -> wgpu::Features {
68 wgpu::Features::TEXTURE_BINDING_ARRAY
69 | wgpu::Features::VERTEX_WRITABLE_STORAGE
70 | wgpu::Features::EXPERIMENTAL_RAY_QUERY
71 }
72
73 fn required_downlevel_capabilities() -> wgpu::DownlevelCapabilities {
74 wgpu::DownlevelCapabilities {
75 flags: wgpu::DownlevelFlags::COMPUTE_SHADERS,
76 ..Default::default()
77 }
78 }
79
80 fn required_limits() -> wgpu::Limits {
81 wgpu::Limits::default().using_minimum_supported_acceleration_structure_values()
82 }
83
84 fn init(
85 config: &wgpu::SurfaceConfiguration,
86 _adapter: &wgpu::Adapter,
87 device: &wgpu::Device,
88 queue: &wgpu::Queue,
89 ) -> Self {
90 let aabb_data = [
91 GpuAabb {
92 min: [-3.5, -0.5, -0.5],
93 max: [-1.5, 0.5, 0.5],
94 _pad: [0.0; 2],
95 },
96 GpuAabb {
97 min: [-0.5, -0.5, -0.5],
98 max: [0.5, 0.5, 0.5],
99 _pad: [0.0; 2],
100 },
101 GpuAabb {
102 min: [1.5, -0.5, -0.5],
103 max: [3.5, 0.5, 0.5],
104 _pad: [0.0; 2],
105 },
106 ];
107
108 let rt_target = device.create_texture(&wgpu::TextureDescriptor {
109 label: Some("rt_target"),
110 size: wgpu::Extent3d {
111 width: config.width,
112 height: config.height,
113 depth_or_array_layers: 1,
114 },
115 mip_level_count: 1,
116 sample_count: 1,
117 dimension: wgpu::TextureDimension::D2,
118 format: wgpu::TextureFormat::Rgba8Unorm,
119 usage: wgpu::TextureUsages::TEXTURE_BINDING | wgpu::TextureUsages::STORAGE_BINDING,
120 view_formats: &[wgpu::TextureFormat::Rgba8Unorm],
121 });
122
123 let rt_view = rt_target.create_view(&wgpu::TextureViewDescriptor {
124 label: None,
125 format: Some(wgpu::TextureFormat::Rgba8Unorm),
126 dimension: Some(wgpu::TextureViewDimension::D2),
127 usage: None,
128 aspect: wgpu::TextureAspect::All,
129 base_mip_level: 0,
130 mip_level_count: None,
131 base_array_layer: 0,
132 array_layer_count: None,
133 });
134
135 let sampler = device.create_sampler(&wgpu::SamplerDescriptor {
136 label: Some("rt_sampler"),
137 address_mode_u: wgpu::AddressMode::ClampToEdge,
138 address_mode_v: wgpu::AddressMode::ClampToEdge,
139 address_mode_w: wgpu::AddressMode::ClampToEdge,
140 mag_filter: wgpu::FilterMode::Linear,
141 min_filter: wgpu::FilterMode::Linear,
142 mipmap_filter: wgpu::MipmapFilterMode::Nearest,
143 ..Default::default()
144 });
145
146 let uniforms = {
147 let view =
148 Mat4::look_at_rh(Vec3::new(0.0, 0.5, 5.0), Vec3::new(0.0, 0.0, 0.0), Vec3::Y);
149 let proj = Mat4::perspective_rh(
150 59.0_f32.to_radians(),
151 config.width as f32 / config.height as f32,
152 0.001,
153 1000.0,
154 );
155
156 Uniforms {
157 view_inverse: view.inverse(),
158 proj_inverse: proj.inverse(),
159 }
160 };
161
162 let uniform_buf = device.create_buffer_init(&wgpu::util::BufferInitDescriptor {
163 label: Some("Uniform Buffer"),
164 contents: bytemuck::cast_slice(&[uniforms]),
165 usage: wgpu::BufferUsages::UNIFORM,
166 });
167
168 let aabb_buf = device.create_buffer_init(&wgpu::util::BufferInitDescriptor {
169 label: Some("AABB primitives"),
170 contents: bytemuck::cast_slice(&aabb_data),
171 usage: wgpu::BufferUsages::BLAS_INPUT | wgpu::BufferUsages::STORAGE,
172 });
173
174 let aabb_size_desc = wgpu::BlasAABBGeometrySizeDescriptor {
175 primitive_count: aabb_data.len() as u32,
176 flags: wgpu::AccelerationStructureGeometryFlags::OPAQUE,
177 };
178
179 let blas = device.create_blas(
180 &wgpu::CreateBlasDescriptor {
181 label: None,
182 flags: wgpu::AccelerationStructureFlags::PREFER_FAST_TRACE,
183 update_mode: wgpu::AccelerationStructureUpdateMode::Build,
184 },
185 wgpu::BlasGeometrySizeDescriptors::AABBs {
186 descriptors: vec![aabb_size_desc.clone()],
187 },
188 );
189
190 let mut tlas = device.create_tlas(&wgpu::CreateTlasDescriptor {
191 label: None,
192 flags: wgpu::AccelerationStructureFlags::PREFER_FAST_TRACE,
193 update_mode: wgpu::AccelerationStructureUpdateMode::Build,
194 max_instances: 1,
195 });
196
197 let shader = device.create_shader_module(wgpu::ShaderModuleDescriptor {
198 label: Some("ray_aabb_compute"),
199 source: wgpu::ShaderSource::Wgsl(Cow::Borrowed(include_str!("shader.wgsl"))),
200 });
201
202 let blit_shader = device.create_shader_module(wgpu::ShaderModuleDescriptor {
203 label: Some("blit"),
204 source: wgpu::ShaderSource::Wgsl(Cow::Borrowed(include_str!("blit.wgsl"))),
205 });
206
207 let compute_pipeline = device.create_compute_pipeline(&wgpu::ComputePipelineDescriptor {
208 label: Some("rt_aabb"),
209 layout: None,
210 module: &shader,
211 entry_point: Some("main"),
212 compilation_options: Default::default(),
213 cache: None,
214 });
215
216 let compute_bind_group_layout = compute_pipeline.get_bind_group_layout(0);
217
218 let compute_bind_group = device.create_bind_group(&wgpu::BindGroupDescriptor {
219 label: None,
220 layout: &compute_bind_group_layout,
221 entries: &[
222 wgpu::BindGroupEntry {
223 binding: 0,
224 resource: wgpu::BindingResource::TextureView(&rt_view),
225 },
226 wgpu::BindGroupEntry {
227 binding: 1,
228 resource: uniform_buf.as_entire_binding(),
229 },
230 wgpu::BindGroupEntry {
231 binding: 2,
232 resource: wgpu::BindingResource::AccelerationStructure(&tlas),
233 },
234 wgpu::BindGroupEntry {
235 binding: 3,
236 resource: aabb_buf.as_entire_binding(),
237 },
238 ],
239 });
240
241 let blit_pipeline = device.create_render_pipeline(&wgpu::RenderPipelineDescriptor {
242 label: Some("blit"),
243 layout: None,
244 vertex: wgpu::VertexState {
245 module: &blit_shader,
246 entry_point: Some("vs_main"),
247 compilation_options: Default::default(),
248 buffers: &[],
249 },
250 fragment: Some(wgpu::FragmentState {
251 module: &blit_shader,
252 entry_point: Some("fs_main"),
253 compilation_options: Default::default(),
254 targets: &[Some(config.format.into())],
255 }),
256 primitive: wgpu::PrimitiveState {
257 topology: wgpu::PrimitiveTopology::TriangleList,
258 ..Default::default()
259 },
260 depth_stencil: None,
261 multisample: wgpu::MultisampleState::default(),
262 multiview_mask: None,
263 cache: None,
264 });
265
266 let blit_bind_group_layout = blit_pipeline.get_bind_group_layout(0);
267
268 let blit_bind_group = device.create_bind_group(&wgpu::BindGroupDescriptor {
269 label: None,
270 layout: &blit_bind_group_layout,
271 entries: &[
272 wgpu::BindGroupEntry {
273 binding: 0,
274 resource: wgpu::BindingResource::TextureView(&rt_view),
275 },
276 wgpu::BindGroupEntry {
277 binding: 1,
278 resource: wgpu::BindingResource::Sampler(&sampler),
279 },
280 ],
281 });
282
283 tlas[0] = Some(wgpu::TlasInstance::new(
284 &blas,
285 affine_to_rows(&Affine3A::from_rotation_translation(
286 Quat::IDENTITY,
287 Vec3::new(0.0, 0.0, 0.0),
288 )),
289 0,
290 0xff,
291 ));
292
293 let mut encoder =
294 device.create_command_encoder(&wgpu::CommandEncoderDescriptor { label: None });
295
296 encoder.build_acceleration_structures(
297 iter::once(&wgpu::BlasBuildEntry {
298 blas: &blas,
299 geometry: wgpu::BlasGeometries::AabbGeometries(vec![wgpu::BlasAabbGeometry {
300 size: &aabb_size_desc,
301 stride: mem::size_of::<GpuAabb>() as wgpu::BufferAddress,
302 aabb_buffer: &aabb_buf,
303 primitive_offset: 0,
304 }]),
305 }),
306 iter::once(&tlas),
307 );
308
309 queue.submit(Some(encoder.finish()));
310
311 Example {
312 rt_target,
313 rt_view,
314 sampler,
315 uniform_buf,
316 aabb_buf,
317 tlas,
318 compute_pipeline,
319 compute_bind_group,
320 blit_pipeline,
321 blit_bind_group,
322 animation_timer: utils::AnimationTimer::default(),
323 }
324 }
325
326 fn update(&mut self, _event: winit::event::WindowEvent) {}
327
328 fn resize(
329 &mut self,
330 _config: &wgpu::SurfaceConfiguration,
331 _device: &wgpu::Device,
332 _queue: &wgpu::Queue,
333 ) {
334 }
335
336 fn render(&mut self, view: &wgpu::TextureView, device: &wgpu::Device, queue: &wgpu::Queue) {
337 let anim_time = self.animation_timer.time();
338
339 self.tlas[0].as_mut().unwrap().transform =
340 affine_to_rows(&Affine3A::from_rotation_translation(
341 Quat::from_rotation_y(anim_time * 0.4),
342 Vec3::new(0.0, 0.0, 0.0),
343 ));
344
345 let mut encoder =
346 device.create_command_encoder(&wgpu::CommandEncoderDescriptor { label: None });
347
348 encoder.build_acceleration_structures(iter::empty(), iter::once(&self.tlas));
349
350 {
351 let mut cpass = encoder.begin_compute_pass(&wgpu::ComputePassDescriptor {
352 label: None,
353 timestamp_writes: None,
354 });
355 cpass.set_pipeline(&self.compute_pipeline);
356 cpass.set_bind_group(0, Some(&self.compute_bind_group), &[]);
357 cpass.dispatch_workgroups(self.rt_target.width() / 8, self.rt_target.height() / 8, 1);
358 }
359
360 {
361 let mut rpass = encoder.begin_render_pass(&wgpu::RenderPassDescriptor {
362 label: None,
363 color_attachments: &[Some(wgpu::RenderPassColorAttachment {
364 view,
365 depth_slice: None,
366 resolve_target: None,
367 ops: wgpu::Operations {
368 load: wgpu::LoadOp::Clear(wgpu::Color::GREEN),
369 store: StoreOp::Store,
370 },
371 })],
372 depth_stencil_attachment: None,
373 timestamp_writes: None,
374 occlusion_query_set: None,
375 multiview_mask: None,
376 });
377
378 rpass.set_pipeline(&self.blit_pipeline);
379 rpass.set_bind_group(0, Some(&self.blit_bind_group), &[]);
380 rpass.draw(0..3, 0..1);
381 }
382
383 queue.submit(Some(encoder.finish()));
384 }
385}
386
387pub fn main() {
388 crate::framework::run::<Example>("ray-aabb");
389}
390
391#[cfg(test)]
392#[wgpu_test::gpu_test]
393pub static TEST: crate::framework::ExampleTestParams = crate::framework::ExampleTestParams {
394 name: "ray_aabb_compute",
395 image_path: "/examples/features/src/ray_aabb_compute/screenshot.png",
396 width: 1024,
397 height: 768,
398 optional_features: wgpu::Features::default(),
399 base_test_parameters: wgpu_test::TestParameters::default()
400 .expect_fail(wgpu_test::FailureCase::backend(wgpu::Backends::METAL)),
404 comparisons: &[wgpu_test::ComparisonType::Mean(0.02)],
405 _phantom: std::marker::PhantomData::<Example>,
406};