naga/back/spv/ray/
mod.rs

1/*!
2Module for code shared between ray queries and ray tracing pipeline code.
3*/
4
5pub mod pipeline;
6pub mod query;
7
8use alloc::{vec, vec::Vec};
9
10use super::{Block, Function, FunctionArgument, Instruction, LookupFunctionType, Writer};
11
12struct ExtractedRayDesc {
13    ray_flags_id: spirv::Word,
14    cull_mask_id: spirv::Word,
15    tmin_id: spirv::Word,
16    tmax_id: spirv::Word,
17    ray_origin_id: spirv::Word,
18    ray_dir_id: spirv::Word,
19    valid_id: Option<spirv::Word>,
20}
21
22/// helper function to check if a particular flag is set in a u32.
23fn write_ray_flags_contains_flags(
24    writer: &mut Writer,
25    block: &mut Block,
26    id: spirv::Word,
27    flag: u32,
28) -> spirv::Word {
29    let bit_id = writer.get_constant_scalar(crate::Literal::U32(flag));
30    let zero_id = writer.get_constant_scalar(crate::Literal::U32(0));
31    let u32_type_id = writer.get_u32_type_id();
32    let bool_ty = writer.get_bool_type_id();
33
34    let and_id = writer.id_gen.next();
35    block.body.push(Instruction::binary(
36        spirv::Op::BitwiseAnd,
37        u32_type_id,
38        and_id,
39        id,
40        bit_id,
41    ));
42
43    let eq_id = writer.id_gen.next();
44    block.body.push(Instruction::binary(
45        spirv::Op::INotEqual,
46        bool_ty,
47        eq_id,
48        and_id,
49        zero_id,
50    ));
51
52    eq_id
53}
54
55impl Writer {
56    fn write_extract_ray_desc(
57        &mut self,
58        block: &mut Block,
59        desc_id: spirv::Word,
60        validate: bool,
61    ) -> ExtractedRayDesc {
62        let bool_type_id = self.get_bool_type_id();
63        let bool_vec3_type_id = self.get_vec3_bool_type_id();
64        let f32_type_id = self.get_f32_type_id();
65        let flag_type_id = self.get_numeric_type_id(super::NumericType::Scalar(crate::Scalar::U32));
66
67        //Note: composite extract indices and types must match `generate_ray_desc_type`
68        let ray_flags_id = self.id_gen.next();
69        block.body.push(Instruction::composite_extract(
70            flag_type_id,
71            ray_flags_id,
72            desc_id,
73            &[0],
74        ));
75        let cull_mask_id = self.id_gen.next();
76        block.body.push(Instruction::composite_extract(
77            flag_type_id,
78            cull_mask_id,
79            desc_id,
80            &[1],
81        ));
82
83        let tmin_id = self.id_gen.next();
84        block.body.push(Instruction::composite_extract(
85            f32_type_id,
86            tmin_id,
87            desc_id,
88            &[2],
89        ));
90        let tmax_id = self.id_gen.next();
91        block.body.push(Instruction::composite_extract(
92            f32_type_id,
93            tmax_id,
94            desc_id,
95            &[3],
96        ));
97
98        let vector_type_id = self.get_numeric_type_id(super::NumericType::Vector {
99            size: crate::VectorSize::Tri,
100            scalar: crate::Scalar::F32,
101        });
102        let ray_origin_id = self.id_gen.next();
103        block.body.push(Instruction::composite_extract(
104            vector_type_id,
105            ray_origin_id,
106            desc_id,
107            &[4],
108        ));
109        let ray_dir_id = self.id_gen.next();
110        block.body.push(Instruction::composite_extract(
111            vector_type_id,
112            ray_dir_id,
113            desc_id,
114            &[5],
115        ));
116
117        let valid_id = validate.then(||{
118            let tmin_le_tmax_id = self.id_gen.next();
119            // Check both that tmin is less than or equal to tmax (https://docs.vulkan.org/spec/latest/appendices/spirvenv.html#VUID-RuntimeSpirv-OpRayQueryInitializeKHR-06350)
120            // and implicitly that neither tmin or tmax are NaN (https://docs.vulkan.org/spec/latest/appendices/spirvenv.html#VUID-RuntimeSpirv-OpRayQueryInitializeKHR-06351)
121            // because this checks if tmin and tmax are ordered too (i.e: not NaN).
122            block.body.push(Instruction::binary(
123                spirv::Op::FOrdLessThanEqual,
124                bool_type_id,
125                tmin_le_tmax_id,
126                tmin_id,
127                tmax_id,
128            ));
129
130            // Check that tmin is greater than or equal to 0 (and
131            // therefore also tmax is too because it is greater than
132            // or equal to tmin) (https://docs.vulkan.org/spec/latest/appendices/spirvenv.html#VUID-RuntimeSpirv-OpRayQueryInitializeKHR-06349).
133            let tmin_ge_zero_id = self.id_gen.next();
134            let zero_id = self.get_constant_scalar(crate::Literal::F32(0.0));
135            block.body.push(Instruction::binary(
136                spirv::Op::FOrdGreaterThanEqual,
137                bool_type_id,
138                tmin_ge_zero_id,
139                tmin_id,
140                zero_id,
141            ));
142
143            // Check that ray origin is finite (https://docs.vulkan.org/spec/latest/appendices/spirvenv.html#VUID-RuntimeSpirv-OpRayQueryInitializeKHR-06348)
144            let ray_origin_infinite_id = self.id_gen.next();
145            block.body.push(Instruction::unary(
146                spirv::Op::IsInf,
147                bool_vec3_type_id,
148                ray_origin_infinite_id,
149                ray_origin_id,
150            ));
151            let any_ray_origin_infinite_id = self.id_gen.next();
152            block.body.push(Instruction::unary(
153                spirv::Op::Any,
154                bool_type_id,
155                any_ray_origin_infinite_id,
156                ray_origin_infinite_id,
157            ));
158
159            let ray_origin_nan_id = self.id_gen.next();
160            block.body.push(Instruction::unary(
161                spirv::Op::IsNan,
162                bool_vec3_type_id,
163                ray_origin_nan_id,
164                ray_origin_id,
165            ));
166            let any_ray_origin_nan_id = self.id_gen.next();
167            block.body.push(Instruction::unary(
168                spirv::Op::Any,
169                bool_type_id,
170                any_ray_origin_nan_id,
171                ray_origin_nan_id,
172            ));
173
174            let ray_origin_not_finite_id = self.id_gen.next();
175            block.body.push(Instruction::binary(
176                spirv::Op::LogicalOr,
177                bool_type_id,
178                ray_origin_not_finite_id,
179                any_ray_origin_nan_id,
180                any_ray_origin_infinite_id,
181            ));
182
183            let all_ray_origin_finite_id = self.id_gen.next();
184            block.body.push(Instruction::unary(
185                spirv::Op::LogicalNot,
186                bool_type_id,
187                all_ray_origin_finite_id,
188                ray_origin_not_finite_id,
189            ));
190
191            // Check that ray direction is finite (https://docs.vulkan.org/spec/latest/appendices/spirvenv.html#VUID-RuntimeSpirv-OpRayQueryInitializeKHR-06348)
192            let ray_dir_infinite_id = self.id_gen.next();
193            block.body.push(Instruction::unary(
194                spirv::Op::IsInf,
195                bool_vec3_type_id,
196                ray_dir_infinite_id,
197                ray_dir_id,
198            ));
199            let any_ray_dir_infinite_id = self.id_gen.next();
200            block.body.push(Instruction::unary(
201                spirv::Op::Any,
202                bool_type_id,
203                any_ray_dir_infinite_id,
204                ray_dir_infinite_id,
205            ));
206
207            let ray_dir_nan_id = self.id_gen.next();
208            block.body.push(Instruction::unary(
209                spirv::Op::IsNan,
210                bool_vec3_type_id,
211                ray_dir_nan_id,
212                ray_dir_id,
213            ));
214            let any_ray_dir_nan_id = self.id_gen.next();
215            block.body.push(Instruction::unary(
216                spirv::Op::Any,
217                bool_type_id,
218                any_ray_dir_nan_id,
219                ray_dir_nan_id,
220            ));
221
222            let ray_dir_not_finite_id = self.id_gen.next();
223            block.body.push(Instruction::binary(
224                spirv::Op::LogicalOr,
225                bool_type_id,
226                ray_dir_not_finite_id,
227                any_ray_dir_nan_id,
228                any_ray_dir_infinite_id,
229            ));
230
231            let all_ray_dir_finite_id = self.id_gen.next();
232            block.body.push(Instruction::unary(
233                spirv::Op::LogicalNot,
234                bool_type_id,
235                all_ray_dir_finite_id,
236                ray_dir_not_finite_id,
237            ));
238
239            /// Writes spirv to check that less than two booleans are true
240            ///
241            /// For each boolean: removes it, `and`s it with all others (i.e for all possible combinations of two booleans in the list checks to see if both are true).
242            /// Then `or`s all of these checks together. This produces whether two or more booleans are true.
243            fn write_less_than_2_true(
244                writer: &mut Writer,
245                block: &mut Block,
246                mut bools: Vec<spirv::Word>,
247            ) -> spirv::Word {
248                assert!(bools.len() > 1, "Must have multiple booleans!");
249                let bool_ty = writer.get_bool_type_id();
250                let mut each_two_true = Vec::new();
251                while let Some(last_bool) = bools.pop() {
252                    for &bool in &bools {
253                        let both_true_id = writer.write_logical_and(
254                            block,
255                            last_bool,
256                            bool,
257                        );
258                        each_two_true.push(both_true_id);
259                    }
260                }
261                let mut all_or_id = each_two_true.pop().expect("since this must have multiple booleans, there must be at least one thing in `each_two_true`");
262                for two_true in each_two_true {
263                    let new_all_or_id = writer.id_gen.next();
264                    block.body.push(Instruction::binary(
265                        spirv::Op::LogicalOr,
266                        bool_ty,
267                        new_all_or_id,
268                        all_or_id,
269                        two_true,
270                    ));
271                    all_or_id = new_all_or_id;
272                }
273
274                let less_than_two_id = writer.id_gen.next();
275                block.body.push(Instruction::unary(
276                    spirv::Op::LogicalNot,
277                    bool_ty,
278                    less_than_two_id,
279                    all_or_id,
280                ));
281                less_than_two_id
282            }
283
284            // Check that at most one of skip triangles and skip AABBs is
285            // present (https://docs.vulkan.org/spec/latest/appendices/spirvenv.html#VUID-RuntimeSpirv-OpRayQueryInitializeKHR-06889)
286            let contains_skip_triangles = write_ray_flags_contains_flags(
287                self,
288                block,
289                ray_flags_id,
290                crate::RayFlag::SKIP_TRIANGLES.bits(),
291            );
292            let contains_skip_aabbs = write_ray_flags_contains_flags(
293                self,
294                block,
295                ray_flags_id,
296                crate::RayFlag::SKIP_AABBS.bits(),
297            );
298
299            let not_contain_skip_triangles_aabbs = write_less_than_2_true(
300                self,
301                block,
302                vec![contains_skip_triangles, contains_skip_aabbs],
303            );
304
305            // Check that at most one of skip triangles (taken from above check),
306            // cull back facing, and cull front face is present (https://docs.vulkan.org/spec/latest/appendices/spirvenv.html#VUID-RuntimeSpirv-OpRayQueryInitializeKHR-06890)
307            let contains_cull_back = write_ray_flags_contains_flags(
308                self,
309                block,
310                ray_flags_id,
311                crate::RayFlag::CULL_BACK_FACING.bits(),
312            );
313            let contains_cull_front = write_ray_flags_contains_flags(
314                self,
315                block,
316                ray_flags_id,
317                crate::RayFlag::CULL_FRONT_FACING.bits(),
318            );
319
320            let not_contain_skip_triangles_cull = write_less_than_2_true(
321                self,
322                block,
323                vec![
324                    contains_skip_triangles,
325                    contains_cull_back,
326                    contains_cull_front,
327                ],
328            );
329
330            // Check that at most one of force opaque, force not opaque, cull opaque,
331            // and cull not opaque are present (https://docs.vulkan.org/spec/latest/appendices/spirvenv.html#VUID-RuntimeSpirv-OpRayQueryInitializeKHR-06891)
332            let contains_opaque = write_ray_flags_contains_flags(
333                self,
334                block,
335                ray_flags_id,
336                crate::RayFlag::FORCE_OPAQUE.bits(),
337            );
338            let contains_no_opaque = write_ray_flags_contains_flags(
339                self,
340                block,
341                ray_flags_id,
342                crate::RayFlag::FORCE_NO_OPAQUE.bits(),
343            );
344            let contains_cull_opaque = write_ray_flags_contains_flags(
345                self,
346                block,
347                ray_flags_id,
348                crate::RayFlag::CULL_OPAQUE.bits(),
349            );
350            let contains_cull_no_opaque = write_ray_flags_contains_flags(
351                self,
352                block,
353                ray_flags_id,
354                crate::RayFlag::CULL_NO_OPAQUE.bits(),
355            );
356
357            let not_contain_multiple_opaque = write_less_than_2_true(
358                self,
359                block,
360                vec![
361                    contains_opaque,
362                    contains_no_opaque,
363                    contains_cull_opaque,
364                    contains_cull_no_opaque,
365                ],
366            );
367
368            // Combine all checks into a single flag saying whether the call is valid or not.
369            self.write_reduce_and(
370                block,
371                vec![
372                    tmin_le_tmax_id,
373                    tmin_ge_zero_id,
374                    all_ray_origin_finite_id,
375                    all_ray_dir_finite_id,
376                    not_contain_skip_triangles_aabbs,
377                    not_contain_skip_triangles_cull,
378                    not_contain_multiple_opaque,
379                ],
380            )
381        });
382
383        ExtractedRayDesc {
384            ray_flags_id,
385            cull_mask_id,
386            tmin_id,
387            tmax_id,
388            ray_origin_id,
389            ray_dir_id,
390            valid_id,
391        }
392    }
393    /// writes a logical and of two scalar booleans
394    fn write_logical_and(
395        &mut self,
396        block: &mut Block,
397        one: spirv::Word,
398        two: spirv::Word,
399    ) -> spirv::Word {
400        let id = self.id_gen.next();
401        let bool_id = self.get_bool_type_id();
402        block.body.push(Instruction::binary(
403            spirv::Op::LogicalAnd,
404            bool_id,
405            id,
406            one,
407            two,
408        ));
409        id
410    }
411
412    fn write_reduce_and(&mut self, block: &mut Block, mut bools: Vec<spirv::Word>) -> spirv::Word {
413        // The combined `and`ed together of all of the bools up to this point.
414        let mut current_combined = bools.pop().unwrap();
415        for boolean in bools {
416            current_combined = self.write_logical_and(block, current_combined, boolean)
417        }
418        current_combined
419    }
420
421    // returns the id of the function, the function, and ids for its arguments.
422    fn write_function_signature(
423        &mut self,
424        arg_types: &[spirv::Word],
425        return_ty: spirv::Word,
426    ) -> (spirv::Word, Function, Vec<spirv::Word>) {
427        let func_ty = self.get_function_type(LookupFunctionType {
428            parameter_type_ids: Vec::from(arg_types),
429            return_type_id: return_ty,
430        });
431
432        let mut function = Function::default();
433        let func_id = self.id_gen.next();
434        function.signature = Some(Instruction::function(
435            return_ty,
436            func_id,
437            spirv::FunctionControl::empty(),
438            func_ty,
439        ));
440
441        let mut arg_ids = Vec::with_capacity(arg_types.len());
442
443        for (idx, &arg_ty) in arg_types.iter().enumerate() {
444            let id = self.id_gen.next();
445            let instruction = Instruction::function_parameter(arg_ty, id);
446            function.parameters.push(FunctionArgument {
447                instruction,
448                handle_id: idx as u32,
449            });
450            arg_ids.push(id);
451        }
452        (func_id, function, arg_ids)
453    }
454}