wgpu_examples/ray_aabb_compute/
mod.rs

1use std::{borrow::Cow, iter, mem};
2
3use bytemuck::{Pod, Zeroable};
4use glam::{Affine3A, Mat4, Quat, Vec3};
5use wgpu::util::DeviceExt;
6
7use wgpu::StoreOp;
8
9use crate::utils;
10
11#[repr(C)]
12#[derive(Clone, Copy, Pod, Zeroable)]
13struct GpuAabb {
14    min: [f32; 3],
15    max: [f32; 3],
16    _pad: [f32; 2],
17}
18
19#[repr(C)]
20#[derive(Clone, Copy, Pod, Zeroable)]
21struct Uniforms {
22    view_inverse: Mat4,
23    proj_inverse: Mat4,
24}
25
26#[inline]
27fn affine_to_rows(mat: &Affine3A) -> [f32; 12] {
28    let row_0 = mat.matrix3.row(0);
29    let row_1 = mat.matrix3.row(1);
30    let row_2 = mat.matrix3.row(2);
31    let translation = mat.translation;
32    [
33        row_0.x,
34        row_0.y,
35        row_0.z,
36        translation.x,
37        row_1.x,
38        row_1.y,
39        row_1.z,
40        translation.y,
41        row_2.x,
42        row_2.y,
43        row_2.z,
44        translation.z,
45    ]
46}
47
48struct Example {
49    rt_target: wgpu::Texture,
50    #[expect(dead_code)]
51    rt_view: wgpu::TextureView,
52    #[expect(dead_code)]
53    sampler: wgpu::Sampler,
54    #[expect(dead_code)]
55    uniform_buf: wgpu::Buffer,
56    #[expect(dead_code)]
57    aabb_buf: wgpu::Buffer,
58    tlas: wgpu::Tlas,
59    compute_pipeline: wgpu::ComputePipeline,
60    compute_bind_group: wgpu::BindGroup,
61    blit_pipeline: wgpu::RenderPipeline,
62    blit_bind_group: wgpu::BindGroup,
63    animation_timer: utils::AnimationTimer,
64}
65
66impl crate::framework::Example for Example {
67    fn required_features() -> wgpu::Features {
68        wgpu::Features::TEXTURE_BINDING_ARRAY
69            | wgpu::Features::VERTEX_WRITABLE_STORAGE
70            | wgpu::Features::EXPERIMENTAL_RAY_QUERY
71    }
72
73    fn required_downlevel_capabilities() -> wgpu::DownlevelCapabilities {
74        wgpu::DownlevelCapabilities {
75            flags: wgpu::DownlevelFlags::COMPUTE_SHADERS,
76            ..Default::default()
77        }
78    }
79
80    fn required_limits() -> wgpu::Limits {
81        wgpu::Limits::default().using_minimum_supported_acceleration_structure_values()
82    }
83
84    fn init(
85        config: &wgpu::SurfaceConfiguration,
86        _adapter: &wgpu::Adapter,
87        device: &wgpu::Device,
88        queue: &wgpu::Queue,
89    ) -> Self {
90        let aabb_data = [
91            GpuAabb {
92                min: [-3.5, -0.5, -0.5],
93                max: [-1.5, 0.5, 0.5],
94                _pad: [0.0; 2],
95            },
96            GpuAabb {
97                min: [-0.5, -0.5, -0.5],
98                max: [0.5, 0.5, 0.5],
99                _pad: [0.0; 2],
100            },
101            GpuAabb {
102                min: [1.5, -0.5, -0.5],
103                max: [3.5, 0.5, 0.5],
104                _pad: [0.0; 2],
105            },
106        ];
107
108        let rt_target = device.create_texture(&wgpu::TextureDescriptor {
109            label: Some("rt_target"),
110            size: wgpu::Extent3d {
111                width: config.width,
112                height: config.height,
113                depth_or_array_layers: 1,
114            },
115            mip_level_count: 1,
116            sample_count: 1,
117            dimension: wgpu::TextureDimension::D2,
118            format: wgpu::TextureFormat::Rgba8Unorm,
119            usage: wgpu::TextureUsages::TEXTURE_BINDING | wgpu::TextureUsages::STORAGE_BINDING,
120            view_formats: &[wgpu::TextureFormat::Rgba8Unorm],
121        });
122
123        let rt_view = rt_target.create_view(&wgpu::TextureViewDescriptor {
124            label: None,
125            format: Some(wgpu::TextureFormat::Rgba8Unorm),
126            dimension: Some(wgpu::TextureViewDimension::D2),
127            usage: None,
128            aspect: wgpu::TextureAspect::All,
129            base_mip_level: 0,
130            mip_level_count: None,
131            base_array_layer: 0,
132            array_layer_count: None,
133        });
134
135        let sampler = device.create_sampler(&wgpu::SamplerDescriptor {
136            label: Some("rt_sampler"),
137            address_mode_u: wgpu::AddressMode::ClampToEdge,
138            address_mode_v: wgpu::AddressMode::ClampToEdge,
139            address_mode_w: wgpu::AddressMode::ClampToEdge,
140            mag_filter: wgpu::FilterMode::Linear,
141            min_filter: wgpu::FilterMode::Linear,
142            mipmap_filter: wgpu::MipmapFilterMode::Nearest,
143            ..Default::default()
144        });
145
146        let uniforms = {
147            let view =
148                Mat4::look_at_rh(Vec3::new(0.0, 0.5, 5.0), Vec3::new(0.0, 0.0, 0.0), Vec3::Y);
149            let proj = Mat4::perspective_rh(
150                59.0_f32.to_radians(),
151                config.width as f32 / config.height as f32,
152                0.001,
153                1000.0,
154            );
155
156            Uniforms {
157                view_inverse: view.inverse(),
158                proj_inverse: proj.inverse(),
159            }
160        };
161
162        let uniform_buf = device.create_buffer_init(&wgpu::util::BufferInitDescriptor {
163            label: Some("Uniform Buffer"),
164            contents: bytemuck::cast_slice(&[uniforms]),
165            usage: wgpu::BufferUsages::UNIFORM,
166        });
167
168        let aabb_buf = device.create_buffer_init(&wgpu::util::BufferInitDescriptor {
169            label: Some("AABB primitives"),
170            contents: bytemuck::cast_slice(&aabb_data),
171            usage: wgpu::BufferUsages::BLAS_INPUT | wgpu::BufferUsages::STORAGE,
172        });
173
174        let aabb_size_desc = wgpu::BlasAABBGeometrySizeDescriptor {
175            primitive_count: aabb_data.len() as u32,
176            flags: wgpu::AccelerationStructureGeometryFlags::OPAQUE,
177        };
178
179        let blas = device.create_blas(
180            &wgpu::CreateBlasDescriptor {
181                label: None,
182                flags: wgpu::AccelerationStructureFlags::PREFER_FAST_TRACE,
183                update_mode: wgpu::AccelerationStructureUpdateMode::Build,
184            },
185            wgpu::BlasGeometrySizeDescriptors::AABBs {
186                descriptors: vec![aabb_size_desc.clone()],
187            },
188        );
189
190        let mut tlas = device.create_tlas(&wgpu::CreateTlasDescriptor {
191            label: None,
192            flags: wgpu::AccelerationStructureFlags::PREFER_FAST_TRACE,
193            update_mode: wgpu::AccelerationStructureUpdateMode::Build,
194            max_instances: 1,
195        });
196
197        let shader = device.create_shader_module(wgpu::ShaderModuleDescriptor {
198            label: Some("ray_aabb_compute"),
199            source: wgpu::ShaderSource::Wgsl(Cow::Borrowed(include_str!("shader.wgsl"))),
200        });
201
202        let blit_shader = device.create_shader_module(wgpu::ShaderModuleDescriptor {
203            label: Some("blit"),
204            source: wgpu::ShaderSource::Wgsl(Cow::Borrowed(include_str!("blit.wgsl"))),
205        });
206
207        let compute_pipeline = device.create_compute_pipeline(&wgpu::ComputePipelineDescriptor {
208            label: Some("rt_aabb"),
209            layout: None,
210            module: &shader,
211            entry_point: Some("main"),
212            compilation_options: Default::default(),
213            cache: None,
214        });
215
216        let compute_bind_group_layout = compute_pipeline.get_bind_group_layout(0);
217
218        let compute_bind_group = device.create_bind_group(&wgpu::BindGroupDescriptor {
219            label: None,
220            layout: &compute_bind_group_layout,
221            entries: &[
222                wgpu::BindGroupEntry {
223                    binding: 0,
224                    resource: wgpu::BindingResource::TextureView(&rt_view),
225                },
226                wgpu::BindGroupEntry {
227                    binding: 1,
228                    resource: uniform_buf.as_entire_binding(),
229                },
230                wgpu::BindGroupEntry {
231                    binding: 2,
232                    resource: wgpu::BindingResource::AccelerationStructure(&tlas),
233                },
234                wgpu::BindGroupEntry {
235                    binding: 3,
236                    resource: aabb_buf.as_entire_binding(),
237                },
238            ],
239        });
240
241        let blit_pipeline = device.create_render_pipeline(&wgpu::RenderPipelineDescriptor {
242            label: Some("blit"),
243            layout: None,
244            vertex: wgpu::VertexState {
245                module: &blit_shader,
246                entry_point: Some("vs_main"),
247                compilation_options: Default::default(),
248                buffers: &[],
249            },
250            fragment: Some(wgpu::FragmentState {
251                module: &blit_shader,
252                entry_point: Some("fs_main"),
253                compilation_options: Default::default(),
254                targets: &[Some(config.format.into())],
255            }),
256            primitive: wgpu::PrimitiveState {
257                topology: wgpu::PrimitiveTopology::TriangleList,
258                ..Default::default()
259            },
260            depth_stencil: None,
261            multisample: wgpu::MultisampleState::default(),
262            multiview_mask: None,
263            cache: None,
264        });
265
266        let blit_bind_group_layout = blit_pipeline.get_bind_group_layout(0);
267
268        let blit_bind_group = device.create_bind_group(&wgpu::BindGroupDescriptor {
269            label: None,
270            layout: &blit_bind_group_layout,
271            entries: &[
272                wgpu::BindGroupEntry {
273                    binding: 0,
274                    resource: wgpu::BindingResource::TextureView(&rt_view),
275                },
276                wgpu::BindGroupEntry {
277                    binding: 1,
278                    resource: wgpu::BindingResource::Sampler(&sampler),
279                },
280            ],
281        });
282
283        tlas[0] = Some(wgpu::TlasInstance::new(
284            &blas,
285            affine_to_rows(&Affine3A::from_rotation_translation(
286                Quat::IDENTITY,
287                Vec3::new(0.0, 0.0, 0.0),
288            )),
289            0,
290            0xff,
291        ));
292
293        let mut encoder =
294            device.create_command_encoder(&wgpu::CommandEncoderDescriptor { label: None });
295
296        encoder.build_acceleration_structures(
297            iter::once(&wgpu::BlasBuildEntry {
298                blas: &blas,
299                geometry: wgpu::BlasGeometries::AabbGeometries(vec![wgpu::BlasAabbGeometry {
300                    size: &aabb_size_desc,
301                    stride: mem::size_of::<GpuAabb>() as wgpu::BufferAddress,
302                    aabb_buffer: &aabb_buf,
303                    primitive_offset: 0,
304                }]),
305            }),
306            iter::once(&tlas),
307        );
308
309        queue.submit(Some(encoder.finish()));
310
311        Example {
312            rt_target,
313            rt_view,
314            sampler,
315            uniform_buf,
316            aabb_buf,
317            tlas,
318            compute_pipeline,
319            compute_bind_group,
320            blit_pipeline,
321            blit_bind_group,
322            animation_timer: utils::AnimationTimer::default(),
323        }
324    }
325
326    fn update(&mut self, _event: winit::event::WindowEvent) {}
327
328    fn resize(
329        &mut self,
330        _config: &wgpu::SurfaceConfiguration,
331        _device: &wgpu::Device,
332        _queue: &wgpu::Queue,
333    ) {
334    }
335
336    fn render(&mut self, view: &wgpu::TextureView, device: &wgpu::Device, queue: &wgpu::Queue) {
337        let anim_time = self.animation_timer.time();
338
339        self.tlas[0].as_mut().unwrap().transform =
340            affine_to_rows(&Affine3A::from_rotation_translation(
341                Quat::from_rotation_y(anim_time * 0.4),
342                Vec3::new(0.0, 0.0, 0.0),
343            ));
344
345        let mut encoder =
346            device.create_command_encoder(&wgpu::CommandEncoderDescriptor { label: None });
347
348        encoder.build_acceleration_structures(iter::empty(), iter::once(&self.tlas));
349
350        {
351            let mut cpass = encoder.begin_compute_pass(&wgpu::ComputePassDescriptor {
352                label: None,
353                timestamp_writes: None,
354            });
355            cpass.set_pipeline(&self.compute_pipeline);
356            cpass.set_bind_group(0, Some(&self.compute_bind_group), &[]);
357            cpass.dispatch_workgroups(self.rt_target.width() / 8, self.rt_target.height() / 8, 1);
358        }
359
360        {
361            let mut rpass = encoder.begin_render_pass(&wgpu::RenderPassDescriptor {
362                label: None,
363                color_attachments: &[Some(wgpu::RenderPassColorAttachment {
364                    view,
365                    depth_slice: None,
366                    resolve_target: None,
367                    ops: wgpu::Operations {
368                        load: wgpu::LoadOp::Clear(wgpu::Color::GREEN),
369                        store: StoreOp::Store,
370                    },
371                })],
372                depth_stencil_attachment: None,
373                timestamp_writes: None,
374                occlusion_query_set: None,
375                multiview_mask: None,
376            });
377
378            rpass.set_pipeline(&self.blit_pipeline);
379            rpass.set_bind_group(0, Some(&self.blit_bind_group), &[]);
380            rpass.draw(0..3, 0..1);
381        }
382
383        queue.submit(Some(encoder.finish()));
384    }
385}
386
387pub fn main() {
388    crate::framework::run::<Example>("ray-aabb");
389}
390
391#[cfg(test)]
392#[wgpu_test::gpu_test]
393pub static TEST: crate::framework::ExampleTestParams = crate::framework::ExampleTestParams {
394    name: "ray_aabb_compute",
395    image_path: "/examples/features/src/ray_aabb_compute/screenshot.png",
396    width: 1024,
397    height: 768,
398    optional_features: wgpu::Features::default(),
399    base_test_parameters: wgpu_test::TestParameters::default()
400        // Metal has no AABB intersection in ray queries yet; image compare fails.
401        // https://github.com/gfx-rs/wgpu/pull/9304
402        // https://github.com/gfx-rs/wgpu/issues/9100
403        .expect_fail(wgpu_test::FailureCase::backend(wgpu::Backends::METAL)),
404    comparisons: &[wgpu_test::ComparisonType::Mean(0.02)],
405    _phantom: std::marker::PhantomData::<Example>,
406};