1use std::{borrow::Cow, iter, mem};
2
3use bytemuck::{Pod, Zeroable};
4use glam::{Affine3A, Mat4, Quat, Vec3};
5use wgpu::util::DeviceExt;
6
7use wgpu::StoreOp;
8
9use crate::utils;
10
11#[repr(C)]
13#[derive(Clone, Copy, Pod, Zeroable)]
14struct Vertex {
15 _pos: [f32; 4],
16 _tex_coord: [f32; 2],
17}
18
19fn vertex(pos: [i8; 3], tc: [i8; 2]) -> Vertex {
20 Vertex {
21 _pos: [pos[0] as f32, pos[1] as f32, pos[2] as f32, 1.0],
22 _tex_coord: [tc[0] as f32, tc[1] as f32],
23 }
24}
25
26fn create_vertices() -> (Vec<Vertex>, Vec<u16>) {
27 let vertex_data = [
28 vertex([-1, -1, 1], [0, 0]),
30 vertex([1, -1, 1], [1, 0]),
31 vertex([1, 1, 1], [1, 1]),
32 vertex([-1, 1, 1], [0, 1]),
33 vertex([-1, 1, -1], [1, 0]),
35 vertex([1, 1, -1], [0, 0]),
36 vertex([1, -1, -1], [0, 1]),
37 vertex([-1, -1, -1], [1, 1]),
38 vertex([1, -1, -1], [0, 0]),
40 vertex([1, 1, -1], [1, 0]),
41 vertex([1, 1, 1], [1, 1]),
42 vertex([1, -1, 1], [0, 1]),
43 vertex([-1, -1, 1], [1, 0]),
45 vertex([-1, 1, 1], [0, 0]),
46 vertex([-1, 1, -1], [0, 1]),
47 vertex([-1, -1, -1], [1, 1]),
48 vertex([1, 1, -1], [1, 0]),
50 vertex([-1, 1, -1], [0, 0]),
51 vertex([-1, 1, 1], [0, 1]),
52 vertex([1, 1, 1], [1, 1]),
53 vertex([1, -1, 1], [0, 0]),
55 vertex([-1, -1, 1], [1, 0]),
56 vertex([-1, -1, -1], [1, 1]),
57 vertex([1, -1, -1], [0, 1]),
58 ];
59
60 let index_data: &[u16] = &[
61 0, 1, 2, 2, 3, 0, 4, 5, 6, 6, 7, 4, 8, 9, 10, 10, 11, 8, 12, 13, 14, 14, 15, 12, 16, 17, 18, 18, 19, 16, 20, 21, 22, 22, 23, 20, ];
68
69 (vertex_data.to_vec(), index_data.to_vec())
70}
71
72#[repr(C)]
73#[derive(Clone, Copy, Pod, Zeroable)]
74struct Uniforms {
75 view_inverse: Mat4,
76 proj_inverse: Mat4,
77}
78
79#[inline]
80fn affine_to_rows(mat: &Affine3A) -> [f32; 12] {
81 let row_0 = mat.matrix3.row(0);
82 let row_1 = mat.matrix3.row(1);
83 let row_2 = mat.matrix3.row(2);
84 let translation = mat.translation;
85 [
86 row_0.x,
87 row_0.y,
88 row_0.z,
89 translation.x,
90 row_1.x,
91 row_1.y,
92 row_1.z,
93 translation.y,
94 row_2.x,
95 row_2.y,
96 row_2.z,
97 translation.z,
98 ]
99}
100
101struct Example {
102 rt_target: wgpu::Texture,
103 #[expect(dead_code)]
104 rt_view: wgpu::TextureView,
105 #[expect(dead_code)]
106 sampler: wgpu::Sampler,
107 #[expect(dead_code)]
108 uniform_buf: wgpu::Buffer,
109 #[expect(dead_code)]
110 vertex_buf: wgpu::Buffer,
111 #[expect(dead_code)]
112 index_buf: wgpu::Buffer,
113 tlas: wgpu::Tlas,
114 compute_pipeline: wgpu::ComputePipeline,
115 compute_bind_group: wgpu::BindGroup,
116 blit_pipeline: wgpu::RenderPipeline,
117 blit_bind_group: wgpu::BindGroup,
118 animation_timer: utils::AnimationTimer,
119}
120
121impl crate::framework::Example for Example {
122 fn required_features() -> wgpu::Features {
123 wgpu::Features::TEXTURE_BINDING_ARRAY
124 | wgpu::Features::VERTEX_WRITABLE_STORAGE
125 | wgpu::Features::EXPERIMENTAL_RAY_QUERY
126 }
127
128 fn required_downlevel_capabilities() -> wgpu::DownlevelCapabilities {
129 wgpu::DownlevelCapabilities {
130 flags: wgpu::DownlevelFlags::COMPUTE_SHADERS,
131 ..Default::default()
132 }
133 }
134
135 fn required_limits() -> wgpu::Limits {
136 wgpu::Limits::default().using_minimum_supported_acceleration_structure_values()
137 }
138
139 fn init(
140 config: &wgpu::SurfaceConfiguration,
141 _adapter: &wgpu::Adapter,
142 device: &wgpu::Device,
143 queue: &wgpu::Queue,
144 ) -> Self {
145 let side_count = 8;
146
147 let rt_target = device.create_texture(&wgpu::TextureDescriptor {
148 label: Some("rt_target"),
149 size: wgpu::Extent3d {
150 width: config.width,
151 height: config.height,
152 depth_or_array_layers: 1,
153 },
154 mip_level_count: 1,
155 sample_count: 1,
156 dimension: wgpu::TextureDimension::D2,
157 format: wgpu::TextureFormat::Rgba8Unorm,
158 usage: wgpu::TextureUsages::TEXTURE_BINDING | wgpu::TextureUsages::STORAGE_BINDING,
159 view_formats: &[wgpu::TextureFormat::Rgba8Unorm],
160 });
161
162 let rt_view = rt_target.create_view(&wgpu::TextureViewDescriptor {
163 label: None,
164 format: Some(wgpu::TextureFormat::Rgba8Unorm),
165 dimension: Some(wgpu::TextureViewDimension::D2),
166 usage: None,
167 aspect: wgpu::TextureAspect::All,
168 base_mip_level: 0,
169 mip_level_count: None,
170 base_array_layer: 0,
171 array_layer_count: None,
172 });
173
174 let sampler = device.create_sampler(&wgpu::SamplerDescriptor {
175 label: Some("rt_sampler"),
176 address_mode_u: wgpu::AddressMode::ClampToEdge,
177 address_mode_v: wgpu::AddressMode::ClampToEdge,
178 address_mode_w: wgpu::AddressMode::ClampToEdge,
179 mag_filter: wgpu::FilterMode::Linear,
180 min_filter: wgpu::FilterMode::Linear,
181 mipmap_filter: wgpu::MipmapFilterMode::Nearest,
182 ..Default::default()
183 });
184
185 let uniforms = {
186 let view = Mat4::look_at_rh(Vec3::new(0.0, 0.0, 2.5), Vec3::ZERO, Vec3::Y);
187 let proj = Mat4::perspective_rh(
188 59.0_f32.to_radians(),
189 config.width as f32 / config.height as f32,
190 0.001,
191 1000.0,
192 );
193
194 Uniforms {
195 view_inverse: view.inverse(),
196 proj_inverse: proj.inverse(),
197 }
198 };
199
200 let uniform_buf = device.create_buffer_init(&wgpu::util::BufferInitDescriptor {
201 label: Some("Uniform Buffer"),
202 contents: bytemuck::cast_slice(&[uniforms]),
203 usage: wgpu::BufferUsages::UNIFORM,
204 });
205
206 let (vertex_data, index_data) = create_vertices();
207
208 let vertex_buf = device.create_buffer_init(&wgpu::util::BufferInitDescriptor {
209 label: Some("Vertex Buffer"),
210 contents: bytemuck::cast_slice(&vertex_data),
211 usage: wgpu::BufferUsages::VERTEX | wgpu::BufferUsages::BLAS_INPUT,
212 });
213
214 let index_buf = device.create_buffer_init(&wgpu::util::BufferInitDescriptor {
215 label: Some("Index Buffer"),
216 contents: bytemuck::cast_slice(&index_data),
217 usage: wgpu::BufferUsages::INDEX | wgpu::BufferUsages::BLAS_INPUT,
218 });
219
220 let blas_geo_size_desc = wgpu::BlasTriangleGeometrySizeDescriptor {
221 vertex_format: wgpu::VertexFormat::Float32x3,
222 vertex_count: vertex_data.len() as u32,
223 index_format: Some(wgpu::IndexFormat::Uint16),
224 index_count: Some(index_data.len() as u32),
225 flags: wgpu::AccelerationStructureGeometryFlags::OPAQUE,
226 };
227
228 let blas = device.create_blas(
229 &wgpu::CreateBlasDescriptor {
230 label: None,
231 flags: wgpu::AccelerationStructureFlags::PREFER_FAST_TRACE,
232 update_mode: wgpu::AccelerationStructureUpdateMode::Build,
233 },
234 wgpu::BlasGeometrySizeDescriptors::Triangles {
235 descriptors: vec![blas_geo_size_desc.clone()],
236 },
237 );
238
239 let mut tlas = device.create_tlas(&wgpu::CreateTlasDescriptor {
240 label: None,
241 flags: wgpu::AccelerationStructureFlags::PREFER_FAST_TRACE,
242 update_mode: wgpu::AccelerationStructureUpdateMode::Build,
243 max_instances: side_count * side_count,
244 });
245
246 let shader = device.create_shader_module(wgpu::ShaderModuleDescriptor {
247 label: Some("rt_computer"),
248 source: wgpu::ShaderSource::Wgsl(Cow::Borrowed(include_str!("shader.wgsl"))),
249 });
250
251 let blit_shader = device.create_shader_module(wgpu::ShaderModuleDescriptor {
252 label: Some("blit"),
253 source: wgpu::ShaderSource::Wgsl(Cow::Borrowed(include_str!("blit.wgsl"))),
254 });
255
256 let compute_pipeline = device.create_compute_pipeline(&wgpu::ComputePipelineDescriptor {
257 label: Some("rt"),
258 layout: None,
259 module: &shader,
260 entry_point: Some("main"),
261 compilation_options: Default::default(),
262 cache: None,
263 });
264
265 let compute_bind_group_layout = compute_pipeline.get_bind_group_layout(0);
266
267 let compute_bind_group = device.create_bind_group(&wgpu::BindGroupDescriptor {
268 label: None,
269 layout: &compute_bind_group_layout,
270 entries: &[
271 wgpu::BindGroupEntry {
272 binding: 0,
273 resource: wgpu::BindingResource::TextureView(&rt_view),
274 },
275 wgpu::BindGroupEntry {
276 binding: 1,
277 resource: uniform_buf.as_entire_binding(),
278 },
279 wgpu::BindGroupEntry {
280 binding: 2,
281 resource: wgpu::BindingResource::AccelerationStructure(&tlas),
282 },
283 ],
284 });
285
286 let blit_pipeline = device.create_render_pipeline(&wgpu::RenderPipelineDescriptor {
287 label: Some("blit"),
288 layout: None,
289 vertex: wgpu::VertexState {
290 module: &blit_shader,
291 entry_point: Some("vs_main"),
292 compilation_options: Default::default(),
293 buffers: &[],
294 },
295 fragment: Some(wgpu::FragmentState {
296 module: &blit_shader,
297 entry_point: Some("fs_main"),
298 compilation_options: Default::default(),
299 targets: &[Some(config.format.into())],
300 }),
301 primitive: wgpu::PrimitiveState {
302 topology: wgpu::PrimitiveTopology::TriangleList,
303 ..Default::default()
304 },
305 depth_stencil: None,
306 multisample: wgpu::MultisampleState::default(),
307 multiview_mask: None,
308 cache: None,
309 });
310
311 let blit_bind_group_layout = blit_pipeline.get_bind_group_layout(0);
312
313 let blit_bind_group = device.create_bind_group(&wgpu::BindGroupDescriptor {
314 label: None,
315 layout: &blit_bind_group_layout,
316 entries: &[
317 wgpu::BindGroupEntry {
318 binding: 0,
319 resource: wgpu::BindingResource::TextureView(&rt_view),
320 },
321 wgpu::BindGroupEntry {
322 binding: 1,
323 resource: wgpu::BindingResource::Sampler(&sampler),
324 },
325 ],
326 });
327
328 let dist = 3.0;
329
330 for x in 0..side_count {
331 for y in 0..side_count {
332 tlas[(x + y * side_count) as usize] = Some(wgpu::TlasInstance::new(
333 &blas,
334 affine_to_rows(&Affine3A::from_rotation_translation(
335 Quat::from_rotation_y(45.9_f32.to_radians()),
336 Vec3 {
337 x: x as f32 * dist,
338 y: y as f32 * dist,
339 z: -30.0,
340 },
341 )),
342 0,
343 0xff,
344 ));
345 }
346 }
347
348 let mut encoder =
349 device.create_command_encoder(&wgpu::CommandEncoderDescriptor { label: None });
350
351 encoder.build_acceleration_structures(
352 iter::once(&wgpu::BlasBuildEntry {
353 blas: &blas,
354 geometry: wgpu::BlasGeometries::TriangleGeometries(vec![
355 wgpu::BlasTriangleGeometry {
356 size: &blas_geo_size_desc,
357 vertex_buffer: &vertex_buf,
358 first_vertex: 0,
359 vertex_stride: mem::size_of::<Vertex>() as u64,
360 index_buffer: Some(&index_buf),
361 first_index: Some(0),
362 transform_buffer: None,
363 transform_buffer_offset: None,
364 },
365 ]),
366 }),
367 iter::once(&tlas),
368 );
369
370 queue.submit(Some(encoder.finish()));
371
372 Example {
373 rt_target,
374 rt_view,
375 sampler,
376 uniform_buf,
377 vertex_buf,
378 index_buf,
379 tlas,
380 compute_pipeline,
381 compute_bind_group,
382 blit_pipeline,
383 blit_bind_group,
384 animation_timer: utils::AnimationTimer::default(),
385 }
386 }
387
388 fn update(&mut self, _event: winit::event::WindowEvent) {
389 }
391
392 fn resize(
393 &mut self,
394 _config: &wgpu::SurfaceConfiguration,
395 _device: &wgpu::Device,
396 _queue: &wgpu::Queue,
397 ) {
398 }
399
400 fn render(&mut self, view: &wgpu::TextureView, device: &wgpu::Device, queue: &wgpu::Queue) {
401 let anim_time = self.animation_timer.time();
402
403 self.tlas[0].as_mut().unwrap().transform =
404 affine_to_rows(&Affine3A::from_rotation_translation(
405 Quat::from_euler(
406 glam::EulerRot::XYZ,
407 anim_time * 0.342,
408 anim_time * 0.254,
409 anim_time * 0.832,
410 ),
411 Vec3 {
412 x: 0.0,
413 y: 0.0,
414 z: -6.0,
415 },
416 ));
417
418 let mut encoder =
419 device.create_command_encoder(&wgpu::CommandEncoderDescriptor { label: None });
420
421 encoder.build_acceleration_structures(iter::empty(), iter::once(&self.tlas));
422
423 {
424 let mut cpass = encoder.begin_compute_pass(&wgpu::ComputePassDescriptor {
425 label: None,
426 timestamp_writes: None,
427 });
428 cpass.set_pipeline(&self.compute_pipeline);
429 cpass.set_bind_group(0, Some(&self.compute_bind_group), &[]);
430 cpass.dispatch_workgroups(self.rt_target.width() / 8, self.rt_target.height() / 8, 1);
431 }
432
433 {
434 let mut rpass = encoder.begin_render_pass(&wgpu::RenderPassDescriptor {
435 label: None,
436 color_attachments: &[Some(wgpu::RenderPassColorAttachment {
437 view,
438 depth_slice: None,
439 resolve_target: None,
440 ops: wgpu::Operations {
441 load: wgpu::LoadOp::Clear(wgpu::Color::GREEN),
442 store: StoreOp::Store,
443 },
444 })],
445 depth_stencil_attachment: None,
446 timestamp_writes: None,
447 occlusion_query_set: None,
448 multiview_mask: None,
449 });
450
451 rpass.set_pipeline(&self.blit_pipeline);
452 rpass.set_bind_group(0, Some(&self.blit_bind_group), &[]);
453 rpass.draw(0..3, 0..1);
454 }
455
456 queue.submit(Some(encoder.finish()));
457 }
458}
459
460pub fn main() {
461 crate::framework::run::<Example>("ray-cube");
462}
463
464#[cfg(test)]
465#[wgpu_test::gpu_test]
466pub static TEST: crate::framework::ExampleTestParams = crate::framework::ExampleTestParams {
467 name: "ray_cube_compute",
468 image_path: "/examples/features/src/ray_cube_compute/screenshot.png",
469 width: 1024,
470 height: 768,
471 optional_features: wgpu::Features::default(),
472 base_test_parameters: wgpu_test::TestParameters::default()
473 .expect_fail(wgpu_test::FailureCase::backend(wgpu::Backends::METAL)),
475 comparisons: &[wgpu_test::ComparisonType::Mean(0.02)],
476 _phantom: std::marker::PhantomData::<Example>,
477};