wgpu_examples/ray_shadows/
mod.rs

1use std::{borrow::Cow, future::Future, iter, mem, pin::Pin, task};
2
3use bytemuck::{Pod, Zeroable};
4use glam::{Mat4, Vec3};
5use wgpu::util::DeviceExt;
6use wgpu::{vertex_attr_array, IndexFormat, VertexBufferLayout};
7
8use crate::utils;
9
10// from cube
11#[repr(C)]
12#[derive(Clone, Copy, Pod, Zeroable)]
13struct Vertex {
14    _pos: [f32; 3],
15    _normal: [f32; 3],
16}
17
18fn vertex(pos: [f32; 3], normal: [f32; 3]) -> Vertex {
19    Vertex {
20        _pos: pos,
21        _normal: normal,
22    }
23}
24
25fn create_vertices() -> (Vec<Vertex>, Vec<u16>) {
26    let vertex_data = [
27        // base
28        vertex([-1.0, 0.0, -1.0], [0.0, 1.0, 0.0]),
29        vertex([-1.0, 0.0, 1.0], [0.0, 1.0, 0.0]),
30        vertex([1.0, 0.0, -1.0], [0.0, 1.0, 0.0]),
31        vertex([1.0, 0.0, 1.0], [0.0, 1.0, 0.0]),
32        //shadow caster
33        vertex([-(1.0 / 3.0), 0.0, 1.0], [0.0, 0.0, 1.0]),
34        vertex([-(1.0 / 3.0), 2.0 / 3.0, 1.0], [0.0, 0.0, 1.0]),
35        vertex([1.0 / 3.0, 0.0, 1.0], [0.0, 0.0, 1.0]),
36        vertex([1.0 / 3.0, 2.0 / 3.0, 1.0], [0.0, 0.0, 1.0]),
37    ];
38
39    let index_data: &[u16] = &[
40        0, 1, 2, 2, 3, 1, //base
41        4, 5, 6, 6, 7, 5,
42    ];
43
44    (vertex_data.to_vec(), index_data.to_vec())
45}
46
47#[repr(C)]
48#[derive(Clone, Copy, Pod, Zeroable)]
49struct Uniforms {
50    view_inverse: Mat4,
51    proj_inverse: Mat4,
52    vertex: Mat4,
53}
54
55/// A wrapper for `pop_error_scope` futures that panics if an error occurs.
56///
57/// Given a future `inner` of an `Option<E>` for some error type `E`,
58/// wait for the future to be ready, and panic if its value is `Some`.
59///
60/// This can be done simpler with `FutureExt`, but we don't want to add
61/// a dependency just for this small case.
62struct ErrorFuture<F> {
63    inner: F,
64}
65impl<F: Future<Output = Option<wgpu::Error>>> Future for ErrorFuture<F> {
66    type Output = ();
67    fn poll(self: Pin<&mut Self>, cx: &mut task::Context<'_>) -> task::Poll<()> {
68        let inner = unsafe { self.map_unchecked_mut(|me| &mut me.inner) };
69        inner.poll(cx).map(|error| {
70            if let Some(e) = error {
71                panic!("Rendering {e}");
72            }
73        })
74    }
75}
76
77struct Example {
78    uniforms: Uniforms,
79    uniform_buf: wgpu::Buffer,
80    vertex_buf: wgpu::Buffer,
81    index_buf: wgpu::Buffer,
82    pipeline: wgpu::RenderPipeline,
83    bind_group: wgpu::BindGroup,
84    animation_timer: utils::AnimationTimer,
85}
86
87const CAM_LOOK_AT: Vec3 = Vec3::new(0.0, 1.0, -1.5);
88
89fn create_matrix(config: &wgpu::SurfaceConfiguration) -> Uniforms {
90    let view = Mat4::look_at_rh(CAM_LOOK_AT, Vec3::ZERO, Vec3::Y);
91    let proj = Mat4::perspective_rh(
92        59.0_f32.to_radians(),
93        config.width as f32 / config.height as f32,
94        0.1,
95        1000.0,
96    );
97
98    Uniforms {
99        view_inverse: view.inverse(),
100        proj_inverse: proj.inverse(),
101        vertex: (proj * view),
102    }
103}
104
105impl crate::framework::Example for Example {
106    fn required_features() -> wgpu::Features {
107        wgpu::Features::EXPERIMENTAL_RAY_QUERY | wgpu::Features::PUSH_CONSTANTS
108    }
109
110    fn required_downlevel_capabilities() -> wgpu::DownlevelCapabilities {
111        wgpu::DownlevelCapabilities {
112            flags: wgpu::DownlevelFlags::COMPUTE_SHADERS,
113            ..Default::default()
114        }
115    }
116
117    fn required_limits() -> wgpu::Limits {
118        wgpu::Limits {
119            max_push_constant_size: 12,
120            ..wgpu::Limits::default()
121        }
122        .using_minimum_supported_acceleration_structure_values()
123    }
124
125    fn init(
126        config: &wgpu::SurfaceConfiguration,
127        _adapter: &wgpu::Adapter,
128        device: &wgpu::Device,
129        queue: &wgpu::Queue,
130    ) -> Self {
131        let uniforms = create_matrix(config);
132
133        let uniform_buf = device.create_buffer_init(&wgpu::util::BufferInitDescriptor {
134            label: Some("Uniform Buffer"),
135            contents: bytemuck::cast_slice(&[uniforms]),
136            usage: wgpu::BufferUsages::UNIFORM | wgpu::BufferUsages::COPY_DST,
137        });
138
139        let (vertex_data, index_data) = create_vertices();
140
141        let vertex_buf = device.create_buffer_init(&wgpu::util::BufferInitDescriptor {
142            label: Some("Vertex Buffer"),
143            contents: bytemuck::cast_slice(&vertex_data),
144            usage: wgpu::BufferUsages::VERTEX | wgpu::BufferUsages::BLAS_INPUT,
145        });
146
147        let index_buf = device.create_buffer_init(&wgpu::util::BufferInitDescriptor {
148            label: Some("Index Buffer"),
149            contents: bytemuck::cast_slice(&index_data),
150            usage: wgpu::BufferUsages::INDEX | wgpu::BufferUsages::BLAS_INPUT,
151        });
152
153        let blas_geo_size_desc = wgpu::BlasTriangleGeometrySizeDescriptor {
154            vertex_format: wgpu::VertexFormat::Float32x3,
155            vertex_count: vertex_data.len() as u32,
156            index_format: Some(wgpu::IndexFormat::Uint16),
157            index_count: Some(index_data.len() as u32),
158            flags: wgpu::AccelerationStructureGeometryFlags::OPAQUE,
159        };
160
161        let blas = device.create_blas(
162            &wgpu::CreateBlasDescriptor {
163                label: None,
164                flags: wgpu::AccelerationStructureFlags::PREFER_FAST_TRACE,
165                update_mode: wgpu::AccelerationStructureUpdateMode::Build,
166            },
167            wgpu::BlasGeometrySizeDescriptors::Triangles {
168                descriptors: vec![blas_geo_size_desc.clone()],
169            },
170        );
171
172        let mut tlas = device.create_tlas(&wgpu::CreateTlasDescriptor {
173            label: None,
174            flags: wgpu::AccelerationStructureFlags::PREFER_FAST_TRACE,
175            update_mode: wgpu::AccelerationStructureUpdateMode::Build,
176            max_instances: 1,
177        });
178
179        let shader = device.create_shader_module(wgpu::ShaderModuleDescriptor {
180            label: None,
181            source: wgpu::ShaderSource::Wgsl(Cow::Borrowed(include_str!("shader.wgsl"))),
182        });
183
184        let bind_group_layout = device.create_bind_group_layout(&wgpu::BindGroupLayoutDescriptor {
185            label: None,
186            entries: &[
187                wgpu::BindGroupLayoutEntry {
188                    binding: 0,
189                    visibility: wgpu::ShaderStages::VERTEX_FRAGMENT,
190                    ty: wgpu::BindingType::Buffer {
191                        ty: wgpu::BufferBindingType::Uniform,
192                        has_dynamic_offset: false,
193                        min_binding_size: None,
194                    },
195                    count: None,
196                },
197                wgpu::BindGroupLayoutEntry {
198                    binding: 1,
199                    visibility: wgpu::ShaderStages::FRAGMENT,
200                    ty: wgpu::BindingType::AccelerationStructure {
201                        vertex_return: false,
202                    },
203                    count: None,
204                },
205            ],
206        });
207
208        let pipeline_layout = device.create_pipeline_layout(&wgpu::PipelineLayoutDescriptor {
209            label: None,
210            bind_group_layouts: &[&bind_group_layout],
211            push_constant_ranges: &[wgpu::PushConstantRange {
212                stages: wgpu::ShaderStages::FRAGMENT,
213                range: 0..12,
214            }],
215        });
216
217        let pipeline = device.create_render_pipeline(&wgpu::RenderPipelineDescriptor {
218            label: None,
219            layout: Some(&pipeline_layout),
220            vertex: wgpu::VertexState {
221                module: &shader,
222                entry_point: Some("vs_main"),
223                compilation_options: Default::default(),
224                buffers: &[VertexBufferLayout {
225                    array_stride: mem::size_of::<Vertex>() as wgpu::BufferAddress,
226                    step_mode: Default::default(),
227                    attributes: &vertex_attr_array![0 => Float32x3, 1 => Float32x3],
228                }],
229            },
230            fragment: Some(wgpu::FragmentState {
231                module: &shader,
232                entry_point: Some("fs_main"),
233                compilation_options: Default::default(),
234                targets: &[Some(config.format.into())],
235            }),
236            primitive: wgpu::PrimitiveState {
237                topology: wgpu::PrimitiveTopology::TriangleList,
238                ..Default::default()
239            },
240            depth_stencil: None,
241            multisample: wgpu::MultisampleState::default(),
242            multiview: None,
243            cache: None,
244        });
245
246        tlas[0] = Some(wgpu::TlasInstance::new(
247            &blas,
248            [1.0, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0],
249            0,
250            0xFF,
251        ));
252
253        let mut encoder =
254            device.create_command_encoder(&wgpu::CommandEncoderDescriptor { label: None });
255
256        encoder.build_acceleration_structures(
257            iter::once(&wgpu::BlasBuildEntry {
258                blas: &blas,
259                geometry: wgpu::BlasGeometries::TriangleGeometries(vec![
260                    wgpu::BlasTriangleGeometry {
261                        size: &blas_geo_size_desc,
262                        vertex_buffer: &vertex_buf,
263                        first_vertex: 0,
264                        vertex_stride: mem::size_of::<Vertex>() as u64,
265                        index_buffer: Some(&index_buf),
266                        first_index: Some(0),
267                        transform_buffer: None,
268                        transform_buffer_offset: None,
269                    },
270                ]),
271            }),
272            iter::once(&tlas),
273        );
274
275        queue.submit(Some(encoder.finish()));
276
277        let bind_group = device.create_bind_group(&wgpu::BindGroupDescriptor {
278            label: None,
279            layout: &bind_group_layout,
280            entries: &[
281                wgpu::BindGroupEntry {
282                    binding: 0,
283                    resource: uniform_buf.as_entire_binding(),
284                },
285                wgpu::BindGroupEntry {
286                    binding: 1,
287                    resource: tlas.as_binding(),
288                },
289            ],
290        });
291
292        let animation_timer = utils::AnimationTimer::default();
293
294        Example {
295            uniforms,
296            uniform_buf,
297            vertex_buf,
298            index_buf,
299            pipeline,
300            bind_group,
301            animation_timer,
302        }
303    }
304
305    fn update(&mut self, _event: winit::event::WindowEvent) {}
306
307    fn resize(
308        &mut self,
309        config: &wgpu::SurfaceConfiguration,
310        _device: &wgpu::Device,
311        queue: &wgpu::Queue,
312    ) {
313        self.uniforms = create_matrix(config);
314
315        queue.write_buffer(&self.uniform_buf, 0, bytemuck::cast_slice(&[self.uniforms]));
316        queue.submit(None);
317    }
318
319    fn render(&mut self, view: &wgpu::TextureView, device: &wgpu::Device, queue: &wgpu::Queue) {
320        //device.push_error_scope(wgpu::ErrorFilter::Validation);
321        const LIGHT_DISTANCE: f32 = 5.0;
322        const TIME_SCALE: f32 = -0.2;
323        const INITIAL_TIME: f32 = 1.0;
324        let time = self.animation_timer.time();
325        let cos = (time * TIME_SCALE + INITIAL_TIME).cos() * LIGHT_DISTANCE;
326        let sin = (time * TIME_SCALE + INITIAL_TIME).sin() * LIGHT_DISTANCE;
327
328        let mut encoder =
329            device.create_command_encoder(&wgpu::CommandEncoderDescriptor { label: None });
330
331        {
332            let mut rpass = encoder.begin_render_pass(&wgpu::RenderPassDescriptor {
333                label: None,
334                color_attachments: &[Some(wgpu::RenderPassColorAttachment {
335                    view,
336                    depth_slice: None,
337                    resolve_target: None,
338                    ops: wgpu::Operations {
339                        load: wgpu::LoadOp::Clear(wgpu::Color {
340                            r: 0.1,
341                            g: 0.1,
342                            b: 0.1,
343                            a: 1.0,
344                        }),
345                        store: wgpu::StoreOp::Store,
346                    },
347                })],
348                depth_stencil_attachment: None,
349                timestamp_writes: None,
350                occlusion_query_set: None,
351            });
352
353            rpass.set_pipeline(&self.pipeline);
354            rpass.set_bind_group(0, Some(&self.bind_group), &[]);
355            rpass.set_push_constants(wgpu::ShaderStages::FRAGMENT, 0, &0.0_f32.to_ne_bytes());
356            rpass.set_push_constants(wgpu::ShaderStages::FRAGMENT, 4, &cos.to_ne_bytes());
357            rpass.set_push_constants(wgpu::ShaderStages::FRAGMENT, 8, &sin.to_ne_bytes());
358            rpass.set_vertex_buffer(0, self.vertex_buf.slice(..));
359            rpass.set_index_buffer(self.index_buf.slice(..), IndexFormat::Uint16);
360            rpass.draw_indexed(0..12, 0, 0..1);
361        }
362        queue.submit(Some(encoder.finish()));
363        device.poll(wgpu::PollType::Wait).unwrap();
364    }
365}
366
367pub fn main() {
368    crate::framework::run::<Example>("ray-shadows");
369}
370
371#[cfg(test)]
372#[wgpu_test::gpu_test]
373pub static TEST: crate::framework::ExampleTestParams = crate::framework::ExampleTestParams {
374    name: "ray_shadows",
375    image_path: "/examples/features/src/ray_shadows/screenshot.png",
376    width: 1024,
377    height: 768,
378    optional_features: wgpu::Features::default(),
379    base_test_parameters: wgpu_test::TestParameters::default(),
380    comparisons: &[wgpu_test::ComparisonType::Mean(0.02)],
381    _phantom: std::marker::PhantomData::<Example>,
382};