wgpu_examples/ray_cube_compute/
mod.rs

1use std::{borrow::Cow, iter, mem};
2
3use bytemuck::{Pod, Zeroable};
4use glam::{Affine3A, Mat4, Quat, Vec3};
5use wgpu::util::DeviceExt;
6
7use wgpu::StoreOp;
8
9use crate::utils;
10
11// from cube
12#[repr(C)]
13#[derive(Clone, Copy, Pod, Zeroable)]
14struct Vertex {
15    _pos: [f32; 4],
16    _tex_coord: [f32; 2],
17}
18
19fn vertex(pos: [i8; 3], tc: [i8; 2]) -> Vertex {
20    Vertex {
21        _pos: [pos[0] as f32, pos[1] as f32, pos[2] as f32, 1.0],
22        _tex_coord: [tc[0] as f32, tc[1] as f32],
23    }
24}
25
26fn create_vertices() -> (Vec<Vertex>, Vec<u16>) {
27    let vertex_data = [
28        // top (0, 0, 1)
29        vertex([-1, -1, 1], [0, 0]),
30        vertex([1, -1, 1], [1, 0]),
31        vertex([1, 1, 1], [1, 1]),
32        vertex([-1, 1, 1], [0, 1]),
33        // bottom (0, 0, -1)
34        vertex([-1, 1, -1], [1, 0]),
35        vertex([1, 1, -1], [0, 0]),
36        vertex([1, -1, -1], [0, 1]),
37        vertex([-1, -1, -1], [1, 1]),
38        // right (1, 0, 0)
39        vertex([1, -1, -1], [0, 0]),
40        vertex([1, 1, -1], [1, 0]),
41        vertex([1, 1, 1], [1, 1]),
42        vertex([1, -1, 1], [0, 1]),
43        // left (-1, 0, 0)
44        vertex([-1, -1, 1], [1, 0]),
45        vertex([-1, 1, 1], [0, 0]),
46        vertex([-1, 1, -1], [0, 1]),
47        vertex([-1, -1, -1], [1, 1]),
48        // front (0, 1, 0)
49        vertex([1, 1, -1], [1, 0]),
50        vertex([-1, 1, -1], [0, 0]),
51        vertex([-1, 1, 1], [0, 1]),
52        vertex([1, 1, 1], [1, 1]),
53        // back (0, -1, 0)
54        vertex([1, -1, 1], [0, 0]),
55        vertex([-1, -1, 1], [1, 0]),
56        vertex([-1, -1, -1], [1, 1]),
57        vertex([1, -1, -1], [0, 1]),
58    ];
59
60    let index_data: &[u16] = &[
61        0, 1, 2, 2, 3, 0, // top
62        4, 5, 6, 6, 7, 4, // bottom
63        8, 9, 10, 10, 11, 8, // right
64        12, 13, 14, 14, 15, 12, // left
65        16, 17, 18, 18, 19, 16, // front
66        20, 21, 22, 22, 23, 20, // back
67    ];
68
69    (vertex_data.to_vec(), index_data.to_vec())
70}
71
72#[repr(C)]
73#[derive(Clone, Copy, Pod, Zeroable)]
74struct Uniforms {
75    view_inverse: Mat4,
76    proj_inverse: Mat4,
77}
78
79#[inline]
80fn affine_to_rows(mat: &Affine3A) -> [f32; 12] {
81    let row_0 = mat.matrix3.row(0);
82    let row_1 = mat.matrix3.row(1);
83    let row_2 = mat.matrix3.row(2);
84    let translation = mat.translation;
85    [
86        row_0.x,
87        row_0.y,
88        row_0.z,
89        translation.x,
90        row_1.x,
91        row_1.y,
92        row_1.z,
93        translation.y,
94        row_2.x,
95        row_2.y,
96        row_2.z,
97        translation.z,
98    ]
99}
100
101struct Example {
102    rt_target: wgpu::Texture,
103    #[expect(dead_code)]
104    rt_view: wgpu::TextureView,
105    #[expect(dead_code)]
106    sampler: wgpu::Sampler,
107    #[expect(dead_code)]
108    uniform_buf: wgpu::Buffer,
109    #[expect(dead_code)]
110    vertex_buf: wgpu::Buffer,
111    #[expect(dead_code)]
112    index_buf: wgpu::Buffer,
113    tlas: wgpu::Tlas,
114    compute_pipeline: wgpu::ComputePipeline,
115    compute_bind_group: wgpu::BindGroup,
116    blit_pipeline: wgpu::RenderPipeline,
117    blit_bind_group: wgpu::BindGroup,
118    animation_timer: utils::AnimationTimer,
119}
120
121impl crate::framework::Example for Example {
122    fn required_features() -> wgpu::Features {
123        wgpu::Features::TEXTURE_BINDING_ARRAY
124            | wgpu::Features::VERTEX_WRITABLE_STORAGE
125            | wgpu::Features::EXPERIMENTAL_RAY_QUERY
126    }
127
128    fn required_downlevel_capabilities() -> wgpu::DownlevelCapabilities {
129        wgpu::DownlevelCapabilities {
130            flags: wgpu::DownlevelFlags::COMPUTE_SHADERS,
131            ..Default::default()
132        }
133    }
134
135    fn required_limits() -> wgpu::Limits {
136        wgpu::Limits::default().using_minimum_supported_acceleration_structure_values()
137    }
138
139    fn init(
140        config: &wgpu::SurfaceConfiguration,
141        _adapter: &wgpu::Adapter,
142        device: &wgpu::Device,
143        queue: &wgpu::Queue,
144    ) -> Self {
145        let side_count = 8;
146
147        let rt_target = device.create_texture(&wgpu::TextureDescriptor {
148            label: Some("rt_target"),
149            size: wgpu::Extent3d {
150                width: config.width,
151                height: config.height,
152                depth_or_array_layers: 1,
153            },
154            mip_level_count: 1,
155            sample_count: 1,
156            dimension: wgpu::TextureDimension::D2,
157            format: wgpu::TextureFormat::Rgba8Unorm,
158            usage: wgpu::TextureUsages::TEXTURE_BINDING | wgpu::TextureUsages::STORAGE_BINDING,
159            view_formats: &[wgpu::TextureFormat::Rgba8Unorm],
160        });
161
162        let rt_view = rt_target.create_view(&wgpu::TextureViewDescriptor {
163            label: None,
164            format: Some(wgpu::TextureFormat::Rgba8Unorm),
165            dimension: Some(wgpu::TextureViewDimension::D2),
166            usage: None,
167            aspect: wgpu::TextureAspect::All,
168            base_mip_level: 0,
169            mip_level_count: None,
170            base_array_layer: 0,
171            array_layer_count: None,
172        });
173
174        let sampler = device.create_sampler(&wgpu::SamplerDescriptor {
175            label: Some("rt_sampler"),
176            address_mode_u: wgpu::AddressMode::ClampToEdge,
177            address_mode_v: wgpu::AddressMode::ClampToEdge,
178            address_mode_w: wgpu::AddressMode::ClampToEdge,
179            mag_filter: wgpu::FilterMode::Linear,
180            min_filter: wgpu::FilterMode::Linear,
181            mipmap_filter: wgpu::MipmapFilterMode::Nearest,
182            ..Default::default()
183        });
184
185        let uniforms = {
186            let view = Mat4::look_at_rh(Vec3::new(0.0, 0.0, 2.5), Vec3::ZERO, Vec3::Y);
187            let proj = Mat4::perspective_rh(
188                59.0_f32.to_radians(),
189                config.width as f32 / config.height as f32,
190                0.001,
191                1000.0,
192            );
193
194            Uniforms {
195                view_inverse: view.inverse(),
196                proj_inverse: proj.inverse(),
197            }
198        };
199
200        let uniform_buf = device.create_buffer_init(&wgpu::util::BufferInitDescriptor {
201            label: Some("Uniform Buffer"),
202            contents: bytemuck::cast_slice(&[uniforms]),
203            usage: wgpu::BufferUsages::UNIFORM,
204        });
205
206        let (vertex_data, index_data) = create_vertices();
207
208        let vertex_buf = device.create_buffer_init(&wgpu::util::BufferInitDescriptor {
209            label: Some("Vertex Buffer"),
210            contents: bytemuck::cast_slice(&vertex_data),
211            usage: wgpu::BufferUsages::VERTEX | wgpu::BufferUsages::BLAS_INPUT,
212        });
213
214        let index_buf = device.create_buffer_init(&wgpu::util::BufferInitDescriptor {
215            label: Some("Index Buffer"),
216            contents: bytemuck::cast_slice(&index_data),
217            usage: wgpu::BufferUsages::INDEX | wgpu::BufferUsages::BLAS_INPUT,
218        });
219
220        let blas_geo_size_desc = wgpu::BlasTriangleGeometrySizeDescriptor {
221            vertex_format: wgpu::VertexFormat::Float32x3,
222            vertex_count: vertex_data.len() as u32,
223            index_format: Some(wgpu::IndexFormat::Uint16),
224            index_count: Some(index_data.len() as u32),
225            flags: wgpu::AccelerationStructureGeometryFlags::OPAQUE,
226        };
227
228        let blas = device.create_blas(
229            &wgpu::CreateBlasDescriptor {
230                label: None,
231                flags: wgpu::AccelerationStructureFlags::PREFER_FAST_TRACE,
232                update_mode: wgpu::AccelerationStructureUpdateMode::Build,
233            },
234            wgpu::BlasGeometrySizeDescriptors::Triangles {
235                descriptors: vec![blas_geo_size_desc.clone()],
236            },
237        );
238
239        let mut tlas = device.create_tlas(&wgpu::CreateTlasDescriptor {
240            label: None,
241            flags: wgpu::AccelerationStructureFlags::PREFER_FAST_TRACE,
242            update_mode: wgpu::AccelerationStructureUpdateMode::Build,
243            max_instances: side_count * side_count,
244        });
245
246        let shader = device.create_shader_module(wgpu::ShaderModuleDescriptor {
247            label: Some("rt_computer"),
248            source: wgpu::ShaderSource::Wgsl(Cow::Borrowed(include_str!("shader.wgsl"))),
249        });
250
251        let blit_shader = device.create_shader_module(wgpu::ShaderModuleDescriptor {
252            label: Some("blit"),
253            source: wgpu::ShaderSource::Wgsl(Cow::Borrowed(include_str!("blit.wgsl"))),
254        });
255
256        let compute_pipeline = device.create_compute_pipeline(&wgpu::ComputePipelineDescriptor {
257            label: Some("rt"),
258            layout: None,
259            module: &shader,
260            entry_point: Some("main"),
261            compilation_options: Default::default(),
262            cache: None,
263        });
264
265        let compute_bind_group_layout = compute_pipeline.get_bind_group_layout(0);
266
267        let compute_bind_group = device.create_bind_group(&wgpu::BindGroupDescriptor {
268            label: None,
269            layout: &compute_bind_group_layout,
270            entries: &[
271                wgpu::BindGroupEntry {
272                    binding: 0,
273                    resource: wgpu::BindingResource::TextureView(&rt_view),
274                },
275                wgpu::BindGroupEntry {
276                    binding: 1,
277                    resource: uniform_buf.as_entire_binding(),
278                },
279                wgpu::BindGroupEntry {
280                    binding: 2,
281                    resource: wgpu::BindingResource::AccelerationStructure(&tlas),
282                },
283            ],
284        });
285
286        let blit_pipeline = device.create_render_pipeline(&wgpu::RenderPipelineDescriptor {
287            label: Some("blit"),
288            layout: None,
289            vertex: wgpu::VertexState {
290                module: &blit_shader,
291                entry_point: Some("vs_main"),
292                compilation_options: Default::default(),
293                buffers: &[],
294            },
295            fragment: Some(wgpu::FragmentState {
296                module: &blit_shader,
297                entry_point: Some("fs_main"),
298                compilation_options: Default::default(),
299                targets: &[Some(config.format.into())],
300            }),
301            primitive: wgpu::PrimitiveState {
302                topology: wgpu::PrimitiveTopology::TriangleList,
303                ..Default::default()
304            },
305            depth_stencil: None,
306            multisample: wgpu::MultisampleState::default(),
307            multiview_mask: None,
308            cache: None,
309        });
310
311        let blit_bind_group_layout = blit_pipeline.get_bind_group_layout(0);
312
313        let blit_bind_group = device.create_bind_group(&wgpu::BindGroupDescriptor {
314            label: None,
315            layout: &blit_bind_group_layout,
316            entries: &[
317                wgpu::BindGroupEntry {
318                    binding: 0,
319                    resource: wgpu::BindingResource::TextureView(&rt_view),
320                },
321                wgpu::BindGroupEntry {
322                    binding: 1,
323                    resource: wgpu::BindingResource::Sampler(&sampler),
324                },
325            ],
326        });
327
328        let dist = 3.0;
329
330        for x in 0..side_count {
331            for y in 0..side_count {
332                tlas[(x + y * side_count) as usize] = Some(wgpu::TlasInstance::new(
333                    &blas,
334                    affine_to_rows(&Affine3A::from_rotation_translation(
335                        Quat::from_rotation_y(45.9_f32.to_radians()),
336                        Vec3 {
337                            x: x as f32 * dist,
338                            y: y as f32 * dist,
339                            z: -30.0,
340                        },
341                    )),
342                    0,
343                    0xff,
344                ));
345            }
346        }
347
348        let mut encoder =
349            device.create_command_encoder(&wgpu::CommandEncoderDescriptor { label: None });
350
351        encoder.build_acceleration_structures(
352            iter::once(&wgpu::BlasBuildEntry {
353                blas: &blas,
354                geometry: wgpu::BlasGeometries::TriangleGeometries(vec![
355                    wgpu::BlasTriangleGeometry {
356                        size: &blas_geo_size_desc,
357                        vertex_buffer: &vertex_buf,
358                        first_vertex: 0,
359                        vertex_stride: mem::size_of::<Vertex>() as u64,
360                        index_buffer: Some(&index_buf),
361                        first_index: Some(0),
362                        transform_buffer: None,
363                        transform_buffer_offset: None,
364                    },
365                ]),
366            }),
367            iter::once(&tlas),
368        );
369
370        queue.submit(Some(encoder.finish()));
371
372        Example {
373            rt_target,
374            rt_view,
375            sampler,
376            uniform_buf,
377            vertex_buf,
378            index_buf,
379            tlas,
380            compute_pipeline,
381            compute_bind_group,
382            blit_pipeline,
383            blit_bind_group,
384            animation_timer: utils::AnimationTimer::default(),
385        }
386    }
387
388    fn update(&mut self, _event: winit::event::WindowEvent) {
389        //empty
390    }
391
392    fn resize(
393        &mut self,
394        _config: &wgpu::SurfaceConfiguration,
395        _device: &wgpu::Device,
396        _queue: &wgpu::Queue,
397    ) {
398    }
399
400    fn render(&mut self, view: &wgpu::TextureView, device: &wgpu::Device, queue: &wgpu::Queue) {
401        let anim_time = self.animation_timer.time();
402
403        self.tlas[0].as_mut().unwrap().transform =
404            affine_to_rows(&Affine3A::from_rotation_translation(
405                Quat::from_euler(
406                    glam::EulerRot::XYZ,
407                    anim_time * 0.342,
408                    anim_time * 0.254,
409                    anim_time * 0.832,
410                ),
411                Vec3 {
412                    x: 0.0,
413                    y: 0.0,
414                    z: -6.0,
415                },
416            ));
417
418        let mut encoder =
419            device.create_command_encoder(&wgpu::CommandEncoderDescriptor { label: None });
420
421        encoder.build_acceleration_structures(iter::empty(), iter::once(&self.tlas));
422
423        {
424            let mut cpass = encoder.begin_compute_pass(&wgpu::ComputePassDescriptor {
425                label: None,
426                timestamp_writes: None,
427            });
428            cpass.set_pipeline(&self.compute_pipeline);
429            cpass.set_bind_group(0, Some(&self.compute_bind_group), &[]);
430            cpass.dispatch_workgroups(self.rt_target.width() / 8, self.rt_target.height() / 8, 1);
431        }
432
433        {
434            let mut rpass = encoder.begin_render_pass(&wgpu::RenderPassDescriptor {
435                label: None,
436                color_attachments: &[Some(wgpu::RenderPassColorAttachment {
437                    view,
438                    depth_slice: None,
439                    resolve_target: None,
440                    ops: wgpu::Operations {
441                        load: wgpu::LoadOp::Clear(wgpu::Color::GREEN),
442                        store: StoreOp::Store,
443                    },
444                })],
445                depth_stencil_attachment: None,
446                timestamp_writes: None,
447                occlusion_query_set: None,
448                multiview_mask: None,
449            });
450
451            rpass.set_pipeline(&self.blit_pipeline);
452            rpass.set_bind_group(0, Some(&self.blit_bind_group), &[]);
453            rpass.draw(0..3, 0..1);
454        }
455
456        queue.submit(Some(encoder.finish()));
457    }
458}
459
460pub fn main() {
461    crate::framework::run::<Example>("ray-cube");
462}
463
464#[cfg(test)]
465#[wgpu_test::gpu_test]
466pub static TEST: crate::framework::ExampleTestParams = crate::framework::ExampleTestParams {
467    name: "ray_cube_compute",
468    image_path: "/examples/features/src/ray_cube_compute/screenshot.png",
469    width: 1024,
470    height: 768,
471    optional_features: wgpu::Features::default(),
472    base_test_parameters: wgpu_test::TestParameters::default()
473        // https://github.com/gfx-rs/wgpu/issues/9100
474        .expect_fail(wgpu_test::FailureCase::backend(wgpu::Backends::METAL)),
475    comparisons: &[wgpu_test::ComparisonType::Mean(0.02)],
476    _phantom: std::marker::PhantomData::<Example>,
477};