naga/back/spv/ray/
pipeline.rs1use crate::back::spv::{
4 Block, BlockContext, Instruction, LocalType, LookupRaytracingFunction, Writer, WriterFlags,
5};
6
7impl Writer {
8 fn write_trace_ray(
9 &mut self,
10 ir_module: &crate::Module,
11 payload: crate::Handle<crate::GlobalVariable>,
12 ) -> spirv::Word {
13 if let Some(&word) = self
14 .ray_tracing_functions
15 .get(&LookupRaytracingFunction::TraceRay { payload })
16 {
17 return word;
18 }
19
20 let acceleration_structure_type_id =
21 self.get_localtype_id(LocalType::AccelerationStructure);
22
23 let ray_desc_type_id = self.get_handle_type_id(
24 ir_module
25 .special_types
26 .ray_desc
27 .expect("ray desc should be set if `traceRays` is called"),
28 );
29
30 let (func_id, mut function, arg_ids) = self.write_function_signature(
31 &[acceleration_structure_type_id, ray_desc_type_id],
32 self.void_type,
33 );
34
35 let acceleration_structure_id = arg_ids[0];
36 let desc_id = arg_ids[1];
37 let payload_id = self.global_variables[payload].access_id;
38
39 let label_id = self.id_gen.next();
40 let mut block = Block::new(label_id);
41
42 let super::ExtractedRayDesc {
43 ray_flags_id,
44 cull_mask_id,
45 tmin_id,
46 tmax_id,
47 ray_origin_id,
48 ray_dir_id,
49 valid_id,
50 } = self.write_extract_ray_desc(&mut block, desc_id, self.trace_ray_argument_validation);
51
52 let merge_label_id = self.id_gen.next();
53 let merge_block = Block::new(merge_label_id);
54
55 let invalid_label_id = self.id_gen.next();
57 let mut invalid_block = Block::new(invalid_label_id);
58
59 let valid_label_id = self.id_gen.next();
60 let mut valid_block = Block::new(valid_label_id);
61
62 match valid_id {
63 Some(all_valid_id) => {
64 block.body.push(Instruction::selection_merge(
65 merge_label_id,
66 spirv::SelectionControl::NONE,
67 ));
68 function.consume(
69 block,
70 Instruction::branch_conditional(all_valid_id, valid_label_id, invalid_label_id),
71 );
72 }
73 None => {
74 function.consume(block, Instruction::branch(valid_label_id));
75 }
76 }
77
78 let zero = self.get_constant_scalar(crate::Literal::U32(0));
79
80 valid_block.body.push(Instruction::trace_ray(
81 acceleration_structure_id,
82 ray_flags_id,
83 cull_mask_id,
84 zero,
85 zero,
86 zero,
87 ray_origin_id,
88 tmin_id,
89 ray_dir_id,
90 tmax_id,
91 payload_id,
92 ));
93
94 function.consume(valid_block, Instruction::branch(merge_label_id));
95
96 if self.flags.contains(WriterFlags::PRINT_ON_TRACE_RAYS_FAIL) {
97 self.write_debug_printf(
98 &mut invalid_block,
99 "Naga ignored invalid arguments to traceRay with flags: %u t_min: %f t_max: %f origin: %v4f dir: %v4f",
100 &[
101 ray_flags_id,
102 tmin_id,
103 tmax_id,
104 ray_origin_id,
105 ray_dir_id,
106 ],
107 );
108 }
109
110 function.consume(invalid_block, Instruction::branch(merge_label_id));
111
112 function.consume(merge_block, Instruction::return_void());
113
114 function.to_words(&mut self.logical_layout.function_definitions);
115
116 self.ray_tracing_functions
117 .insert(LookupRaytracingFunction::TraceRay { payload }, func_id);
118
119 func_id
120 }
121}
122
123impl BlockContext<'_> {
124 pub(in super::super) fn write_ray_tracing_pipeline_function(
125 &mut self,
126 function: &crate::RayPipelineFunction,
127 block: &mut Block,
128 ) {
129 match *function {
130 crate::RayPipelineFunction::TraceRay {
131 acceleration_structure,
132 descriptor,
133 payload,
134 } => {
135 let crate::Expression::GlobalVariable(payload) =
137 self.ir_function.expressions[payload]
138 else {
139 unreachable!()
140 };
141
142 let desc_id = self.cached[descriptor];
143 let acc_struct_id = self.get_handle_id(acceleration_structure);
144
145 let func = self.writer.write_trace_ray(self.ir_module, payload);
146
147 let func_id = self.gen_id();
148 block.body.push(Instruction::function_call(
149 self.writer.void_type,
150 func_id,
151 func,
152 &[acc_struct_id, desc_id],
153 ));
154 }
155 }
156 }
157}