wgpu_examples/ray_scene/
mod.rs

1use crate::utils;
2use bytemuck::{Pod, Zeroable};
3use glam::{Mat4, Quat, Vec3};
4use std::f32::consts::PI;
5use std::ops::IndexMut;
6use std::{borrow::Cow, future::Future, iter, mem, ops::Range, pin::Pin, task};
7use wgpu::util::DeviceExt;
8
9// from cube
10#[repr(C)]
11#[derive(Debug, Clone, Copy, Pod, Zeroable, Default)]
12struct Vertex {
13    pos: [f32; 3],
14    _p0: [u32; 1],
15    normal: [f32; 3],
16    _p1: [u32; 1],
17    uv: [f32; 2],
18    _p2: [u32; 2],
19}
20
21#[repr(C)]
22#[derive(Clone, Copy, Pod, Zeroable)]
23struct Uniforms {
24    view_inverse: Mat4,
25    proj_inverse: Mat4,
26}
27
28/// A wrapper for `pop_error_scope` futures that panics if an error occurs.
29///
30/// Given a future `inner` of an `Option<E>` for some error type `E`,
31/// wait for the future to be ready, and panic if its value is `Some`.
32///
33/// This can be done simpler with `FutureExt`, but we don't want to add
34/// a dependency just for this small case.
35struct ErrorFuture<F> {
36    inner: F,
37}
38impl<F: Future<Output = Option<wgpu::Error>>> Future for ErrorFuture<F> {
39    type Output = ();
40    fn poll(self: Pin<&mut Self>, cx: &mut task::Context<'_>) -> task::Poll<()> {
41        let inner = unsafe { self.map_unchecked_mut(|me| &mut me.inner) };
42        inner.poll(cx).map(|error| {
43            if let Some(e) = error {
44                panic!("Rendering {e}");
45            }
46        })
47    }
48}
49
50#[derive(Debug, Clone, Default)]
51struct RawSceneComponents {
52    vertices: Vec<Vertex>,
53    indices: Vec<u32>,
54    geometries: Vec<(Range<usize>, Material)>, // index range, material
55    instances: Vec<(Range<usize>, Range<usize>)>, // vertex range, geometry range
56}
57
58struct SceneComponents {
59    vertices: wgpu::Buffer,
60    indices: wgpu::Buffer,
61    geometries: wgpu::Buffer,
62    instances: wgpu::Buffer,
63    bottom_level_acceleration_structures: Vec<wgpu::Blas>,
64}
65
66#[repr(C)]
67#[derive(Clone, Copy, Pod, Zeroable)]
68struct InstanceEntry {
69    first_vertex: u32,
70    first_geometry: u32,
71    last_geometry: u32,
72    _pad: u32,
73}
74
75#[repr(C)]
76#[derive(Clone, Copy, Pod, Zeroable, Default)]
77struct GeometryEntry {
78    first_index: u32,
79    _p0: [u32; 3],
80    material: Material,
81}
82
83#[repr(C)]
84#[derive(Clone, Copy, Pod, Zeroable, Default, Debug)]
85struct Material {
86    roughness_exponent: f32,
87    metalness: f32,
88    specularity: f32,
89    _p0: [u32; 1],
90    albedo: [f32; 3],
91    _p1: [u32; 1],
92}
93
94fn load_model(scene: &mut RawSceneComponents, path: &str) {
95    let path = env!("CARGO_MANIFEST_DIR").to_string() + "/src" + path;
96    println!("{path}");
97    let mut object = obj::Obj::load(path).unwrap();
98    object.load_mtls().unwrap();
99
100    let data = object.data;
101
102    let start_vertex_index = scene.vertices.len();
103    let start_geometry_index = scene.geometries.len();
104
105    let mut mapping = std::collections::HashMap::<(usize, Option<usize>, usize), usize>::new();
106
107    let mut next_index = 0;
108
109    for object in data.objects {
110        for group in object.groups {
111            let start_index_index = scene.indices.len();
112            for poly in group.polys {
113                for end_index in 2..poly.0.len() {
114                    for &index in &[0, end_index - 1, end_index] {
115                        let obj::IndexTuple(position_id, texture_id, normal_id) = poly.0[index];
116                        let uv = texture_id
117                            .map(|texture_id| data.texture[texture_id])
118                            .unwrap_or_default();
119                        let normal_id = normal_id.expect("normals required");
120
121                        let index = *mapping
122                            .entry((position_id, texture_id, normal_id))
123                            .or_insert(next_index);
124                        if index == next_index {
125                            next_index += 1;
126
127                            scene.vertices.push(Vertex {
128                                pos: data.position[position_id],
129                                uv,
130                                normal: data.normal[normal_id],
131                                ..Default::default()
132                            })
133                        }
134
135                        scene.indices.push(index as u32);
136                    }
137                }
138            }
139
140            let mut material: Material = Default::default();
141
142            if let Some(obj::ObjMaterial::Mtl(mat)) = group.material {
143                if let Some(kd) = mat.kd {
144                    material.albedo = kd;
145                }
146                if let Some(ns) = mat.ns {
147                    material.roughness_exponent = ns;
148                }
149                if let Some(ka) = mat.ka {
150                    material.metalness = ka[0];
151                }
152                if let Some(ks) = mat.ks {
153                    material.specularity = ks[0];
154                }
155            }
156
157            scene
158                .geometries
159                .push((start_index_index..scene.indices.len(), material));
160        }
161    }
162    scene.instances.push((
163        start_vertex_index..scene.vertices.len(),
164        start_geometry_index..scene.geometries.len(),
165    ));
166
167    // dbg!(scene.vertices.len());
168    // dbg!(scene.indices.len());
169    // dbg!(&scene.geometries);
170    // dbg!(&scene.instances);
171}
172
173fn upload_scene_components(
174    device: &wgpu::Device,
175    queue: &wgpu::Queue,
176    scene: &RawSceneComponents,
177) -> SceneComponents {
178    let geometry_buffer_content = scene
179        .geometries
180        .iter()
181        .map(|(index_range, material)| GeometryEntry {
182            first_index: index_range.start as u32,
183            material: *material,
184            ..Default::default()
185        })
186        .collect::<Vec<_>>();
187
188    let instance_buffer_content = scene
189        .instances
190        .iter()
191        .map(|geometry| InstanceEntry {
192            first_vertex: geometry.0.start as u32,
193            first_geometry: geometry.1.start as u32,
194            last_geometry: geometry.1.end as u32,
195            _pad: 1,
196        })
197        .collect::<Vec<_>>();
198
199    let vertices = device.create_buffer_init(&wgpu::util::BufferInitDescriptor {
200        label: Some("Vertices"),
201        contents: bytemuck::cast_slice(&scene.vertices),
202        usage: wgpu::BufferUsages::VERTEX
203            | wgpu::BufferUsages::STORAGE
204            | wgpu::BufferUsages::BLAS_INPUT,
205    });
206    let indices = device.create_buffer_init(&wgpu::util::BufferInitDescriptor {
207        label: Some("Indices"),
208        contents: bytemuck::cast_slice(&scene.indices),
209        usage: wgpu::BufferUsages::INDEX
210            | wgpu::BufferUsages::STORAGE
211            | wgpu::BufferUsages::BLAS_INPUT,
212    });
213    let geometries = device.create_buffer_init(&wgpu::util::BufferInitDescriptor {
214        label: Some("Geometries"),
215        contents: bytemuck::cast_slice(&geometry_buffer_content),
216        usage: wgpu::BufferUsages::STORAGE,
217    });
218    let instances = device.create_buffer_init(&wgpu::util::BufferInitDescriptor {
219        label: Some("Instances"),
220        contents: bytemuck::cast_slice(&instance_buffer_content),
221        usage: wgpu::BufferUsages::STORAGE,
222    });
223
224    let (size_descriptors, bottom_level_acceleration_structures): (Vec<_>, Vec<_>) = scene
225        .instances
226        .iter()
227        .map(|(vertex_range, geometry_range)| {
228            let size_desc: Vec<wgpu::BlasTriangleGeometrySizeDescriptor> = (*geometry_range)
229                .clone()
230                .map(|i| wgpu::BlasTriangleGeometrySizeDescriptor {
231                    vertex_format: wgpu::VertexFormat::Float32x3,
232                    vertex_count: vertex_range.end as u32 - vertex_range.start as u32,
233                    index_format: Some(wgpu::IndexFormat::Uint32),
234                    index_count: Some(
235                        scene.geometries[i].0.end as u32 - scene.geometries[i].0.start as u32,
236                    ),
237                    flags: wgpu::AccelerationStructureGeometryFlags::OPAQUE,
238                })
239                .collect();
240
241            let blas = device.create_blas(
242                &wgpu::CreateBlasDescriptor {
243                    label: None,
244                    flags: wgpu::AccelerationStructureFlags::PREFER_FAST_TRACE,
245                    update_mode: wgpu::AccelerationStructureUpdateMode::Build,
246                },
247                wgpu::BlasGeometrySizeDescriptors::Triangles {
248                    descriptors: size_desc.clone(),
249                },
250            );
251            (size_desc, blas)
252        })
253        .unzip();
254
255    let build_entries: Vec<_> = scene
256        .instances
257        .iter()
258        .zip(size_descriptors.iter())
259        .zip(bottom_level_acceleration_structures.iter())
260        .map(|(((vertex_range, geometry_range), size_desc), blas)| {
261            let triangle_geometries: Vec<_> = size_desc
262                .iter()
263                .zip(geometry_range.clone())
264                .map(|(size, i)| wgpu::BlasTriangleGeometry {
265                    size,
266                    vertex_buffer: &vertices,
267                    first_vertex: vertex_range.start as u32,
268                    vertex_stride: mem::size_of::<Vertex>() as u64,
269                    index_buffer: Some(&indices),
270                    first_index: Some(scene.geometries[i].0.start as u32),
271                    transform_buffer: None,
272                    transform_buffer_offset: None,
273                })
274                .collect();
275
276            wgpu::BlasBuildEntry {
277                blas,
278                geometry: wgpu::BlasGeometries::TriangleGeometries(triangle_geometries),
279            }
280        })
281        .collect();
282
283    let mut encoder =
284        device.create_command_encoder(&wgpu::CommandEncoderDescriptor { label: None });
285
286    encoder.build_acceleration_structures(build_entries.iter(), iter::empty());
287
288    queue.submit(Some(encoder.finish()));
289
290    SceneComponents {
291        vertices,
292        indices,
293        geometries,
294        instances,
295        bottom_level_acceleration_structures,
296    }
297}
298
299fn load_scene(device: &wgpu::Device, queue: &wgpu::Queue) -> SceneComponents {
300    let mut scene = RawSceneComponents::default();
301
302    load_model(&mut scene, "/skybox/models/rustacean-3d.obj");
303    load_model(&mut scene, "/ray_scene/cube.obj");
304
305    upload_scene_components(device, queue, &scene)
306}
307
308struct Example {
309    uniforms: Uniforms,
310    uniform_buf: wgpu::Buffer,
311    tlas: wgpu::Tlas,
312    pipeline: wgpu::RenderPipeline,
313    bind_group: wgpu::BindGroup,
314    scene_components: SceneComponents,
315    animation_timer: utils::AnimationTimer,
316}
317
318impl crate::framework::Example for Example {
319    fn required_features() -> wgpu::Features {
320        wgpu::Features::EXPERIMENTAL_RAY_QUERY
321    }
322
323    fn required_downlevel_capabilities() -> wgpu::DownlevelCapabilities {
324        wgpu::DownlevelCapabilities {
325            flags: wgpu::DownlevelFlags::COMPUTE_SHADERS,
326            ..Default::default()
327        }
328    }
329
330    fn required_limits() -> wgpu::Limits {
331        wgpu::Limits::default().using_minimum_supported_acceleration_structure_values()
332    }
333
334    fn init(
335        config: &wgpu::SurfaceConfiguration,
336        _adapter: &wgpu::Adapter,
337        device: &wgpu::Device,
338        queue: &wgpu::Queue,
339    ) -> Self {
340        let side_count = 8;
341
342        let scene_components = load_scene(device, queue);
343
344        let uniforms = {
345            let view = Mat4::look_at_rh(Vec3::new(0.0, 0.0, 2.5), Vec3::ZERO, Vec3::Y);
346            let proj = Mat4::perspective_rh(
347                59.0_f32.to_radians(),
348                config.width as f32 / config.height as f32,
349                0.001,
350                1000.0,
351            );
352
353            Uniforms {
354                view_inverse: view.inverse(),
355                proj_inverse: proj.inverse(),
356            }
357        };
358
359        let uniform_buf = device.create_buffer_init(&wgpu::util::BufferInitDescriptor {
360            label: Some("Uniform Buffer"),
361            contents: bytemuck::cast_slice(&[uniforms]),
362            usage: wgpu::BufferUsages::UNIFORM | wgpu::BufferUsages::COPY_DST,
363        });
364
365        let tlas = device.create_tlas(&wgpu::CreateTlasDescriptor {
366            label: None,
367            flags: wgpu::AccelerationStructureFlags::PREFER_FAST_TRACE,
368            update_mode: wgpu::AccelerationStructureUpdateMode::Build,
369            max_instances: side_count * side_count,
370        });
371
372        let shader = device.create_shader_module(wgpu::ShaderModuleDescriptor {
373            label: None,
374            source: wgpu::ShaderSource::Wgsl(Cow::Borrowed(include_str!("shader.wgsl"))),
375        });
376
377        let pipeline = device.create_render_pipeline(&wgpu::RenderPipelineDescriptor {
378            label: None,
379            layout: None,
380            vertex: wgpu::VertexState {
381                module: &shader,
382                entry_point: Some("vs_main"),
383                compilation_options: Default::default(),
384                buffers: &[],
385            },
386            fragment: Some(wgpu::FragmentState {
387                module: &shader,
388                entry_point: Some("fs_main"),
389                compilation_options: Default::default(),
390                targets: &[Some(config.format.into())],
391            }),
392            primitive: wgpu::PrimitiveState {
393                topology: wgpu::PrimitiveTopology::TriangleList,
394                ..Default::default()
395            },
396            depth_stencil: None,
397            multisample: wgpu::MultisampleState::default(),
398            multiview: None,
399            cache: None,
400        });
401
402        let bind_group_layout = pipeline.get_bind_group_layout(0);
403
404        let bind_group = device.create_bind_group(&wgpu::BindGroupDescriptor {
405            label: None,
406            layout: &bind_group_layout,
407            entries: &[
408                wgpu::BindGroupEntry {
409                    binding: 0,
410                    resource: uniform_buf.as_entire_binding(),
411                },
412                wgpu::BindGroupEntry {
413                    binding: 5,
414                    resource: tlas.as_binding(),
415                },
416                wgpu::BindGroupEntry {
417                    binding: 1,
418                    resource: scene_components.vertices.as_entire_binding(),
419                },
420                wgpu::BindGroupEntry {
421                    binding: 2,
422                    resource: scene_components.indices.as_entire_binding(),
423                },
424                wgpu::BindGroupEntry {
425                    binding: 3,
426                    resource: scene_components.geometries.as_entire_binding(),
427                },
428                wgpu::BindGroupEntry {
429                    binding: 4,
430                    resource: scene_components.instances.as_entire_binding(),
431                },
432            ],
433        });
434
435        Example {
436            uniforms,
437            uniform_buf,
438            tlas,
439            pipeline,
440            bind_group,
441            scene_components,
442            animation_timer: utils::AnimationTimer::default(),
443        }
444    }
445
446    fn update(&mut self, _event: winit::event::WindowEvent) {}
447
448    fn resize(
449        &mut self,
450        config: &wgpu::SurfaceConfiguration,
451        _device: &wgpu::Device,
452        queue: &wgpu::Queue,
453    ) {
454        let proj = Mat4::perspective_rh(
455            59.0_f32.to_radians(),
456            config.width as f32 / config.height as f32,
457            0.001,
458            1000.0,
459        );
460
461        self.uniforms.proj_inverse = proj.inverse();
462
463        queue.write_buffer(&self.uniform_buf, 0, bytemuck::cast_slice(&[self.uniforms]));
464    }
465
466    fn render(&mut self, view: &wgpu::TextureView, device: &wgpu::Device, queue: &wgpu::Queue) {
467        device.push_error_scope(wgpu::ErrorFilter::Validation);
468
469        // scene update
470        {
471            let dist = 3.5;
472
473            let side_count = 2;
474
475            let anim_time = self.animation_timer.time();
476
477            for x in 0..side_count {
478                for y in 0..side_count {
479                    let instance = self.tlas.index_mut(x + y * side_count);
480
481                    let blas_index = (x + y)
482                        % self
483                            .scene_components
484                            .bottom_level_acceleration_structures
485                            .len();
486
487                    let x = x as f32 / (side_count - 1) as f32;
488                    let y = y as f32 / (side_count - 1) as f32;
489                    let x = x * 2.0 - 1.0;
490                    let y = y * 2.0 - 1.0;
491
492                    let transform = Mat4::from_rotation_translation(
493                        Quat::from_euler(
494                            glam::EulerRot::XYZ,
495                            anim_time * 0.5 * 0.342,
496                            anim_time * 0.5 * 0.254,
497                            anim_time * 0.5 * 0.832 + PI,
498                        ),
499                        Vec3 {
500                            x: x * dist,
501                            y: y * dist,
502                            z: -14.0,
503                        },
504                    );
505                    let transform = transform.transpose().to_cols_array()[..12]
506                        .try_into()
507                        .unwrap();
508                    *instance = Some(wgpu::TlasInstance::new(
509                        &self.scene_components.bottom_level_acceleration_structures[blas_index],
510                        transform,
511                        blas_index as u32,
512                        0xff,
513                    ));
514                }
515            }
516        }
517
518        let mut encoder =
519            device.create_command_encoder(&wgpu::CommandEncoderDescriptor { label: None });
520
521        encoder.build_acceleration_structures(iter::empty(), iter::once(&self.tlas));
522
523        {
524            let mut rpass = encoder.begin_render_pass(&wgpu::RenderPassDescriptor {
525                label: None,
526                color_attachments: &[Some(wgpu::RenderPassColorAttachment {
527                    view,
528                    depth_slice: None,
529                    resolve_target: None,
530                    ops: wgpu::Operations {
531                        load: wgpu::LoadOp::Clear(wgpu::Color::GREEN),
532                        store: wgpu::StoreOp::Store,
533                    },
534                })],
535                depth_stencil_attachment: None,
536                timestamp_writes: None,
537                occlusion_query_set: None,
538            });
539
540            rpass.set_pipeline(&self.pipeline);
541            rpass.set_bind_group(0, Some(&self.bind_group), &[]);
542            rpass.draw(0..3, 0..1);
543        }
544
545        queue.submit(Some(encoder.finish()));
546    }
547}
548
549pub fn main() {
550    crate::framework::run::<Example>("ray_scene");
551}
552
553#[cfg(test)]
554#[wgpu_test::gpu_test]
555pub static TEST: crate::framework::ExampleTestParams = crate::framework::ExampleTestParams {
556    name: "ray_scene",
557    image_path: "/examples/features/src/ray_scene/screenshot.png",
558    width: 1024,
559    height: 768,
560    optional_features: wgpu::Features::default(),
561    base_test_parameters: wgpu_test::TestParameters::default().expect_fail(
562        wgpu_test::FailureCase::backend_adapter(wgpu::Backends::VULKAN, "llvmpipe")
563            .panic("Image data mismatch"),
564    ),
565    comparisons: &[wgpu_test::ComparisonType::Mean(0.02)],
566    _phantom: std::marker::PhantomData::<Example>,
567};