naga/back/spv/ray/
pipeline.rs

1//! Code for ray tracing pipelines
2
3use crate::back::spv::{
4    Block, BlockContext, Instruction, LocalType, LookupRaytracingFunction, Writer, WriterFlags,
5};
6
7impl Writer {
8    fn write_trace_ray(
9        &mut self,
10        ir_module: &crate::Module,
11        payload: crate::Handle<crate::GlobalVariable>,
12    ) -> spirv::Word {
13        if let Some(&word) = self
14            .ray_tracing_functions
15            .get(&LookupRaytracingFunction::TraceRay { payload })
16        {
17            return word;
18        }
19
20        let acceleration_structure_type_id =
21            self.get_localtype_id(LocalType::AccelerationStructure);
22
23        let ray_desc_type_id = self.get_handle_type_id(
24            ir_module
25                .special_types
26                .ray_desc
27                .expect("ray desc should be set if `traceRays` is called"),
28        );
29
30        let (func_id, mut function, arg_ids) = self.write_function_signature(
31            &[acceleration_structure_type_id, ray_desc_type_id],
32            self.void_type,
33        );
34
35        let acceleration_structure_id = arg_ids[0];
36        let desc_id = arg_ids[1];
37        let payload_id = self.global_variables[payload].access_id;
38
39        let label_id = self.id_gen.next();
40        let mut block = Block::new(label_id);
41
42        let super::ExtractedRayDesc {
43            ray_flags_id,
44            cull_mask_id,
45            tmin_id,
46            tmax_id,
47            ray_origin_id,
48            ray_dir_id,
49            valid_id,
50        } = self.write_extract_ray_desc(&mut block, desc_id, self.trace_ray_argument_validation);
51
52        let merge_label_id = self.id_gen.next();
53        let merge_block = Block::new(merge_label_id);
54
55        // NOTE: this block will be unreachable if trace ray validation is disabled.
56        let invalid_label_id = self.id_gen.next();
57        let mut invalid_block = Block::new(invalid_label_id);
58
59        let valid_label_id = self.id_gen.next();
60        let mut valid_block = Block::new(valid_label_id);
61
62        match valid_id {
63            Some(all_valid_id) => {
64                block.body.push(Instruction::selection_merge(
65                    merge_label_id,
66                    spirv::SelectionControl::NONE,
67                ));
68                function.consume(
69                    block,
70                    Instruction::branch_conditional(all_valid_id, valid_label_id, invalid_label_id),
71                );
72            }
73            None => {
74                function.consume(block, Instruction::branch(valid_label_id));
75            }
76        }
77
78        let zero = self.get_constant_scalar(crate::Literal::U32(0));
79
80        valid_block.body.push(Instruction::trace_ray(
81            acceleration_structure_id,
82            ray_flags_id,
83            cull_mask_id,
84            zero,
85            zero,
86            zero,
87            ray_origin_id,
88            tmin_id,
89            ray_dir_id,
90            tmax_id,
91            payload_id,
92        ));
93
94        function.consume(valid_block, Instruction::branch(merge_label_id));
95
96        if self.flags.contains(WriterFlags::PRINT_ON_TRACE_RAYS_FAIL) {
97            self.write_debug_printf(
98                &mut invalid_block,
99                "Naga ignored invalid arguments to traceRay with flags: %u t_min: %f t_max: %f origin: %v4f dir: %v4f",
100                &[
101                    ray_flags_id,
102                    tmin_id,
103                    tmax_id,
104                    ray_origin_id,
105                    ray_dir_id,
106                ],
107            );
108        }
109
110        function.consume(invalid_block, Instruction::branch(merge_label_id));
111
112        function.consume(merge_block, Instruction::return_void());
113
114        function.to_words(&mut self.logical_layout.function_definitions);
115
116        self.ray_tracing_functions
117            .insert(LookupRaytracingFunction::TraceRay { payload }, func_id);
118
119        func_id
120    }
121}
122
123impl BlockContext<'_> {
124    pub(in super::super) fn write_ray_tracing_pipeline_function(
125        &mut self,
126        function: &crate::RayPipelineFunction,
127        block: &mut Block,
128    ) {
129        match *function {
130            crate::RayPipelineFunction::TraceRay {
131                acceleration_structure,
132                descriptor,
133                payload,
134            } => {
135                // Checked for when validating the module in `validate_block_impl`.
136                let crate::Expression::GlobalVariable(payload) =
137                    self.ir_function.expressions[payload]
138                else {
139                    unreachable!()
140                };
141
142                let desc_id = self.cached[descriptor];
143                let acc_struct_id = self.get_handle_id(acceleration_structure);
144
145                let func = self.writer.write_trace_ray(self.ir_module, payload);
146
147                let func_id = self.gen_id();
148                block.body.push(Instruction::function_call(
149                    self.writer.void_type,
150                    func_id,
151                    func,
152                    &[acc_struct_id, desc_id],
153                ));
154            }
155        }
156    }
157}