wgpu_examples/ray_cube_compute/
mod.rs

1use std::{borrow::Cow, future::Future, iter, mem, pin::Pin, task};
2
3use bytemuck::{Pod, Zeroable};
4use glam::{Affine3A, Mat4, Quat, Vec3};
5use wgpu::util::DeviceExt;
6
7use wgpu::StoreOp;
8
9use crate::utils;
10
11// from cube
12#[repr(C)]
13#[derive(Clone, Copy, Pod, Zeroable)]
14struct Vertex {
15    _pos: [f32; 4],
16    _tex_coord: [f32; 2],
17}
18
19fn vertex(pos: [i8; 3], tc: [i8; 2]) -> Vertex {
20    Vertex {
21        _pos: [pos[0] as f32, pos[1] as f32, pos[2] as f32, 1.0],
22        _tex_coord: [tc[0] as f32, tc[1] as f32],
23    }
24}
25
26fn create_vertices() -> (Vec<Vertex>, Vec<u16>) {
27    let vertex_data = [
28        // top (0, 0, 1)
29        vertex([-1, -1, 1], [0, 0]),
30        vertex([1, -1, 1], [1, 0]),
31        vertex([1, 1, 1], [1, 1]),
32        vertex([-1, 1, 1], [0, 1]),
33        // bottom (0, 0, -1)
34        vertex([-1, 1, -1], [1, 0]),
35        vertex([1, 1, -1], [0, 0]),
36        vertex([1, -1, -1], [0, 1]),
37        vertex([-1, -1, -1], [1, 1]),
38        // right (1, 0, 0)
39        vertex([1, -1, -1], [0, 0]),
40        vertex([1, 1, -1], [1, 0]),
41        vertex([1, 1, 1], [1, 1]),
42        vertex([1, -1, 1], [0, 1]),
43        // left (-1, 0, 0)
44        vertex([-1, -1, 1], [1, 0]),
45        vertex([-1, 1, 1], [0, 0]),
46        vertex([-1, 1, -1], [0, 1]),
47        vertex([-1, -1, -1], [1, 1]),
48        // front (0, 1, 0)
49        vertex([1, 1, -1], [1, 0]),
50        vertex([-1, 1, -1], [0, 0]),
51        vertex([-1, 1, 1], [0, 1]),
52        vertex([1, 1, 1], [1, 1]),
53        // back (0, -1, 0)
54        vertex([1, -1, 1], [0, 0]),
55        vertex([-1, -1, 1], [1, 0]),
56        vertex([-1, -1, -1], [1, 1]),
57        vertex([1, -1, -1], [0, 1]),
58    ];
59
60    let index_data: &[u16] = &[
61        0, 1, 2, 2, 3, 0, // top
62        4, 5, 6, 6, 7, 4, // bottom
63        8, 9, 10, 10, 11, 8, // right
64        12, 13, 14, 14, 15, 12, // left
65        16, 17, 18, 18, 19, 16, // front
66        20, 21, 22, 22, 23, 20, // back
67    ];
68
69    (vertex_data.to_vec(), index_data.to_vec())
70}
71
72#[repr(C)]
73#[derive(Clone, Copy, Pod, Zeroable)]
74struct Uniforms {
75    view_inverse: Mat4,
76    proj_inverse: Mat4,
77}
78
79#[inline]
80fn affine_to_rows(mat: &Affine3A) -> [f32; 12] {
81    let row_0 = mat.matrix3.row(0);
82    let row_1 = mat.matrix3.row(1);
83    let row_2 = mat.matrix3.row(2);
84    let translation = mat.translation;
85    [
86        row_0.x,
87        row_0.y,
88        row_0.z,
89        translation.x,
90        row_1.x,
91        row_1.y,
92        row_1.z,
93        translation.y,
94        row_2.x,
95        row_2.y,
96        row_2.z,
97        translation.z,
98    ]
99}
100
101/// A wrapper for `pop_error_scope` futures that panics if an error occurs.
102///
103/// Given a future `inner` of an `Option<E>` for some error type `E`,
104/// wait for the future to be ready, and panic if its value is `Some`.
105///
106/// This can be done simpler with `FutureExt`, but we don't want to add
107/// a dependency just for this small case.
108struct ErrorFuture<F> {
109    inner: F,
110}
111impl<F: Future<Output = Option<wgpu::Error>>> Future for ErrorFuture<F> {
112    type Output = ();
113    fn poll(self: Pin<&mut Self>, cx: &mut task::Context<'_>) -> task::Poll<()> {
114        let inner = unsafe { self.map_unchecked_mut(|me| &mut me.inner) };
115        inner.poll(cx).map(|error| {
116            if let Some(e) = error {
117                panic!("Rendering {e}");
118            }
119        })
120    }
121}
122
123struct Example {
124    rt_target: wgpu::Texture,
125    #[expect(dead_code)]
126    rt_view: wgpu::TextureView,
127    #[expect(dead_code)]
128    sampler: wgpu::Sampler,
129    #[expect(dead_code)]
130    uniform_buf: wgpu::Buffer,
131    #[expect(dead_code)]
132    vertex_buf: wgpu::Buffer,
133    #[expect(dead_code)]
134    index_buf: wgpu::Buffer,
135    tlas: wgpu::Tlas,
136    compute_pipeline: wgpu::ComputePipeline,
137    compute_bind_group: wgpu::BindGroup,
138    blit_pipeline: wgpu::RenderPipeline,
139    blit_bind_group: wgpu::BindGroup,
140    animation_timer: utils::AnimationTimer,
141}
142
143impl crate::framework::Example for Example {
144    fn required_features() -> wgpu::Features {
145        wgpu::Features::TEXTURE_BINDING_ARRAY
146            | wgpu::Features::VERTEX_WRITABLE_STORAGE
147            | wgpu::Features::EXPERIMENTAL_RAY_QUERY
148    }
149
150    fn required_downlevel_capabilities() -> wgpu::DownlevelCapabilities {
151        wgpu::DownlevelCapabilities {
152            flags: wgpu::DownlevelFlags::COMPUTE_SHADERS,
153            ..Default::default()
154        }
155    }
156
157    fn required_limits() -> wgpu::Limits {
158        wgpu::Limits::default().using_minimum_supported_acceleration_structure_values()
159    }
160
161    fn init(
162        config: &wgpu::SurfaceConfiguration,
163        _adapter: &wgpu::Adapter,
164        device: &wgpu::Device,
165        queue: &wgpu::Queue,
166    ) -> Self {
167        let side_count = 8;
168
169        let rt_target = device.create_texture(&wgpu::TextureDescriptor {
170            label: Some("rt_target"),
171            size: wgpu::Extent3d {
172                width: config.width,
173                height: config.height,
174                depth_or_array_layers: 1,
175            },
176            mip_level_count: 1,
177            sample_count: 1,
178            dimension: wgpu::TextureDimension::D2,
179            format: wgpu::TextureFormat::Rgba8Unorm,
180            usage: wgpu::TextureUsages::TEXTURE_BINDING | wgpu::TextureUsages::STORAGE_BINDING,
181            view_formats: &[wgpu::TextureFormat::Rgba8Unorm],
182        });
183
184        let rt_view = rt_target.create_view(&wgpu::TextureViewDescriptor {
185            label: None,
186            format: Some(wgpu::TextureFormat::Rgba8Unorm),
187            dimension: Some(wgpu::TextureViewDimension::D2),
188            usage: None,
189            aspect: wgpu::TextureAspect::All,
190            base_mip_level: 0,
191            mip_level_count: None,
192            base_array_layer: 0,
193            array_layer_count: None,
194        });
195
196        let sampler = device.create_sampler(&wgpu::SamplerDescriptor {
197            label: Some("rt_sampler"),
198            address_mode_u: wgpu::AddressMode::ClampToEdge,
199            address_mode_v: wgpu::AddressMode::ClampToEdge,
200            address_mode_w: wgpu::AddressMode::ClampToEdge,
201            mag_filter: wgpu::FilterMode::Linear,
202            min_filter: wgpu::FilterMode::Linear,
203            mipmap_filter: wgpu::FilterMode::Nearest,
204            ..Default::default()
205        });
206
207        let uniforms = {
208            let view = Mat4::look_at_rh(Vec3::new(0.0, 0.0, 2.5), Vec3::ZERO, Vec3::Y);
209            let proj = Mat4::perspective_rh(
210                59.0_f32.to_radians(),
211                config.width as f32 / config.height as f32,
212                0.001,
213                1000.0,
214            );
215
216            Uniforms {
217                view_inverse: view.inverse(),
218                proj_inverse: proj.inverse(),
219            }
220        };
221
222        let uniform_buf = device.create_buffer_init(&wgpu::util::BufferInitDescriptor {
223            label: Some("Uniform Buffer"),
224            contents: bytemuck::cast_slice(&[uniforms]),
225            usage: wgpu::BufferUsages::UNIFORM,
226        });
227
228        let (vertex_data, index_data) = create_vertices();
229
230        let vertex_buf = device.create_buffer_init(&wgpu::util::BufferInitDescriptor {
231            label: Some("Vertex Buffer"),
232            contents: bytemuck::cast_slice(&vertex_data),
233            usage: wgpu::BufferUsages::VERTEX | wgpu::BufferUsages::BLAS_INPUT,
234        });
235
236        let index_buf = device.create_buffer_init(&wgpu::util::BufferInitDescriptor {
237            label: Some("Index Buffer"),
238            contents: bytemuck::cast_slice(&index_data),
239            usage: wgpu::BufferUsages::INDEX | wgpu::BufferUsages::BLAS_INPUT,
240        });
241
242        let blas_geo_size_desc = wgpu::BlasTriangleGeometrySizeDescriptor {
243            vertex_format: wgpu::VertexFormat::Float32x3,
244            vertex_count: vertex_data.len() as u32,
245            index_format: Some(wgpu::IndexFormat::Uint16),
246            index_count: Some(index_data.len() as u32),
247            flags: wgpu::AccelerationStructureGeometryFlags::OPAQUE,
248        };
249
250        let blas = device.create_blas(
251            &wgpu::CreateBlasDescriptor {
252                label: None,
253                flags: wgpu::AccelerationStructureFlags::PREFER_FAST_TRACE,
254                update_mode: wgpu::AccelerationStructureUpdateMode::Build,
255            },
256            wgpu::BlasGeometrySizeDescriptors::Triangles {
257                descriptors: vec![blas_geo_size_desc.clone()],
258            },
259        );
260
261        let mut tlas = device.create_tlas(&wgpu::CreateTlasDescriptor {
262            label: None,
263            flags: wgpu::AccelerationStructureFlags::PREFER_FAST_TRACE,
264            update_mode: wgpu::AccelerationStructureUpdateMode::Build,
265            max_instances: side_count * side_count,
266        });
267
268        let shader = device.create_shader_module(wgpu::ShaderModuleDescriptor {
269            label: Some("rt_computer"),
270            source: wgpu::ShaderSource::Wgsl(Cow::Borrowed(include_str!("shader.wgsl"))),
271        });
272
273        let blit_shader = device.create_shader_module(wgpu::ShaderModuleDescriptor {
274            label: Some("blit"),
275            source: wgpu::ShaderSource::Wgsl(Cow::Borrowed(include_str!("blit.wgsl"))),
276        });
277
278        let compute_pipeline = device.create_compute_pipeline(&wgpu::ComputePipelineDescriptor {
279            label: Some("rt"),
280            layout: None,
281            module: &shader,
282            entry_point: Some("main"),
283            compilation_options: Default::default(),
284            cache: None,
285        });
286
287        let compute_bind_group_layout = compute_pipeline.get_bind_group_layout(0);
288
289        let compute_bind_group = device.create_bind_group(&wgpu::BindGroupDescriptor {
290            label: None,
291            layout: &compute_bind_group_layout,
292            entries: &[
293                wgpu::BindGroupEntry {
294                    binding: 0,
295                    resource: wgpu::BindingResource::TextureView(&rt_view),
296                },
297                wgpu::BindGroupEntry {
298                    binding: 1,
299                    resource: uniform_buf.as_entire_binding(),
300                },
301                wgpu::BindGroupEntry {
302                    binding: 2,
303                    resource: wgpu::BindingResource::AccelerationStructure(&tlas),
304                },
305            ],
306        });
307
308        let blit_pipeline = device.create_render_pipeline(&wgpu::RenderPipelineDescriptor {
309            label: Some("blit"),
310            layout: None,
311            vertex: wgpu::VertexState {
312                module: &blit_shader,
313                entry_point: Some("vs_main"),
314                compilation_options: Default::default(),
315                buffers: &[],
316            },
317            fragment: Some(wgpu::FragmentState {
318                module: &blit_shader,
319                entry_point: Some("fs_main"),
320                compilation_options: Default::default(),
321                targets: &[Some(config.format.into())],
322            }),
323            primitive: wgpu::PrimitiveState {
324                topology: wgpu::PrimitiveTopology::TriangleList,
325                ..Default::default()
326            },
327            depth_stencil: None,
328            multisample: wgpu::MultisampleState::default(),
329            multiview: None,
330            cache: None,
331        });
332
333        let blit_bind_group_layout = blit_pipeline.get_bind_group_layout(0);
334
335        let blit_bind_group = device.create_bind_group(&wgpu::BindGroupDescriptor {
336            label: None,
337            layout: &blit_bind_group_layout,
338            entries: &[
339                wgpu::BindGroupEntry {
340                    binding: 0,
341                    resource: wgpu::BindingResource::TextureView(&rt_view),
342                },
343                wgpu::BindGroupEntry {
344                    binding: 1,
345                    resource: wgpu::BindingResource::Sampler(&sampler),
346                },
347            ],
348        });
349
350        let dist = 3.0;
351
352        for x in 0..side_count {
353            for y in 0..side_count {
354                tlas[(x + y * side_count) as usize] = Some(wgpu::TlasInstance::new(
355                    &blas,
356                    affine_to_rows(&Affine3A::from_rotation_translation(
357                        Quat::from_rotation_y(45.9_f32.to_radians()),
358                        Vec3 {
359                            x: x as f32 * dist,
360                            y: y as f32 * dist,
361                            z: -30.0,
362                        },
363                    )),
364                    0,
365                    0xff,
366                ));
367            }
368        }
369
370        let mut encoder =
371            device.create_command_encoder(&wgpu::CommandEncoderDescriptor { label: None });
372
373        encoder.build_acceleration_structures(
374            iter::once(&wgpu::BlasBuildEntry {
375                blas: &blas,
376                geometry: wgpu::BlasGeometries::TriangleGeometries(vec![
377                    wgpu::BlasTriangleGeometry {
378                        size: &blas_geo_size_desc,
379                        vertex_buffer: &vertex_buf,
380                        first_vertex: 0,
381                        vertex_stride: mem::size_of::<Vertex>() as u64,
382                        index_buffer: Some(&index_buf),
383                        first_index: Some(0),
384                        transform_buffer: None,
385                        transform_buffer_offset: None,
386                    },
387                ]),
388            }),
389            iter::once(&tlas),
390        );
391
392        queue.submit(Some(encoder.finish()));
393
394        Example {
395            rt_target,
396            rt_view,
397            sampler,
398            uniform_buf,
399            vertex_buf,
400            index_buf,
401            tlas,
402            compute_pipeline,
403            compute_bind_group,
404            blit_pipeline,
405            blit_bind_group,
406            animation_timer: utils::AnimationTimer::default(),
407        }
408    }
409
410    fn update(&mut self, _event: winit::event::WindowEvent) {
411        //empty
412    }
413
414    fn resize(
415        &mut self,
416        _config: &wgpu::SurfaceConfiguration,
417        _device: &wgpu::Device,
418        _queue: &wgpu::Queue,
419    ) {
420    }
421
422    fn render(&mut self, view: &wgpu::TextureView, device: &wgpu::Device, queue: &wgpu::Queue) {
423        device.push_error_scope(wgpu::ErrorFilter::Validation);
424
425        let anim_time = self.animation_timer.time();
426
427        self.tlas[0].as_mut().unwrap().transform =
428            affine_to_rows(&Affine3A::from_rotation_translation(
429                Quat::from_euler(
430                    glam::EulerRot::XYZ,
431                    anim_time * 0.342,
432                    anim_time * 0.254,
433                    anim_time * 0.832,
434                ),
435                Vec3 {
436                    x: 0.0,
437                    y: 0.0,
438                    z: -6.0,
439                },
440            ));
441
442        let mut encoder =
443            device.create_command_encoder(&wgpu::CommandEncoderDescriptor { label: None });
444
445        encoder.build_acceleration_structures(iter::empty(), iter::once(&self.tlas));
446
447        {
448            let mut cpass = encoder.begin_compute_pass(&wgpu::ComputePassDescriptor {
449                label: None,
450                timestamp_writes: None,
451            });
452            cpass.set_pipeline(&self.compute_pipeline);
453            cpass.set_bind_group(0, Some(&self.compute_bind_group), &[]);
454            cpass.dispatch_workgroups(self.rt_target.width() / 8, self.rt_target.height() / 8, 1);
455        }
456
457        {
458            let mut rpass = encoder.begin_render_pass(&wgpu::RenderPassDescriptor {
459                label: None,
460                color_attachments: &[Some(wgpu::RenderPassColorAttachment {
461                    view,
462                    depth_slice: None,
463                    resolve_target: None,
464                    ops: wgpu::Operations {
465                        load: wgpu::LoadOp::Clear(wgpu::Color::GREEN),
466                        store: StoreOp::Store,
467                    },
468                })],
469                depth_stencil_attachment: None,
470                timestamp_writes: None,
471                occlusion_query_set: None,
472            });
473
474            rpass.set_pipeline(&self.blit_pipeline);
475            rpass.set_bind_group(0, Some(&self.blit_bind_group), &[]);
476            rpass.draw(0..3, 0..1);
477        }
478
479        queue.submit(Some(encoder.finish()));
480    }
481}
482
483pub fn main() {
484    crate::framework::run::<Example>("ray-cube");
485}
486
487#[cfg(test)]
488#[wgpu_test::gpu_test]
489pub static TEST: crate::framework::ExampleTestParams = crate::framework::ExampleTestParams {
490    name: "ray_cube_compute",
491    image_path: "/examples/features/src/ray_cube_compute/screenshot.png",
492    width: 1024,
493    height: 768,
494    optional_features: wgpu::Features::default(),
495    base_test_parameters: wgpu_test::TestParameters::default(),
496    comparisons: &[wgpu_test::ComparisonType::Mean(0.02)],
497    _phantom: std::marker::PhantomData::<Example>,
498};