wgpu_examples/ray_cube_normals/
mod.rs

1use std::{borrow::Cow, future::Future, iter, mem, pin::Pin, task};
2
3use bytemuck::{Pod, Zeroable};
4use glam::{Affine3A, Mat4, Quat, Vec3};
5use wgpu::util::DeviceExt;
6
7use crate::utils;
8use wgpu::StoreOp;
9
10// from cube
11#[repr(C)]
12#[derive(Clone, Copy, Pod, Zeroable)]
13struct Vertex {
14    _pos: [f32; 4],
15    _tex_coord: [f32; 2],
16}
17
18fn vertex(pos: [i8; 3], tc: [i8; 2]) -> Vertex {
19    Vertex {
20        _pos: [pos[0] as f32, pos[1] as f32, pos[2] as f32, 1.0],
21        _tex_coord: [tc[0] as f32, tc[1] as f32],
22    }
23}
24
25fn create_vertices() -> (Vec<Vertex>, Vec<u16>) {
26    let vertex_data = [
27        // top (0, 0, 1)
28        vertex([-1, -1, 1], [0, 0]),
29        vertex([1, -1, 1], [1, 0]),
30        vertex([1, 1, 1], [1, 1]),
31        vertex([-1, 1, 1], [0, 1]),
32        // bottom (0, 0, -1)
33        vertex([-1, 1, -1], [1, 0]),
34        vertex([1, 1, -1], [0, 0]),
35        vertex([1, -1, -1], [0, 1]),
36        vertex([-1, -1, -1], [1, 1]),
37        // right (1, 0, 0)
38        vertex([1, -1, -1], [0, 0]),
39        vertex([1, 1, -1], [1, 0]),
40        vertex([1, 1, 1], [1, 1]),
41        vertex([1, -1, 1], [0, 1]),
42        // left (-1, 0, 0)
43        vertex([-1, -1, 1], [1, 0]),
44        vertex([-1, 1, 1], [0, 0]),
45        vertex([-1, 1, -1], [0, 1]),
46        vertex([-1, -1, -1], [1, 1]),
47        // front (0, 1, 0)
48        vertex([1, 1, -1], [1, 0]),
49        vertex([-1, 1, -1], [0, 0]),
50        vertex([-1, 1, 1], [0, 1]),
51        vertex([1, 1, 1], [1, 1]),
52        // back (0, -1, 0)
53        vertex([1, -1, 1], [0, 0]),
54        vertex([-1, -1, 1], [1, 0]),
55        vertex([-1, -1, -1], [1, 1]),
56        vertex([1, -1, -1], [0, 1]),
57    ];
58
59    let index_data: &[u16] = &[
60        0, 1, 2, 2, 3, 0, // top
61        4, 5, 6, 6, 7, 4, // bottom
62        8, 9, 10, 10, 11, 8, // right
63        12, 13, 14, 14, 15, 12, // left
64        16, 17, 18, 18, 19, 16, // front
65        20, 21, 22, 22, 23, 20, // back
66    ];
67
68    (vertex_data.to_vec(), index_data.to_vec())
69}
70
71#[repr(C)]
72#[derive(Clone, Copy, Pod, Zeroable)]
73struct Uniforms {
74    view_inverse: Mat4,
75    proj_inverse: Mat4,
76}
77
78#[inline]
79fn affine_to_rows(mat: &Affine3A) -> [f32; 12] {
80    let row_0 = mat.matrix3.row(0);
81    let row_1 = mat.matrix3.row(1);
82    let row_2 = mat.matrix3.row(2);
83    let translation = mat.translation;
84    [
85        row_0.x,
86        row_0.y,
87        row_0.z,
88        translation.x,
89        row_1.x,
90        row_1.y,
91        row_1.z,
92        translation.y,
93        row_2.x,
94        row_2.y,
95        row_2.z,
96        translation.z,
97    ]
98}
99
100/// A wrapper for `pop_error_scope` futures that panics if an error occurs.
101///
102/// Given a future `inner` of an `Option<E>` for some error type `E`,
103/// wait for the future to be ready, and panic if its value is `Some`.
104///
105/// This can be done simpler with `FutureExt`, but we don't want to add
106/// a dependency just for this small case.
107struct ErrorFuture<F> {
108    inner: F,
109}
110impl<F: Future<Output = Option<wgpu::Error>>> Future for ErrorFuture<F> {
111    type Output = ();
112    fn poll(self: Pin<&mut Self>, cx: &mut task::Context<'_>) -> task::Poll<()> {
113        let inner = unsafe { self.map_unchecked_mut(|me| &mut me.inner) };
114        inner.poll(cx).map(|error| {
115            if let Some(e) = error {
116                panic!("Rendering {e}");
117            }
118        })
119    }
120}
121
122struct Example {
123    rt_target: wgpu::Texture,
124    tlas: wgpu::Tlas,
125    compute_pipeline: wgpu::ComputePipeline,
126    compute_bind_group: wgpu::BindGroup,
127    blit_pipeline: wgpu::RenderPipeline,
128    blit_bind_group: wgpu::BindGroup,
129    animation_timer: utils::AnimationTimer,
130}
131
132impl crate::framework::Example for Example {
133    // Don't want srgb, so normals show up better.
134    const SRGB: bool = false;
135    fn required_features() -> wgpu::Features {
136        wgpu::Features::EXPERIMENTAL_RAY_QUERY | wgpu::Features::EXPERIMENTAL_RAY_HIT_VERTEX_RETURN
137    }
138
139    fn required_downlevel_capabilities() -> wgpu::DownlevelCapabilities {
140        wgpu::DownlevelCapabilities {
141            flags: wgpu::DownlevelFlags::COMPUTE_SHADERS,
142            ..Default::default()
143        }
144    }
145
146    fn required_limits() -> wgpu::Limits {
147        wgpu::Limits::default().using_minimum_supported_acceleration_structure_values()
148    }
149
150    fn init(
151        config: &wgpu::SurfaceConfiguration,
152        _adapter: &wgpu::Adapter,
153        device: &wgpu::Device,
154        queue: &wgpu::Queue,
155    ) -> Self {
156        let side_count = 8;
157
158        let rt_target = device.create_texture(&wgpu::TextureDescriptor {
159            label: Some("rt_target"),
160            size: wgpu::Extent3d {
161                width: config.width,
162                height: config.height,
163                depth_or_array_layers: 1,
164            },
165            mip_level_count: 1,
166            sample_count: 1,
167            dimension: wgpu::TextureDimension::D2,
168            format: wgpu::TextureFormat::Rgba8Unorm,
169            usage: wgpu::TextureUsages::TEXTURE_BINDING | wgpu::TextureUsages::STORAGE_BINDING,
170            view_formats: &[wgpu::TextureFormat::Rgba8Unorm],
171        });
172
173        let rt_view = rt_target.create_view(&wgpu::TextureViewDescriptor {
174            label: None,
175            format: Some(wgpu::TextureFormat::Rgba8Unorm),
176            dimension: Some(wgpu::TextureViewDimension::D2),
177            usage: None,
178            aspect: wgpu::TextureAspect::All,
179            base_mip_level: 0,
180            mip_level_count: None,
181            base_array_layer: 0,
182            array_layer_count: None,
183        });
184
185        let sampler = device.create_sampler(&wgpu::SamplerDescriptor {
186            label: Some("rt_sampler"),
187            address_mode_u: wgpu::AddressMode::ClampToEdge,
188            address_mode_v: wgpu::AddressMode::ClampToEdge,
189            address_mode_w: wgpu::AddressMode::ClampToEdge,
190            mag_filter: wgpu::FilterMode::Linear,
191            min_filter: wgpu::FilterMode::Linear,
192            mipmap_filter: wgpu::FilterMode::Nearest,
193            ..Default::default()
194        });
195
196        let uniforms = {
197            let view = Mat4::look_at_rh(Vec3::new(0.0, 0.0, 2.5), Vec3::ZERO, Vec3::Y);
198            let proj = Mat4::perspective_rh(
199                59.0_f32.to_radians(),
200                config.width as f32 / config.height as f32,
201                0.001,
202                1000.0,
203            );
204
205            Uniforms {
206                view_inverse: view.inverse(),
207                proj_inverse: proj.inverse(),
208            }
209        };
210
211        let uniform_buf = device.create_buffer_init(&wgpu::util::BufferInitDescriptor {
212            label: Some("Uniform Buffer"),
213            contents: bytemuck::cast_slice(&[uniforms]),
214            usage: wgpu::BufferUsages::UNIFORM,
215        });
216
217        let (vertex_data, index_data) = create_vertices();
218
219        let vertex_buf = device.create_buffer_init(&wgpu::util::BufferInitDescriptor {
220            label: Some("Vertex Buffer"),
221            contents: bytemuck::cast_slice(&vertex_data),
222            usage: wgpu::BufferUsages::VERTEX | wgpu::BufferUsages::BLAS_INPUT,
223        });
224
225        let index_buf = device.create_buffer_init(&wgpu::util::BufferInitDescriptor {
226            label: Some("Index Buffer"),
227            contents: bytemuck::cast_slice(&index_data),
228            usage: wgpu::BufferUsages::INDEX | wgpu::BufferUsages::BLAS_INPUT,
229        });
230
231        let blas_geo_size_desc = wgpu::BlasTriangleGeometrySizeDescriptor {
232            vertex_format: wgpu::VertexFormat::Float32x3,
233            vertex_count: vertex_data.len() as u32,
234            index_format: Some(wgpu::IndexFormat::Uint16),
235            index_count: Some(index_data.len() as u32),
236            flags: wgpu::AccelerationStructureGeometryFlags::OPAQUE,
237        };
238
239        let blas = device.create_blas(
240            &wgpu::CreateBlasDescriptor {
241                label: None,
242                flags: wgpu::AccelerationStructureFlags::PREFER_FAST_TRACE
243                    | wgpu::AccelerationStructureFlags::ALLOW_RAY_HIT_VERTEX_RETURN,
244                update_mode: wgpu::AccelerationStructureUpdateMode::Build,
245            },
246            wgpu::BlasGeometrySizeDescriptors::Triangles {
247                descriptors: vec![blas_geo_size_desc.clone()],
248            },
249        );
250
251        let mut tlas = device.create_tlas(&wgpu::CreateTlasDescriptor {
252            label: None,
253            flags: wgpu::AccelerationStructureFlags::PREFER_FAST_TRACE
254                | wgpu::AccelerationStructureFlags::ALLOW_RAY_HIT_VERTEX_RETURN,
255            update_mode: wgpu::AccelerationStructureUpdateMode::Build,
256            max_instances: side_count * side_count,
257        });
258
259        let shader = device.create_shader_module(wgpu::ShaderModuleDescriptor {
260            label: Some("rt_computer"),
261            source: wgpu::ShaderSource::Wgsl(Cow::Borrowed(include_str!("shader.wgsl"))),
262        });
263
264        let blit_shader = device.create_shader_module(wgpu::ShaderModuleDescriptor {
265            label: Some("blit"),
266            source: wgpu::ShaderSource::Wgsl(Cow::Borrowed(include_str!("blit.wgsl"))),
267        });
268
269        let compute_pipeline = device.create_compute_pipeline(&wgpu::ComputePipelineDescriptor {
270            label: Some("rt"),
271            layout: None,
272            module: &shader,
273            entry_point: None,
274            compilation_options: Default::default(),
275            cache: None,
276        });
277
278        let compute_bind_group_layout = compute_pipeline.get_bind_group_layout(0);
279
280        let compute_bind_group = device.create_bind_group(&wgpu::BindGroupDescriptor {
281            label: None,
282            layout: &compute_bind_group_layout,
283            entries: &[
284                wgpu::BindGroupEntry {
285                    binding: 0,
286                    resource: wgpu::BindingResource::TextureView(&rt_view),
287                },
288                wgpu::BindGroupEntry {
289                    binding: 1,
290                    resource: uniform_buf.as_entire_binding(),
291                },
292                wgpu::BindGroupEntry {
293                    binding: 2,
294                    resource: wgpu::BindingResource::AccelerationStructure(&tlas),
295                },
296            ],
297        });
298
299        let blit_pipeline = device.create_render_pipeline(&wgpu::RenderPipelineDescriptor {
300            label: Some("blit"),
301            layout: None,
302            vertex: wgpu::VertexState {
303                module: &blit_shader,
304                entry_point: Some("vs_main"),
305                compilation_options: Default::default(),
306                buffers: &[],
307            },
308            fragment: Some(wgpu::FragmentState {
309                module: &blit_shader,
310                entry_point: Some("fs_main"),
311                compilation_options: Default::default(),
312                targets: &[Some(config.format.into())],
313            }),
314            primitive: wgpu::PrimitiveState {
315                topology: wgpu::PrimitiveTopology::TriangleList,
316                ..Default::default()
317            },
318            depth_stencil: None,
319            multisample: wgpu::MultisampleState::default(),
320            multiview: None,
321            cache: None,
322        });
323
324        let blit_bind_group_layout = blit_pipeline.get_bind_group_layout(0);
325
326        let blit_bind_group = device.create_bind_group(&wgpu::BindGroupDescriptor {
327            label: None,
328            layout: &blit_bind_group_layout,
329            entries: &[
330                wgpu::BindGroupEntry {
331                    binding: 0,
332                    resource: wgpu::BindingResource::TextureView(&rt_view),
333                },
334                wgpu::BindGroupEntry {
335                    binding: 1,
336                    resource: wgpu::BindingResource::Sampler(&sampler),
337                },
338            ],
339        });
340
341        let dist = 3.0;
342
343        for x in 0..side_count {
344            for y in 0..side_count {
345                tlas[(x + y * side_count) as usize] = Some(wgpu::TlasInstance::new(
346                    &blas,
347                    affine_to_rows(&Affine3A::from_rotation_translation(
348                        Quat::from_rotation_y(45.9_f32.to_radians()),
349                        Vec3 {
350                            x: x as f32 * dist,
351                            y: y as f32 * dist,
352                            z: -30.0,
353                        },
354                    )),
355                    0,
356                    0xff,
357                ));
358            }
359        }
360
361        let mut encoder =
362            device.create_command_encoder(&wgpu::CommandEncoderDescriptor { label: None });
363
364        encoder.build_acceleration_structures(
365            iter::once(&wgpu::BlasBuildEntry {
366                blas: &blas,
367                geometry: wgpu::BlasGeometries::TriangleGeometries(vec![
368                    wgpu::BlasTriangleGeometry {
369                        size: &blas_geo_size_desc,
370                        vertex_buffer: &vertex_buf,
371                        first_vertex: 0,
372                        vertex_stride: mem::size_of::<Vertex>() as u64,
373                        index_buffer: Some(&index_buf),
374                        first_index: Some(0),
375                        transform_buffer: None,
376                        transform_buffer_offset: None,
377                    },
378                ]),
379            }),
380            iter::once(&tlas),
381        );
382
383        queue.submit(Some(encoder.finish()));
384
385        Example {
386            rt_target,
387            tlas,
388            compute_pipeline,
389            compute_bind_group,
390            blit_pipeline,
391            blit_bind_group,
392            animation_timer: utils::AnimationTimer::default(),
393        }
394    }
395
396    fn update(&mut self, _event: winit::event::WindowEvent) {
397        //empty
398    }
399
400    fn resize(
401        &mut self,
402        _config: &wgpu::SurfaceConfiguration,
403        _device: &wgpu::Device,
404        _queue: &wgpu::Queue,
405    ) {
406    }
407
408    fn render(&mut self, view: &wgpu::TextureView, device: &wgpu::Device, queue: &wgpu::Queue) {
409        device.push_error_scope(wgpu::ErrorFilter::Validation);
410
411        let anim_time = self.animation_timer.time();
412
413        self.tlas[0].as_mut().unwrap().transform =
414            affine_to_rows(&Affine3A::from_rotation_translation(
415                Quat::from_euler(
416                    glam::EulerRot::XYZ,
417                    anim_time * 0.342,
418                    anim_time * 0.254,
419                    anim_time * 0.832,
420                ),
421                Vec3 {
422                    x: 0.0,
423                    y: 0.0,
424                    z: -6.0,
425                },
426            ));
427
428        let mut encoder =
429            device.create_command_encoder(&wgpu::CommandEncoderDescriptor { label: None });
430
431        encoder.build_acceleration_structures(iter::empty(), iter::once(&self.tlas));
432
433        {
434            let mut cpass = encoder.begin_compute_pass(&wgpu::ComputePassDescriptor {
435                label: None,
436                timestamp_writes: None,
437            });
438            cpass.set_pipeline(&self.compute_pipeline);
439            cpass.set_bind_group(0, Some(&self.compute_bind_group), &[]);
440            cpass.dispatch_workgroups(self.rt_target.width() / 8, self.rt_target.height() / 8, 1);
441        }
442
443        {
444            let mut rpass = encoder.begin_render_pass(&wgpu::RenderPassDescriptor {
445                label: None,
446                color_attachments: &[Some(wgpu::RenderPassColorAttachment {
447                    view,
448                    depth_slice: None,
449                    resolve_target: None,
450                    ops: wgpu::Operations {
451                        load: wgpu::LoadOp::Clear(wgpu::Color::GREEN),
452                        store: StoreOp::Store,
453                    },
454                })],
455                depth_stencil_attachment: None,
456                timestamp_writes: None,
457                occlusion_query_set: None,
458            });
459
460            rpass.set_pipeline(&self.blit_pipeline);
461            rpass.set_bind_group(0, Some(&self.blit_bind_group), &[]);
462            rpass.draw(0..3, 0..1);
463        }
464
465        queue.submit(Some(encoder.finish()));
466    }
467}
468
469pub fn main() {
470    crate::framework::run::<Example>("ray-cube");
471}
472
473#[cfg(test)]
474#[wgpu_test::gpu_test]
475pub static TEST: crate::framework::ExampleTestParams = crate::framework::ExampleTestParams {
476    name: "ray_cube_normals",
477    image_path: "/examples/features/src/ray_cube_normals/screenshot.png",
478    width: 1024,
479    height: 768,
480    optional_features: wgpu::Features::default(),
481    base_test_parameters: wgpu_test::TestParameters::default().expect_fail(
482        wgpu_test::FailureCase::backend_adapter(wgpu::Backends::VULKAN, "AMD")
483            .panic("Image data mismatch"),
484    ),
485    comparisons: &[wgpu_test::ComparisonType::Mean(0.02)],
486    _phantom: std::marker::PhantomData::<Example>,
487};