wgpu_examples/ray_traced_triangle/
mod.rs

1use glam::{Mat4, Vec3};
2use std::mem;
3use wgpu::util::{BufferInitDescriptor, DeviceExt};
4use wgpu::{include_wgsl, BufferUsages, IndexFormat, SamplerDescriptor};
5use wgpu::{
6    AccelerationStructureFlags, AccelerationStructureUpdateMode, BlasBuildEntry, BlasGeometries,
7    BlasGeometrySizeDescriptors, BlasTriangleGeometry, BlasTriangleGeometrySizeDescriptor,
8    CreateBlasDescriptor, CreateTlasDescriptor, Tlas, TlasInstance,
9};
10
11use crate::utils;
12
13struct Example {
14    tlas: Tlas,
15    compute_pipeline: wgpu::ComputePipeline,
16    blit_pipeline: wgpu::RenderPipeline,
17    bind_group: wgpu::BindGroup,
18    blit_bind_group: wgpu::BindGroup,
19    storage_texture: wgpu::Texture,
20    animation_timer: utils::AnimationTimer,
21}
22
23#[repr(C)]
24#[derive(bytemuck::Pod, bytemuck::Zeroable, Clone, Copy, Debug)]
25struct Uniforms {
26    view_inverse: Mat4,
27    proj_inverse: Mat4,
28}
29
30impl crate::framework::Example for Example {
31    fn required_features() -> wgpu::Features {
32        wgpu::Features::EXPERIMENTAL_RAY_QUERY
33    }
34
35    fn required_limits() -> wgpu::Limits {
36        wgpu::Limits::default().using_minimum_supported_acceleration_structure_values()
37    }
38
39    fn required_downlevel_capabilities() -> wgpu::DownlevelCapabilities {
40        wgpu::DownlevelCapabilities {
41            flags: wgpu::DownlevelFlags::COMPUTE_SHADERS,
42            ..Default::default()
43        }
44    }
45
46    fn init(
47        config: &wgpu::SurfaceConfiguration,
48        _adapter: &wgpu::Adapter,
49        device: &wgpu::Device,
50        queue: &wgpu::Queue,
51    ) -> Self {
52        let shader = device.create_shader_module(include_wgsl!("shader.wgsl"));
53
54        let blit_shader = device.create_shader_module(include_wgsl!("blit.wgsl"));
55
56        let bgl = device.create_bind_group_layout(&wgpu::BindGroupLayoutDescriptor {
57            label: Some("bgl for shader.wgsl"),
58            entries: &[
59                wgpu::BindGroupLayoutEntry {
60                    binding: 0,
61                    visibility: wgpu::ShaderStages::COMPUTE,
62                    ty: wgpu::BindingType::Buffer {
63                        ty: wgpu::BufferBindingType::Uniform,
64                        has_dynamic_offset: false,
65                        min_binding_size: None,
66                    },
67                    count: None,
68                },
69                wgpu::BindGroupLayoutEntry {
70                    binding: 1,
71                    visibility: wgpu::ShaderStages::COMPUTE,
72                    ty: wgpu::BindingType::StorageTexture {
73                        access: wgpu::StorageTextureAccess::WriteOnly,
74                        format: wgpu::TextureFormat::Rgba8Unorm,
75                        view_dimension: wgpu::TextureViewDimension::D2,
76                    },
77                    count: None,
78                },
79                wgpu::BindGroupLayoutEntry {
80                    binding: 2,
81                    visibility: wgpu::ShaderStages::COMPUTE,
82                    ty: wgpu::BindingType::AccelerationStructure {
83                        vertex_return: false,
84                    },
85                    count: None,
86                },
87            ],
88        });
89
90        let blit_bgl = device.create_bind_group_layout(&wgpu::BindGroupLayoutDescriptor {
91            label: Some("bgl for blit.wgsl"),
92            entries: &[
93                wgpu::BindGroupLayoutEntry {
94                    binding: 0,
95                    visibility: wgpu::ShaderStages::VERTEX_FRAGMENT,
96                    ty: wgpu::BindingType::Texture {
97                        sample_type: wgpu::TextureSampleType::Float { filterable: false },
98                        view_dimension: wgpu::TextureViewDimension::D2,
99                        multisampled: false,
100                    },
101                    count: None,
102                },
103                wgpu::BindGroupLayoutEntry {
104                    binding: 1,
105                    visibility: wgpu::ShaderStages::VERTEX_FRAGMENT,
106                    ty: wgpu::BindingType::Sampler(wgpu::SamplerBindingType::NonFiltering),
107                    count: None,
108                },
109            ],
110        });
111
112        let vertices: [f32; 9] = [1.0, 1.0, 0.0, -1.0, 1.0, 0.0, 0.0, -1.0, 0.0];
113
114        let indices: [u32; 3] = [0, 1, 2];
115
116        let vertex_buffer = device.create_buffer_init(&BufferInitDescriptor {
117            label: Some("vertex buffer"),
118            contents: bytemuck::cast_slice(&vertices),
119            usage: BufferUsages::BLAS_INPUT,
120        });
121
122        let index_buffer = device.create_buffer_init(&BufferInitDescriptor {
123            label: Some("vertex buffer"),
124            contents: bytemuck::cast_slice(&indices),
125            usage: BufferUsages::BLAS_INPUT,
126        });
127
128        let blas_size_desc = BlasTriangleGeometrySizeDescriptor {
129            vertex_format: wgpu::VertexFormat::Float32x3,
130            // 3 coordinates per vertex
131            vertex_count: (vertices.len() / 3) as u32,
132            index_format: Some(IndexFormat::Uint32),
133            index_count: Some(indices.len() as u32),
134            flags: wgpu::AccelerationStructureGeometryFlags::OPAQUE,
135        };
136
137        let blas = device.create_blas(
138            &CreateBlasDescriptor {
139                label: None,
140                flags: AccelerationStructureFlags::PREFER_FAST_TRACE,
141                update_mode: AccelerationStructureUpdateMode::Build,
142            },
143            BlasGeometrySizeDescriptors::Triangles {
144                descriptors: vec![blas_size_desc.clone()],
145            },
146        );
147
148        let mut tlas = device.create_tlas(&CreateTlasDescriptor {
149            label: None,
150            max_instances: 3,
151            flags: AccelerationStructureFlags::PREFER_FAST_TRACE,
152            update_mode: AccelerationStructureUpdateMode::Build,
153        });
154
155        tlas[0] = Some(TlasInstance::new(
156            &blas,
157            Mat4::from_translation(Vec3 {
158                x: 0.0,
159                y: 0.0,
160                z: 0.0,
161            })
162            .transpose()
163            .to_cols_array()[..12]
164                .try_into()
165                .unwrap(),
166            0,
167            0xff,
168        ));
169
170        tlas[1] = Some(TlasInstance::new(
171            &blas,
172            Mat4::from_translation(Vec3 {
173                x: -1.0,
174                y: -1.0,
175                z: -2.0,
176            })
177            .transpose()
178            .to_cols_array()[..12]
179                .try_into()
180                .unwrap(),
181            0,
182            0xff,
183        ));
184
185        tlas[2] = Some(TlasInstance::new(
186            &blas,
187            Mat4::from_translation(Vec3 {
188                x: 1.0,
189                y: -1.0,
190                z: -2.0,
191            })
192            .transpose()
193            .to_cols_array()[..12]
194                .try_into()
195                .unwrap(),
196            0,
197            0xff,
198        ));
199
200        let uniforms = {
201            let view = Mat4::look_at_rh(Vec3::new(0.0, 0.0, 2.5), Vec3::ZERO, Vec3::Y);
202            let proj = Mat4::perspective_rh(59.0_f32.to_radians(), 1.0, 0.001, 1000.0);
203
204            Uniforms {
205                view_inverse: view.inverse(),
206                proj_inverse: proj.inverse(),
207            }
208        };
209
210        let uniform_buffer = device.create_buffer_init(&BufferInitDescriptor {
211            label: None,
212            contents: bytemuck::cast_slice(&[uniforms]),
213            usage: BufferUsages::UNIFORM,
214        });
215
216        let mut encoder = device.create_command_encoder(&wgpu::CommandEncoderDescriptor::default());
217
218        encoder.build_acceleration_structures(
219            Some(&BlasBuildEntry {
220                blas: &blas,
221                geometry: BlasGeometries::TriangleGeometries(vec![BlasTriangleGeometry {
222                    size: &blas_size_desc,
223                    vertex_buffer: &vertex_buffer,
224                    first_vertex: 0,
225                    vertex_stride: mem::size_of::<[f32; 3]>() as wgpu::BufferAddress,
226                    // in this case since one triangle gets no compression from an index buffer `index_buffer` and `first_index` could be `None`.
227                    index_buffer: Some(&index_buffer),
228                    first_index: Some(0),
229                    transform_buffer: None,
230                    transform_buffer_offset: None,
231                }]),
232            }),
233            Some(&tlas),
234        );
235
236        queue.submit(Some(encoder.finish()));
237
238        let storage_tex = device.create_texture(&wgpu::TextureDescriptor {
239            label: None,
240            size: wgpu::Extent3d {
241                width: config.width,
242                height: config.height,
243                depth_or_array_layers: 1,
244            },
245            mip_level_count: 1,
246            sample_count: 1,
247            dimension: wgpu::TextureDimension::D2,
248            format: wgpu::TextureFormat::Rgba8Unorm,
249            usage: wgpu::TextureUsages::STORAGE_BINDING | wgpu::TextureUsages::TEXTURE_BINDING,
250            view_formats: &[],
251        });
252
253        let sampler = device.create_sampler(&SamplerDescriptor {
254            label: None,
255            address_mode_u: Default::default(),
256            address_mode_v: Default::default(),
257            address_mode_w: Default::default(),
258            mag_filter: wgpu::FilterMode::Nearest,
259            min_filter: wgpu::FilterMode::Nearest,
260            mipmap_filter: wgpu::FilterMode::Nearest,
261            lod_min_clamp: 1.0,
262            lod_max_clamp: 1.0,
263            compare: None,
264            anisotropy_clamp: 1,
265            border_color: None,
266        });
267
268        let compute_pipeline_layout =
269            device.create_pipeline_layout(&wgpu::PipelineLayoutDescriptor {
270                label: Some("pipeline layout for shader.wgsl"),
271                bind_group_layouts: &[&bgl],
272                push_constant_ranges: &[],
273            });
274
275        let compute_pipeline = device.create_compute_pipeline(&wgpu::ComputePipelineDescriptor {
276            label: Some("pipeline for shader.wgsl"),
277            layout: Some(&compute_pipeline_layout),
278            module: &shader,
279            entry_point: None,
280            compilation_options: Default::default(),
281            cache: None,
282        });
283
284        let blit_pipeline_layout = device.create_pipeline_layout(&wgpu::PipelineLayoutDescriptor {
285            label: Some("pipeline layout for blit.wgsl"),
286            bind_group_layouts: &[&blit_bgl],
287            push_constant_ranges: &[],
288        });
289
290        let blit_pipeline = device.create_render_pipeline(&wgpu::RenderPipelineDescriptor {
291            label: Some("pipeline for blit.wgsl"),
292            layout: Some(&blit_pipeline_layout),
293            vertex: wgpu::VertexState {
294                module: &blit_shader,
295                entry_point: None,
296                compilation_options: Default::default(),
297                buffers: &[],
298            },
299            primitive: Default::default(),
300            depth_stencil: None,
301            multisample: Default::default(),
302            fragment: Some(wgpu::FragmentState {
303                module: &blit_shader,
304                entry_point: None,
305                compilation_options: Default::default(),
306                targets: &[Some(wgpu::ColorTargetState {
307                    format: config.format,
308                    blend: None,
309                    write_mask: Default::default(),
310                })],
311            }),
312            multiview: None,
313            cache: None,
314        });
315
316        let bind_group = device.create_bind_group(&wgpu::BindGroupDescriptor {
317            label: Some("bind group for shader.wgsl"),
318            layout: &bgl,
319            entries: &[
320                wgpu::BindGroupEntry {
321                    binding: 0,
322                    resource: uniform_buffer.as_entire_binding(),
323                },
324                wgpu::BindGroupEntry {
325                    binding: 1,
326                    resource: wgpu::BindingResource::TextureView(
327                        &storage_tex.create_view(&wgpu::TextureViewDescriptor::default()),
328                    ),
329                },
330                wgpu::BindGroupEntry {
331                    binding: 2,
332                    resource: wgpu::BindingResource::AccelerationStructure(&tlas),
333                },
334            ],
335        });
336
337        let blit_bind_group = device.create_bind_group(&wgpu::BindGroupDescriptor {
338            label: Some("bind group for blit.wgsl"),
339            layout: &blit_bgl,
340            entries: &[
341                wgpu::BindGroupEntry {
342                    binding: 0,
343                    resource: wgpu::BindingResource::TextureView(
344                        &storage_tex.create_view(&wgpu::TextureViewDescriptor::default()),
345                    ),
346                },
347                wgpu::BindGroupEntry {
348                    binding: 1,
349                    resource: wgpu::BindingResource::Sampler(&sampler),
350                },
351            ],
352        });
353
354        Self {
355            tlas,
356            compute_pipeline,
357            blit_pipeline,
358            bind_group,
359            blit_bind_group,
360            storage_texture: storage_tex,
361            animation_timer: utils::AnimationTimer::default(),
362        }
363    }
364
365    fn resize(
366        &mut self,
367        _config: &wgpu::SurfaceConfiguration,
368        _device: &wgpu::Device,
369        _queue: &wgpu::Queue,
370    ) {
371    }
372
373    fn update(&mut self, _event: winit::event::WindowEvent) {}
374
375    fn render(&mut self, view: &wgpu::TextureView, device: &wgpu::Device, queue: &wgpu::Queue) {
376        self.tlas[0].as_mut().unwrap().transform =
377            Mat4::from_rotation_y(self.animation_timer.time())
378                .transpose()
379                .to_cols_array()[..12]
380                .try_into()
381                .unwrap();
382
383        let mut encoder =
384            device.create_command_encoder(&wgpu::CommandEncoderDescriptor { label: None });
385
386        encoder.build_acceleration_structures(None, Some(&self.tlas));
387
388        {
389            let mut cpass = encoder.begin_compute_pass(&wgpu::ComputePassDescriptor {
390                label: None,
391                timestamp_writes: None,
392            });
393            cpass.set_pipeline(&self.compute_pipeline);
394            cpass.set_bind_group(0, Some(&self.bind_group), &[]);
395            cpass.dispatch_workgroups(
396                self.storage_texture.width() / 8,
397                self.storage_texture.height() / 8,
398                1,
399            );
400        }
401
402        {
403            let mut rpass = encoder.begin_render_pass(&wgpu::RenderPassDescriptor {
404                label: None,
405                color_attachments: &[Some(wgpu::RenderPassColorAttachment {
406                    view,
407                    depth_slice: None,
408                    resolve_target: None,
409                    ops: wgpu::Operations {
410                        load: wgpu::LoadOp::Clear(wgpu::Color::GREEN),
411                        store: wgpu::StoreOp::Store,
412                    },
413                })],
414                depth_stencil_attachment: None,
415                timestamp_writes: None,
416                occlusion_query_set: None,
417            });
418
419            rpass.set_pipeline(&self.blit_pipeline);
420            rpass.set_bind_group(0, Some(&self.blit_bind_group), &[]);
421            rpass.draw(0..3, 0..1);
422        }
423
424        queue.submit(Some(encoder.finish()));
425    }
426}
427
428pub fn main() {
429    crate::framework::run::<Example>("ray-traced-triangle");
430}
431
432#[cfg(test)]
433#[wgpu_test::gpu_test]
434static TEST: crate::framework::ExampleTestParams = crate::framework::ExampleTestParams {
435    name: "ray_traced_triangle",
436    image_path: "/examples/features/src/ray_traced_triangle/screenshot.png",
437    width: 1024,
438    height: 768,
439    optional_features: wgpu::Features::default(),
440    base_test_parameters: wgpu_test::TestParameters::default(),
441    comparisons: &[wgpu_test::ComparisonType::Mean(0.02)],
442    _phantom: std::marker::PhantomData::<Example>,
443};