naga/back/spv/
ray.rs

1/*!
2Generating SPIR-V for ray query operations.
3*/
4
5use alloc::{vec, vec::Vec};
6
7use super::{
8    Block, BlockContext, Function, FunctionArgument, Instruction, LookupFunctionType, NumericType,
9    Writer,
10};
11use crate::{arena::Handle, back::spv::LookupRayQueryFunction};
12
13/// helper function to check if a particular flag is set in a u32.
14fn write_ray_flags_contains_flags(
15    writer: &mut Writer,
16    block: &mut Block,
17    id: spirv::Word,
18    flag: u32,
19) -> spirv::Word {
20    let bit_id = writer.get_constant_scalar(crate::Literal::U32(flag));
21    let zero_id = writer.get_constant_scalar(crate::Literal::U32(0));
22    let u32_type_id = writer.get_u32_type_id();
23    let bool_ty = writer.get_bool_type_id();
24
25    let and_id = writer.id_gen.next();
26    block.body.push(Instruction::binary(
27        spirv::Op::BitwiseAnd,
28        u32_type_id,
29        and_id,
30        id,
31        bit_id,
32    ));
33
34    let eq_id = writer.id_gen.next();
35    block.body.push(Instruction::binary(
36        spirv::Op::INotEqual,
37        bool_ty,
38        eq_id,
39        and_id,
40        zero_id,
41    ));
42
43    eq_id
44}
45
46impl Writer {
47    /// writes a logical and of two scalar booleans
48    fn write_logical_and(
49        &mut self,
50        block: &mut Block,
51        one: spirv::Word,
52        two: spirv::Word,
53    ) -> spirv::Word {
54        let id = self.id_gen.next();
55        let bool_id = self.get_bool_type_id();
56        block.body.push(Instruction::binary(
57            spirv::Op::LogicalAnd,
58            bool_id,
59            id,
60            one,
61            two,
62        ));
63        id
64    }
65
66    fn write_reduce_and(&mut self, block: &mut Block, mut bools: Vec<spirv::Word>) -> spirv::Word {
67        // The combined `and`ed together of all of the bools up to this point.
68        let mut current_combined = bools.pop().unwrap();
69        for boolean in bools {
70            current_combined = self.write_logical_and(block, current_combined, boolean)
71        }
72        current_combined
73    }
74
75    // returns the id of the function, the function, and ids for its arguments.
76    fn write_function_signature(
77        &mut self,
78        arg_types: &[spirv::Word],
79        return_ty: spirv::Word,
80    ) -> (spirv::Word, Function, Vec<spirv::Word>) {
81        let func_ty = self.get_function_type(LookupFunctionType {
82            parameter_type_ids: Vec::from(arg_types),
83            return_type_id: return_ty,
84        });
85
86        let mut function = Function::default();
87        let func_id = self.id_gen.next();
88        function.signature = Some(Instruction::function(
89            return_ty,
90            func_id,
91            spirv::FunctionControl::empty(),
92            func_ty,
93        ));
94
95        let mut arg_ids = Vec::with_capacity(arg_types.len());
96
97        for (idx, &arg_ty) in arg_types.iter().enumerate() {
98            let id = self.id_gen.next();
99            let instruction = Instruction::function_parameter(arg_ty, id);
100            function.parameters.push(FunctionArgument {
101                instruction,
102                handle_id: idx as u32,
103            });
104            arg_ids.push(id);
105        }
106        (func_id, function, arg_ids)
107    }
108
109    pub(super) fn write_ray_query_get_intersection_function(
110        &mut self,
111        is_committed: bool,
112        ir_module: &crate::Module,
113    ) -> spirv::Word {
114        if let Some(&word) =
115            self.ray_query_functions
116                .get(&LookupRayQueryFunction::GetIntersection {
117                    committed: is_committed,
118                })
119        {
120            return word;
121        }
122        let ray_intersection = ir_module.special_types.ray_intersection.unwrap();
123        let intersection_type_id = self.get_handle_type_id(ray_intersection);
124        let intersection_pointer_type_id =
125            self.get_pointer_type_id(intersection_type_id, spirv::StorageClass::Function);
126
127        let flag_type_id = self.get_u32_type_id();
128        let flag_pointer_type_id =
129            self.get_pointer_type_id(flag_type_id, spirv::StorageClass::Function);
130
131        let transform_type_id = self.get_numeric_type_id(NumericType::Matrix {
132            columns: crate::VectorSize::Quad,
133            rows: crate::VectorSize::Tri,
134            scalar: crate::Scalar::F32,
135        });
136        let transform_pointer_type_id =
137            self.get_pointer_type_id(transform_type_id, spirv::StorageClass::Function);
138
139        let barycentrics_type_id = self.get_numeric_type_id(NumericType::Vector {
140            size: crate::VectorSize::Bi,
141            scalar: crate::Scalar::F32,
142        });
143        let barycentrics_pointer_type_id =
144            self.get_pointer_type_id(barycentrics_type_id, spirv::StorageClass::Function);
145
146        let bool_type_id = self.get_bool_type_id();
147        let bool_pointer_type_id =
148            self.get_pointer_type_id(bool_type_id, spirv::StorageClass::Function);
149
150        let scalar_type_id = self.get_f32_type_id();
151        let float_pointer_type_id = self.get_f32_pointer_type_id(spirv::StorageClass::Function);
152
153        let argument_type_id = self.get_ray_query_pointer_id();
154
155        let (func_id, mut function, arg_ids) = self.write_function_signature(
156            &[argument_type_id, flag_pointer_type_id],
157            intersection_type_id,
158        );
159
160        let query_id = arg_ids[0];
161        let intersection_tracker_id = arg_ids[1];
162
163        let label_id = self.id_gen.next();
164        let mut block = Block::new(label_id);
165
166        let blank_intersection = self.get_constant_null(intersection_type_id);
167        let blank_intersection_id = self.id_gen.next();
168        // This must be before everything else in the function.
169        block.body.push(Instruction::variable(
170            intersection_pointer_type_id,
171            blank_intersection_id,
172            spirv::StorageClass::Function,
173            Some(blank_intersection),
174        ));
175
176        let intersection_id = self.get_constant_scalar(crate::Literal::U32(if is_committed {
177            spirv::RayQueryIntersection::RayQueryCommittedIntersectionKHR
178        } else {
179            spirv::RayQueryIntersection::RayQueryCandidateIntersectionKHR
180        } as _));
181
182        let loaded_ray_query_tracker_id = self.id_gen.next();
183        block.body.push(Instruction::load(
184            flag_type_id,
185            loaded_ray_query_tracker_id,
186            intersection_tracker_id,
187            None,
188        ));
189        let proceeded_id = write_ray_flags_contains_flags(
190            self,
191            &mut block,
192            loaded_ray_query_tracker_id,
193            super::RayQueryPoint::PROCEED.bits(),
194        );
195        let finished_proceed_id = write_ray_flags_contains_flags(
196            self,
197            &mut block,
198            loaded_ray_query_tracker_id,
199            super::RayQueryPoint::FINISHED_TRAVERSAL.bits(),
200        );
201        let proceed_finished_correct_id = if is_committed {
202            finished_proceed_id
203        } else {
204            let not_finished_id = self.id_gen.next();
205            block.body.push(Instruction::unary(
206                spirv::Op::LogicalNot,
207                bool_type_id,
208                not_finished_id,
209                finished_proceed_id,
210            ));
211            not_finished_id
212        };
213
214        let is_valid_id =
215            self.write_logical_and(&mut block, proceed_finished_correct_id, proceeded_id);
216
217        let valid_id = self.id_gen.next();
218        let mut valid_block = Block::new(valid_id);
219
220        let final_label_id = self.id_gen.next();
221        let mut final_block = Block::new(final_label_id);
222
223        block.body.push(Instruction::selection_merge(
224            final_label_id,
225            spirv::SelectionControl::NONE,
226        ));
227        function.consume(
228            block,
229            Instruction::branch_conditional(is_valid_id, valid_id, final_label_id),
230        );
231
232        let raw_kind_id = self.id_gen.next();
233        valid_block
234            .body
235            .push(Instruction::ray_query_get_intersection(
236                spirv::Op::RayQueryGetIntersectionTypeKHR,
237                flag_type_id,
238                raw_kind_id,
239                query_id,
240                intersection_id,
241            ));
242        let kind_id = if is_committed {
243            // Nothing to do: the IR value matches `spirv::RayQueryCommittedIntersectionType`
244            raw_kind_id
245        } else {
246            // Remap from the candidate kind to IR
247            let condition_id = self.id_gen.next();
248            let committed_triangle_kind_id = self.get_constant_scalar(crate::Literal::U32(
249                spirv::RayQueryCandidateIntersectionType::RayQueryCandidateIntersectionTriangleKHR
250                    as _,
251            ));
252            valid_block.body.push(Instruction::binary(
253                spirv::Op::IEqual,
254                self.get_bool_type_id(),
255                condition_id,
256                raw_kind_id,
257                committed_triangle_kind_id,
258            ));
259            let kind_id = self.id_gen.next();
260            valid_block.body.push(Instruction::select(
261                flag_type_id,
262                kind_id,
263                condition_id,
264                self.get_constant_scalar(crate::Literal::U32(
265                    crate::RayQueryIntersection::Triangle as _,
266                )),
267                self.get_constant_scalar(crate::Literal::U32(
268                    crate::RayQueryIntersection::Aabb as _,
269                )),
270            ));
271            kind_id
272        };
273        let idx_id = self.get_index_constant(0);
274        let access_idx = self.id_gen.next();
275        valid_block.body.push(Instruction::access_chain(
276            flag_pointer_type_id,
277            access_idx,
278            blank_intersection_id,
279            &[idx_id],
280        ));
281        valid_block
282            .body
283            .push(Instruction::store(access_idx, kind_id, None));
284
285        let not_none_comp_id = self.id_gen.next();
286        let none_id =
287            self.get_constant_scalar(crate::Literal::U32(crate::RayQueryIntersection::None as _));
288        valid_block.body.push(Instruction::binary(
289            spirv::Op::INotEqual,
290            self.get_bool_type_id(),
291            not_none_comp_id,
292            kind_id,
293            none_id,
294        ));
295
296        let not_none_label_id = self.id_gen.next();
297        let mut not_none_block = Block::new(not_none_label_id);
298
299        let outer_merge_label_id = self.id_gen.next();
300        let outer_merge_block = Block::new(outer_merge_label_id);
301
302        valid_block.body.push(Instruction::selection_merge(
303            outer_merge_label_id,
304            spirv::SelectionControl::NONE,
305        ));
306        function.consume(
307            valid_block,
308            Instruction::branch_conditional(
309                not_none_comp_id,
310                not_none_label_id,
311                outer_merge_label_id,
312            ),
313        );
314
315        let instance_custom_index_id = self.id_gen.next();
316        not_none_block
317            .body
318            .push(Instruction::ray_query_get_intersection(
319                spirv::Op::RayQueryGetIntersectionInstanceCustomIndexKHR,
320                flag_type_id,
321                instance_custom_index_id,
322                query_id,
323                intersection_id,
324            ));
325        let instance_id = self.id_gen.next();
326        not_none_block
327            .body
328            .push(Instruction::ray_query_get_intersection(
329                spirv::Op::RayQueryGetIntersectionInstanceIdKHR,
330                flag_type_id,
331                instance_id,
332                query_id,
333                intersection_id,
334            ));
335        let sbt_record_offset_id = self.id_gen.next();
336        not_none_block
337            .body
338            .push(Instruction::ray_query_get_intersection(
339                spirv::Op::RayQueryGetIntersectionInstanceShaderBindingTableRecordOffsetKHR,
340                flag_type_id,
341                sbt_record_offset_id,
342                query_id,
343                intersection_id,
344            ));
345        let geometry_index_id = self.id_gen.next();
346        not_none_block
347            .body
348            .push(Instruction::ray_query_get_intersection(
349                spirv::Op::RayQueryGetIntersectionGeometryIndexKHR,
350                flag_type_id,
351                geometry_index_id,
352                query_id,
353                intersection_id,
354            ));
355        let primitive_index_id = self.id_gen.next();
356        not_none_block
357            .body
358            .push(Instruction::ray_query_get_intersection(
359                spirv::Op::RayQueryGetIntersectionPrimitiveIndexKHR,
360                flag_type_id,
361                primitive_index_id,
362                query_id,
363                intersection_id,
364            ));
365
366        //Note: there is also `OpRayQueryGetIntersectionCandidateAABBOpaqueKHR`,
367        // but it's not a property of an intersection.
368
369        let object_to_world_id = self.id_gen.next();
370        not_none_block
371            .body
372            .push(Instruction::ray_query_get_intersection(
373                spirv::Op::RayQueryGetIntersectionObjectToWorldKHR,
374                transform_type_id,
375                object_to_world_id,
376                query_id,
377                intersection_id,
378            ));
379        let world_to_object_id = self.id_gen.next();
380        not_none_block
381            .body
382            .push(Instruction::ray_query_get_intersection(
383                spirv::Op::RayQueryGetIntersectionWorldToObjectKHR,
384                transform_type_id,
385                world_to_object_id,
386                query_id,
387                intersection_id,
388            ));
389
390        // instance custom index
391        let idx_id = self.get_index_constant(2);
392        let access_idx = self.id_gen.next();
393        not_none_block.body.push(Instruction::access_chain(
394            flag_pointer_type_id,
395            access_idx,
396            blank_intersection_id,
397            &[idx_id],
398        ));
399        not_none_block.body.push(Instruction::store(
400            access_idx,
401            instance_custom_index_id,
402            None,
403        ));
404
405        // instance
406        let idx_id = self.get_index_constant(3);
407        let access_idx = self.id_gen.next();
408        not_none_block.body.push(Instruction::access_chain(
409            flag_pointer_type_id,
410            access_idx,
411            blank_intersection_id,
412            &[idx_id],
413        ));
414        not_none_block
415            .body
416            .push(Instruction::store(access_idx, instance_id, None));
417
418        let idx_id = self.get_index_constant(4);
419        let access_idx = self.id_gen.next();
420        not_none_block.body.push(Instruction::access_chain(
421            flag_pointer_type_id,
422            access_idx,
423            blank_intersection_id,
424            &[idx_id],
425        ));
426        not_none_block
427            .body
428            .push(Instruction::store(access_idx, sbt_record_offset_id, None));
429
430        let idx_id = self.get_index_constant(5);
431        let access_idx = self.id_gen.next();
432        not_none_block.body.push(Instruction::access_chain(
433            flag_pointer_type_id,
434            access_idx,
435            blank_intersection_id,
436            &[idx_id],
437        ));
438        not_none_block
439            .body
440            .push(Instruction::store(access_idx, geometry_index_id, None));
441
442        let idx_id = self.get_index_constant(6);
443        let access_idx = self.id_gen.next();
444        not_none_block.body.push(Instruction::access_chain(
445            flag_pointer_type_id,
446            access_idx,
447            blank_intersection_id,
448            &[idx_id],
449        ));
450        not_none_block
451            .body
452            .push(Instruction::store(access_idx, primitive_index_id, None));
453
454        let idx_id = self.get_index_constant(9);
455        let access_idx = self.id_gen.next();
456        not_none_block.body.push(Instruction::access_chain(
457            transform_pointer_type_id,
458            access_idx,
459            blank_intersection_id,
460            &[idx_id],
461        ));
462        not_none_block
463            .body
464            .push(Instruction::store(access_idx, object_to_world_id, None));
465
466        let idx_id = self.get_index_constant(10);
467        let access_idx = self.id_gen.next();
468        not_none_block.body.push(Instruction::access_chain(
469            transform_pointer_type_id,
470            access_idx,
471            blank_intersection_id,
472            &[idx_id],
473        ));
474        not_none_block
475            .body
476            .push(Instruction::store(access_idx, world_to_object_id, None));
477
478        let tri_comp_id = self.id_gen.next();
479        let tri_id = self.get_constant_scalar(crate::Literal::U32(
480            crate::RayQueryIntersection::Triangle as _,
481        ));
482        not_none_block.body.push(Instruction::binary(
483            spirv::Op::IEqual,
484            self.get_bool_type_id(),
485            tri_comp_id,
486            kind_id,
487            tri_id,
488        ));
489
490        let tri_label_id = self.id_gen.next();
491        let mut tri_block = Block::new(tri_label_id);
492
493        let merge_label_id = self.id_gen.next();
494        let merge_block = Block::new(merge_label_id);
495        // t
496        {
497            let block = if is_committed {
498                &mut not_none_block
499            } else {
500                &mut tri_block
501            };
502            let t_id = self.id_gen.next();
503            block.body.push(Instruction::ray_query_get_intersection(
504                spirv::Op::RayQueryGetIntersectionTKHR,
505                scalar_type_id,
506                t_id,
507                query_id,
508                intersection_id,
509            ));
510            let idx_id = self.get_index_constant(1);
511            let access_idx = self.id_gen.next();
512            block.body.push(Instruction::access_chain(
513                float_pointer_type_id,
514                access_idx,
515                blank_intersection_id,
516                &[idx_id],
517            ));
518            block.body.push(Instruction::store(access_idx, t_id, None));
519        }
520        not_none_block.body.push(Instruction::selection_merge(
521            merge_label_id,
522            spirv::SelectionControl::NONE,
523        ));
524        function.consume(
525            not_none_block,
526            Instruction::branch_conditional(not_none_comp_id, tri_label_id, merge_label_id),
527        );
528
529        let barycentrics_id = self.id_gen.next();
530        tri_block.body.push(Instruction::ray_query_get_intersection(
531            spirv::Op::RayQueryGetIntersectionBarycentricsKHR,
532            barycentrics_type_id,
533            barycentrics_id,
534            query_id,
535            intersection_id,
536        ));
537
538        let front_face_id = self.id_gen.next();
539        tri_block.body.push(Instruction::ray_query_get_intersection(
540            spirv::Op::RayQueryGetIntersectionFrontFaceKHR,
541            bool_type_id,
542            front_face_id,
543            query_id,
544            intersection_id,
545        ));
546
547        let idx_id = self.get_index_constant(7);
548        let access_idx = self.id_gen.next();
549        tri_block.body.push(Instruction::access_chain(
550            barycentrics_pointer_type_id,
551            access_idx,
552            blank_intersection_id,
553            &[idx_id],
554        ));
555        tri_block
556            .body
557            .push(Instruction::store(access_idx, barycentrics_id, None));
558
559        let idx_id = self.get_index_constant(8);
560        let access_idx = self.id_gen.next();
561        tri_block.body.push(Instruction::access_chain(
562            bool_pointer_type_id,
563            access_idx,
564            blank_intersection_id,
565            &[idx_id],
566        ));
567        tri_block
568            .body
569            .push(Instruction::store(access_idx, front_face_id, None));
570        function.consume(tri_block, Instruction::branch(merge_label_id));
571        function.consume(merge_block, Instruction::branch(outer_merge_label_id));
572        function.consume(outer_merge_block, Instruction::branch(final_label_id));
573
574        let loaded_blank_intersection_id = self.id_gen.next();
575        final_block.body.push(Instruction::load(
576            intersection_type_id,
577            loaded_blank_intersection_id,
578            blank_intersection_id,
579            None,
580        ));
581        function.consume(
582            final_block,
583            Instruction::return_value(loaded_blank_intersection_id),
584        );
585
586        function.to_words(&mut self.logical_layout.function_definitions);
587        self.ray_query_functions.insert(
588            LookupRayQueryFunction::GetIntersection {
589                committed: is_committed,
590            },
591            func_id,
592        );
593        func_id
594    }
595
596    fn write_ray_query_initialize(&mut self, ir_module: &crate::Module) -> spirv::Word {
597        if let Some(&word) = self
598            .ray_query_functions
599            .get(&LookupRayQueryFunction::Initialize)
600        {
601            return word;
602        }
603
604        let ray_query_type_id = self.get_ray_query_pointer_id();
605        let acceleration_structure_type_id =
606            self.get_localtype_id(super::LocalType::AccelerationStructure);
607        let ray_desc_type_id = self.get_handle_type_id(
608            ir_module
609                .special_types
610                .ray_desc
611                .expect("ray desc should be set if ray queries are being initialized"),
612        );
613
614        let u32_ty = self.get_u32_type_id();
615        let u32_ptr_ty = self.get_pointer_type_id(u32_ty, spirv::StorageClass::Function);
616
617        let f32_type_id = self.get_f32_type_id();
618        let f32_ptr_ty = self.get_pointer_type_id(f32_type_id, spirv::StorageClass::Function);
619
620        let bool_type_id = self.get_bool_type_id();
621        let bool_vec3_type_id = self.get_vec3_bool_type_id();
622
623        let (func_id, mut function, arg_ids) = self.write_function_signature(
624            &[
625                ray_query_type_id,
626                acceleration_structure_type_id,
627                ray_desc_type_id,
628                u32_ptr_ty,
629                f32_ptr_ty,
630            ],
631            self.void_type,
632        );
633
634        let query_id = arg_ids[0];
635        let acceleration_structure_id = arg_ids[1];
636        let desc_id = arg_ids[2];
637        let init_tracker_id = arg_ids[3];
638        let t_max_tracker_id = arg_ids[4];
639
640        let label_id = self.id_gen.next();
641        let mut block = Block::new(label_id);
642
643        let flag_type_id = self.get_numeric_type_id(NumericType::Scalar(crate::Scalar::U32));
644
645        //Note: composite extract indices and types must match `generate_ray_desc_type`
646        let ray_flags_id = self.id_gen.next();
647        block.body.push(Instruction::composite_extract(
648            flag_type_id,
649            ray_flags_id,
650            desc_id,
651            &[0],
652        ));
653        let cull_mask_id = self.id_gen.next();
654        block.body.push(Instruction::composite_extract(
655            flag_type_id,
656            cull_mask_id,
657            desc_id,
658            &[1],
659        ));
660
661        let tmin_id = self.id_gen.next();
662        block.body.push(Instruction::composite_extract(
663            f32_type_id,
664            tmin_id,
665            desc_id,
666            &[2],
667        ));
668        let tmax_id = self.id_gen.next();
669        block.body.push(Instruction::composite_extract(
670            f32_type_id,
671            tmax_id,
672            desc_id,
673            &[3],
674        ));
675        block
676            .body
677            .push(Instruction::store(t_max_tracker_id, tmax_id, None));
678
679        let vector_type_id = self.get_numeric_type_id(NumericType::Vector {
680            size: crate::VectorSize::Tri,
681            scalar: crate::Scalar::F32,
682        });
683        let ray_origin_id = self.id_gen.next();
684        block.body.push(Instruction::composite_extract(
685            vector_type_id,
686            ray_origin_id,
687            desc_id,
688            &[4],
689        ));
690        let ray_dir_id = self.id_gen.next();
691        block.body.push(Instruction::composite_extract(
692            vector_type_id,
693            ray_dir_id,
694            desc_id,
695            &[5],
696        ));
697
698        let valid_id = self.ray_query_initialization_tracking.then(||{
699            let tmin_le_tmax_id = self.id_gen.next();
700            // Check both that tmin is less than or equal to tmax (https://docs.vulkan.org/spec/latest/appendices/spirvenv.html#VUID-RuntimeSpirv-OpRayQueryInitializeKHR-06350)
701            // and implicitly that neither tmin or tmax are NaN (https://docs.vulkan.org/spec/latest/appendices/spirvenv.html#VUID-RuntimeSpirv-OpRayQueryInitializeKHR-06351)
702            // because this checks if tmin and tmax are ordered too (i.e: not NaN).
703            block.body.push(Instruction::binary(
704                spirv::Op::FOrdLessThanEqual,
705                bool_type_id,
706                tmin_le_tmax_id,
707                tmin_id,
708                tmax_id,
709            ));
710
711            // Check that tmin is greater than or equal to 0 (and
712            // therefore also tmax is too because it is greater than
713            // or equal to tmin) (https://docs.vulkan.org/spec/latest/appendices/spirvenv.html#VUID-RuntimeSpirv-OpRayQueryInitializeKHR-06349).
714            let tmin_ge_zero_id = self.id_gen.next();
715            let zero_id = self.get_constant_scalar(crate::Literal::F32(0.0));
716            block.body.push(Instruction::binary(
717                spirv::Op::FOrdGreaterThanEqual,
718                bool_type_id,
719                tmin_ge_zero_id,
720                tmin_id,
721                zero_id,
722            ));
723
724            // Check that ray origin is finite (https://docs.vulkan.org/spec/latest/appendices/spirvenv.html#VUID-RuntimeSpirv-OpRayQueryInitializeKHR-06348)
725            let ray_origin_infinite_id = self.id_gen.next();
726            block.body.push(Instruction::unary(
727                spirv::Op::IsInf,
728                bool_vec3_type_id,
729                ray_origin_infinite_id,
730                ray_origin_id,
731            ));
732            let any_ray_origin_infinite_id = self.id_gen.next();
733            block.body.push(Instruction::unary(
734                spirv::Op::Any,
735                bool_type_id,
736                any_ray_origin_infinite_id,
737                ray_origin_infinite_id,
738            ));
739
740            let ray_origin_nan_id = self.id_gen.next();
741            block.body.push(Instruction::unary(
742                spirv::Op::IsNan,
743                bool_vec3_type_id,
744                ray_origin_nan_id,
745                ray_origin_id,
746            ));
747            let any_ray_origin_nan_id = self.id_gen.next();
748            block.body.push(Instruction::unary(
749                spirv::Op::Any,
750                bool_type_id,
751                any_ray_origin_nan_id,
752                ray_origin_nan_id,
753            ));
754
755            let ray_origin_not_finite_id = self.id_gen.next();
756            block.body.push(Instruction::binary(
757                spirv::Op::LogicalOr,
758                bool_type_id,
759                ray_origin_not_finite_id,
760                any_ray_origin_nan_id,
761                any_ray_origin_infinite_id,
762            ));
763
764            let all_ray_origin_finite_id = self.id_gen.next();
765            block.body.push(Instruction::unary(
766                spirv::Op::LogicalNot,
767                bool_type_id,
768                all_ray_origin_finite_id,
769                ray_origin_not_finite_id,
770            ));
771
772            // Check that ray direction is finite (https://docs.vulkan.org/spec/latest/appendices/spirvenv.html#VUID-RuntimeSpirv-OpRayQueryInitializeKHR-06348)
773            let ray_dir_infinite_id = self.id_gen.next();
774            block.body.push(Instruction::unary(
775                spirv::Op::IsInf,
776                bool_vec3_type_id,
777                ray_dir_infinite_id,
778                ray_dir_id,
779            ));
780            let any_ray_dir_infinite_id = self.id_gen.next();
781            block.body.push(Instruction::unary(
782                spirv::Op::Any,
783                bool_type_id,
784                any_ray_dir_infinite_id,
785                ray_dir_infinite_id,
786            ));
787
788            let ray_dir_nan_id = self.id_gen.next();
789            block.body.push(Instruction::unary(
790                spirv::Op::IsNan,
791                bool_vec3_type_id,
792                ray_dir_nan_id,
793                ray_dir_id,
794            ));
795            let any_ray_dir_nan_id = self.id_gen.next();
796            block.body.push(Instruction::unary(
797                spirv::Op::Any,
798                bool_type_id,
799                any_ray_dir_nan_id,
800                ray_dir_nan_id,
801            ));
802
803            let ray_dir_not_finite_id = self.id_gen.next();
804            block.body.push(Instruction::binary(
805                spirv::Op::LogicalOr,
806                bool_type_id,
807                ray_dir_not_finite_id,
808                any_ray_dir_nan_id,
809                any_ray_dir_infinite_id,
810            ));
811
812            let all_ray_dir_finite_id = self.id_gen.next();
813            block.body.push(Instruction::unary(
814                spirv::Op::LogicalNot,
815                bool_type_id,
816                all_ray_dir_finite_id,
817                ray_dir_not_finite_id,
818            ));
819
820            /// Writes spirv to check that less than two booleans are true
821            ///
822            /// For each boolean: removes it, `and`s it with all others (i.e for all possible combinations of two booleans in the list checks to see if both are true).
823            /// Then `or`s all of these checks together. This produces whether two or more booleans are true.
824            fn write_less_than_2_true(
825                writer: &mut Writer,
826                block: &mut Block,
827                mut bools: Vec<spirv::Word>,
828            ) -> spirv::Word {
829                assert!(bools.len() > 1, "Must have multiple booleans!");
830                let bool_ty = writer.get_bool_type_id();
831                let mut each_two_true = Vec::new();
832                while let Some(last_bool) = bools.pop() {
833                    for &bool in &bools {
834                        let both_true_id = writer.write_logical_and(
835                            block,
836                            last_bool,
837                            bool,
838                        );
839                        each_two_true.push(both_true_id);
840                    }
841                }
842                let mut all_or_id = each_two_true.pop().expect("since this must have multiple booleans, there must be at least one thing in `each_two_true`");
843                for two_true in each_two_true {
844                    let new_all_or_id = writer.id_gen.next();
845                    block.body.push(Instruction::binary(
846                        spirv::Op::LogicalOr,
847                        bool_ty,
848                        new_all_or_id,
849                        all_or_id,
850                        two_true,
851                    ));
852                    all_or_id = new_all_or_id;
853                }
854
855                let less_than_two_id = writer.id_gen.next();
856                block.body.push(Instruction::unary(
857                    spirv::Op::LogicalNot,
858                    bool_ty,
859                    less_than_two_id,
860                    all_or_id,
861                ));
862                less_than_two_id
863            }
864
865            // Check that at most one of skip triangles and skip AABBs is
866            // present (https://docs.vulkan.org/spec/latest/appendices/spirvenv.html#VUID-RuntimeSpirv-OpRayQueryInitializeKHR-06889)
867            let contains_skip_triangles = write_ray_flags_contains_flags(
868                self,
869                &mut block,
870                ray_flags_id,
871                crate::RayFlag::SKIP_TRIANGLES.bits(),
872            );
873            let contains_skip_aabbs = write_ray_flags_contains_flags(
874                self,
875                &mut block,
876                ray_flags_id,
877                crate::RayFlag::SKIP_AABBS.bits(),
878            );
879
880            let not_contain_skip_triangles_aabbs = write_less_than_2_true(
881                self,
882                &mut block,
883                vec![contains_skip_triangles, contains_skip_aabbs],
884            );
885
886            // Check that at most one of skip triangles (taken from above check),
887            // cull back facing, and cull front face is present (https://docs.vulkan.org/spec/latest/appendices/spirvenv.html#VUID-RuntimeSpirv-OpRayQueryInitializeKHR-06890)
888            let contains_cull_back = write_ray_flags_contains_flags(
889                self,
890                &mut block,
891                ray_flags_id,
892                crate::RayFlag::CULL_BACK_FACING.bits(),
893            );
894            let contains_cull_front = write_ray_flags_contains_flags(
895                self,
896                &mut block,
897                ray_flags_id,
898                crate::RayFlag::CULL_FRONT_FACING.bits(),
899            );
900
901            let not_contain_skip_triangles_cull = write_less_than_2_true(
902                self,
903                &mut block,
904                vec![
905                    contains_skip_triangles,
906                    contains_cull_back,
907                    contains_cull_front,
908                ],
909            );
910
911            // Check that at most one of force opaque, force not opaque, cull opaque,
912            // and cull not opaque are present (https://docs.vulkan.org/spec/latest/appendices/spirvenv.html#VUID-RuntimeSpirv-OpRayQueryInitializeKHR-06891)
913            let contains_opaque = write_ray_flags_contains_flags(
914                self,
915                &mut block,
916                ray_flags_id,
917                crate::RayFlag::FORCE_OPAQUE.bits(),
918            );
919            let contains_no_opaque = write_ray_flags_contains_flags(
920                self,
921                &mut block,
922                ray_flags_id,
923                crate::RayFlag::FORCE_NO_OPAQUE.bits(),
924            );
925            let contains_cull_opaque = write_ray_flags_contains_flags(
926                self,
927                &mut block,
928                ray_flags_id,
929                crate::RayFlag::CULL_OPAQUE.bits(),
930            );
931            let contains_cull_no_opaque = write_ray_flags_contains_flags(
932                self,
933                &mut block,
934                ray_flags_id,
935                crate::RayFlag::CULL_NO_OPAQUE.bits(),
936            );
937
938            let not_contain_multiple_opaque = write_less_than_2_true(
939                self,
940                &mut block,
941                vec![
942                    contains_opaque,
943                    contains_no_opaque,
944                    contains_cull_opaque,
945                    contains_cull_no_opaque,
946                ],
947            );
948
949            // Combine all checks into a single flag saying whether the call is valid or not.
950            self.write_reduce_and(
951                &mut block,
952                vec![
953                    tmin_le_tmax_id,
954                    tmin_ge_zero_id,
955                    all_ray_origin_finite_id,
956                    all_ray_dir_finite_id,
957                    not_contain_skip_triangles_aabbs,
958                    not_contain_skip_triangles_cull,
959                    not_contain_multiple_opaque,
960                ],
961            )
962        });
963
964        let merge_label_id = self.id_gen.next();
965        let merge_block = Block::new(merge_label_id);
966
967        // NOTE: this block will be unreachable if initialization tracking is disabled.
968        let invalid_label_id = self.id_gen.next();
969        let mut invalid_block = Block::new(invalid_label_id);
970
971        let valid_label_id = self.id_gen.next();
972        let mut valid_block = Block::new(valid_label_id);
973
974        match valid_id {
975            Some(all_valid_id) => {
976                block.body.push(Instruction::selection_merge(
977                    merge_label_id,
978                    spirv::SelectionControl::NONE,
979                ));
980                function.consume(
981                    block,
982                    Instruction::branch_conditional(all_valid_id, valid_label_id, invalid_label_id),
983                );
984            }
985            None => {
986                function.consume(block, Instruction::branch(valid_label_id));
987            }
988        }
989
990        valid_block.body.push(Instruction::ray_query_initialize(
991            query_id,
992            acceleration_structure_id,
993            ray_flags_id,
994            cull_mask_id,
995            ray_origin_id,
996            tmin_id,
997            ray_dir_id,
998            tmax_id,
999        ));
1000
1001        let const_initialized = self.get_constant_scalar(crate::Literal::U32(
1002            super::RayQueryPoint::INITIALIZED.bits(),
1003        ));
1004        valid_block
1005            .body
1006            .push(Instruction::store(init_tracker_id, const_initialized, None));
1007
1008        function.consume(valid_block, Instruction::branch(merge_label_id));
1009
1010        if self
1011            .flags
1012            .contains(super::WriterFlags::PRINT_ON_RAY_QUERY_INITIALIZATION_FAIL)
1013        {
1014            self.write_debug_printf(
1015                &mut invalid_block,
1016                "Naga ignored invalid arguments to rayQueryInitialize with flags: %u t_min: %f t_max: %f origin: %v4f dir: %v4f",
1017                &[
1018                    ray_flags_id,
1019                    tmin_id,
1020                    tmax_id,
1021                    ray_origin_id,
1022                    ray_dir_id,
1023                ],
1024            );
1025        }
1026
1027        function.consume(invalid_block, Instruction::branch(merge_label_id));
1028
1029        function.consume(merge_block, Instruction::return_void());
1030
1031        function.to_words(&mut self.logical_layout.function_definitions);
1032
1033        self.ray_query_functions
1034            .insert(LookupRayQueryFunction::Initialize, func_id);
1035        func_id
1036    }
1037
1038    fn write_ray_query_proceed(&mut self) -> spirv::Word {
1039        if let Some(&word) = self
1040            .ray_query_functions
1041            .get(&LookupRayQueryFunction::Proceed)
1042        {
1043            return word;
1044        }
1045
1046        let ray_query_type_id = self.get_ray_query_pointer_id();
1047
1048        let u32_ty = self.get_u32_type_id();
1049        let u32_ptr_ty = self.get_pointer_type_id(u32_ty, spirv::StorageClass::Function);
1050
1051        let bool_type_id = self.get_bool_type_id();
1052        let bool_ptr_ty = self.get_pointer_type_id(bool_type_id, spirv::StorageClass::Function);
1053
1054        let (func_id, mut function, arg_ids) =
1055            self.write_function_signature(&[ray_query_type_id, u32_ptr_ty], bool_type_id);
1056
1057        let query_id = arg_ids[0];
1058        let init_tracker_id = arg_ids[1];
1059
1060        let block_id = self.id_gen.next();
1061        let mut block = Block::new(block_id);
1062
1063        // TODO: perhaps this could be replaced with an OpPhi?
1064        let proceeded_id = self.id_gen.next();
1065        let const_false = self.get_constant_scalar(crate::Literal::Bool(false));
1066        block.body.push(Instruction::variable(
1067            bool_ptr_ty,
1068            proceeded_id,
1069            spirv::StorageClass::Function,
1070            Some(const_false),
1071        ));
1072
1073        let initialized_tracker_id = self.id_gen.next();
1074        block.body.push(Instruction::load(
1075            u32_ty,
1076            initialized_tracker_id,
1077            init_tracker_id,
1078            None,
1079        ));
1080
1081        let merge_id = self.id_gen.next();
1082        let mut merge_block = Block::new(merge_id);
1083
1084        let valid_block_id = self.id_gen.next();
1085        let mut valid_block = Block::new(valid_block_id);
1086
1087        let instruction = if self.ray_query_initialization_tracking {
1088            let is_initialized = write_ray_flags_contains_flags(
1089                self,
1090                &mut block,
1091                initialized_tracker_id,
1092                super::RayQueryPoint::INITIALIZED.bits(),
1093            );
1094
1095            block.body.push(Instruction::selection_merge(
1096                merge_id,
1097                spirv::SelectionControl::NONE,
1098            ));
1099
1100            Instruction::branch_conditional(is_initialized, valid_block_id, merge_id)
1101        } else {
1102            Instruction::branch(valid_block_id)
1103        };
1104
1105        function.consume(block, instruction);
1106
1107        let has_proceeded = self.id_gen.next();
1108        valid_block.body.push(Instruction::ray_query_proceed(
1109            bool_type_id,
1110            has_proceeded,
1111            query_id,
1112        ));
1113
1114        valid_block
1115            .body
1116            .push(Instruction::store(proceeded_id, has_proceeded, None));
1117
1118        let add_flag_finished = self.get_constant_scalar(crate::Literal::U32(
1119            (super::RayQueryPoint::PROCEED | super::RayQueryPoint::FINISHED_TRAVERSAL).bits(),
1120        ));
1121        let add_flag_continuing =
1122            self.get_constant_scalar(crate::Literal::U32(super::RayQueryPoint::PROCEED.bits()));
1123
1124        let add_flags_id = self.id_gen.next();
1125        valid_block.body.push(Instruction::select(
1126            u32_ty,
1127            add_flags_id,
1128            has_proceeded,
1129            add_flag_continuing,
1130            add_flag_finished,
1131        ));
1132        let final_flags = self.id_gen.next();
1133        valid_block.body.push(Instruction::binary(
1134            spirv::Op::BitwiseOr,
1135            u32_ty,
1136            final_flags,
1137            initialized_tracker_id,
1138            add_flags_id,
1139        ));
1140        valid_block
1141            .body
1142            .push(Instruction::store(init_tracker_id, final_flags, None));
1143
1144        function.consume(valid_block, Instruction::branch(merge_id));
1145
1146        let loaded_proceeded_id = self.id_gen.next();
1147        merge_block.body.push(Instruction::load(
1148            bool_type_id,
1149            loaded_proceeded_id,
1150            proceeded_id,
1151            None,
1152        ));
1153
1154        function.consume(merge_block, Instruction::return_value(loaded_proceeded_id));
1155
1156        function.to_words(&mut self.logical_layout.function_definitions);
1157
1158        self.ray_query_functions
1159            .insert(LookupRayQueryFunction::Proceed, func_id);
1160        func_id
1161    }
1162
1163    fn write_ray_query_generate_intersection(&mut self) -> spirv::Word {
1164        if let Some(&word) = self
1165            .ray_query_functions
1166            .get(&LookupRayQueryFunction::GenerateIntersection)
1167        {
1168            return word;
1169        }
1170
1171        let ray_query_type_id = self.get_ray_query_pointer_id();
1172
1173        let u32_ty = self.get_u32_type_id();
1174        let u32_ptr_ty = self.get_pointer_type_id(u32_ty, spirv::StorageClass::Function);
1175
1176        let f32_type_id = self.get_f32_type_id();
1177        let f32_ptr_type_id = self.get_pointer_type_id(f32_type_id, spirv::StorageClass::Function);
1178
1179        let bool_type_id = self.get_bool_type_id();
1180
1181        let (func_id, mut function, arg_ids) = self.write_function_signature(
1182            &[ray_query_type_id, u32_ptr_ty, f32_type_id, f32_ptr_type_id],
1183            self.void_type,
1184        );
1185
1186        let query_id = arg_ids[0];
1187        let init_tracker_id = arg_ids[1];
1188        let depth_id = arg_ids[2];
1189        let t_max_tracker_id = arg_ids[3];
1190
1191        let block_id = self.id_gen.next();
1192        let mut block = Block::new(block_id);
1193
1194        let current_t = self.id_gen.next();
1195        block.body.push(Instruction::variable(
1196            f32_ptr_type_id,
1197            current_t,
1198            spirv::StorageClass::Function,
1199            None,
1200        ));
1201
1202        let current_t = self.id_gen.next();
1203        block.body.push(Instruction::variable(
1204            f32_ptr_type_id,
1205            current_t,
1206            spirv::StorageClass::Function,
1207            None,
1208        ));
1209
1210        let valid_id = self.id_gen.next();
1211        let mut valid_block = Block::new(valid_id);
1212
1213        let final_label_id = self.id_gen.next();
1214        let final_block = Block::new(final_label_id);
1215
1216        let instruction = if self.ray_query_initialization_tracking {
1217            let initialized_tracker_id = self.id_gen.next();
1218            block.body.push(Instruction::load(
1219                u32_ty,
1220                initialized_tracker_id,
1221                init_tracker_id,
1222                None,
1223            ));
1224
1225            let proceeded_id = write_ray_flags_contains_flags(
1226                self,
1227                &mut block,
1228                initialized_tracker_id,
1229                super::RayQueryPoint::PROCEED.bits(),
1230            );
1231            let finished_proceed_id = write_ray_flags_contains_flags(
1232                self,
1233                &mut block,
1234                initialized_tracker_id,
1235                super::RayQueryPoint::FINISHED_TRAVERSAL.bits(),
1236            );
1237
1238            // Can't find anything to suggest double calling this function is invalid.
1239
1240            let not_finished_id = self.id_gen.next();
1241            block.body.push(Instruction::unary(
1242                spirv::Op::LogicalNot,
1243                bool_type_id,
1244                not_finished_id,
1245                finished_proceed_id,
1246            ));
1247
1248            let is_valid_id = self.write_logical_and(&mut block, not_finished_id, proceeded_id);
1249
1250            block.body.push(Instruction::selection_merge(
1251                final_label_id,
1252                spirv::SelectionControl::NONE,
1253            ));
1254
1255            Instruction::branch_conditional(is_valid_id, valid_id, final_label_id)
1256        } else {
1257            Instruction::branch(valid_id)
1258        };
1259
1260        function.consume(block, instruction);
1261
1262        let intersection_id = self.get_constant_scalar(crate::Literal::U32(
1263            spirv::RayQueryIntersection::RayQueryCandidateIntersectionKHR as _,
1264        ));
1265        let committed_intersection_id = self.get_constant_scalar(crate::Literal::U32(
1266            spirv::RayQueryIntersection::RayQueryCommittedIntersectionKHR as _,
1267        ));
1268        let raw_kind_id = self.id_gen.next();
1269        valid_block
1270            .body
1271            .push(Instruction::ray_query_get_intersection(
1272                spirv::Op::RayQueryGetIntersectionTypeKHR,
1273                u32_ty,
1274                raw_kind_id,
1275                query_id,
1276                intersection_id,
1277            ));
1278
1279        let candidate_aabb_id = self.get_constant_scalar(crate::Literal::U32(
1280            spirv::RayQueryCandidateIntersectionType::RayQueryCandidateIntersectionAABBKHR as _,
1281        ));
1282        let intersection_aabb_id = self.id_gen.next();
1283        valid_block.body.push(Instruction::binary(
1284            spirv::Op::IEqual,
1285            bool_type_id,
1286            intersection_aabb_id,
1287            raw_kind_id,
1288            candidate_aabb_id,
1289        ));
1290
1291        // Check that the provided t value is between t min and the current committed
1292        // t value, (https://docs.vulkan.org/spec/latest/appendices/spirvenv.html#VUID-RuntimeSpirv-OpRayQueryGenerateIntersectionKHR-06353)
1293
1294        // Get the tmin
1295        let t_min_id = self.id_gen.next();
1296        valid_block.body.push(Instruction::ray_query_get_t_min(
1297            f32_type_id,
1298            t_min_id,
1299            query_id,
1300        ));
1301
1302        // Get the current committed t, or tmax if no hit.
1303        // Basically emulate HLSL's (easier) version
1304        // Pseudo-code:
1305        // ````wgsl
1306        // // start of function
1307        // var current_t:f32;
1308        // ...
1309        // let committed_type_id = RayQueryGetIntersectionTypeKHR<Committed>(query_id);
1310        // if committed_type_id == Committed_None {
1311        //     current_t = load(t_max_tracker);
1312        // } else {
1313        //     current_t = RayQueryGetIntersectionTKHR<Committed>(query_id);
1314        // }
1315        // ...
1316        // ````
1317
1318        let committed_type_id = self.id_gen.next();
1319        valid_block
1320            .body
1321            .push(Instruction::ray_query_get_intersection(
1322                spirv::Op::RayQueryGetIntersectionTypeKHR,
1323                u32_ty,
1324                committed_type_id,
1325                query_id,
1326                committed_intersection_id,
1327            ));
1328
1329        let no_committed = self.id_gen.next();
1330        valid_block.body.push(Instruction::binary(
1331            spirv::Op::IEqual,
1332            bool_type_id,
1333            no_committed,
1334            committed_type_id,
1335            self.get_constant_scalar(crate::Literal::U32(
1336                spirv::RayQueryCommittedIntersectionType::RayQueryCommittedIntersectionNoneKHR as _,
1337            )),
1338        ));
1339
1340        let next_valid_block_id = self.id_gen.next();
1341        let no_committed_block_id = self.id_gen.next();
1342        let mut no_committed_block = Block::new(no_committed_block_id);
1343        let committed_block_id = self.id_gen.next();
1344        let mut committed_block = Block::new(committed_block_id);
1345        valid_block.body.push(Instruction::selection_merge(
1346            next_valid_block_id,
1347            spirv::SelectionControl::NONE,
1348        ));
1349        function.consume(
1350            valid_block,
1351            Instruction::branch_conditional(
1352                no_committed,
1353                no_committed_block_id,
1354                committed_block_id,
1355            ),
1356        );
1357
1358        // Assign t_max to current_t
1359        let t_max_id = self.id_gen.next();
1360        no_committed_block.body.push(Instruction::load(
1361            f32_type_id,
1362            t_max_id,
1363            t_max_tracker_id,
1364            None,
1365        ));
1366        no_committed_block
1367            .body
1368            .push(Instruction::store(current_t, t_max_id, None));
1369        function.consume(no_committed_block, Instruction::branch(next_valid_block_id));
1370
1371        // Assign t_current to current_t
1372        let latest_t_id = self.id_gen.next();
1373        committed_block
1374            .body
1375            .push(Instruction::ray_query_get_intersection(
1376                spirv::Op::RayQueryGetIntersectionTKHR,
1377                f32_type_id,
1378                latest_t_id,
1379                query_id,
1380                intersection_id,
1381            ));
1382        committed_block
1383            .body
1384            .push(Instruction::store(current_t, latest_t_id, None));
1385        function.consume(committed_block, Instruction::branch(next_valid_block_id));
1386
1387        let mut valid_block = Block::new(next_valid_block_id);
1388
1389        let t_ge_t_min = self.id_gen.next();
1390        valid_block.body.push(Instruction::binary(
1391            spirv::Op::FOrdGreaterThanEqual,
1392            bool_type_id,
1393            t_ge_t_min,
1394            depth_id,
1395            t_min_id,
1396        ));
1397        let t_current = self.id_gen.next();
1398        valid_block
1399            .body
1400            .push(Instruction::load(f32_type_id, t_current, current_t, None));
1401        let t_le_t_current = self.id_gen.next();
1402        valid_block.body.push(Instruction::binary(
1403            spirv::Op::FOrdLessThanEqual,
1404            bool_type_id,
1405            t_le_t_current,
1406            depth_id,
1407            t_current,
1408        ));
1409
1410        let t_in_range = self.id_gen.next();
1411        valid_block.body.push(Instruction::binary(
1412            spirv::Op::LogicalAnd,
1413            bool_type_id,
1414            t_in_range,
1415            t_ge_t_min,
1416            t_le_t_current,
1417        ));
1418
1419        let call_valid_id = self.id_gen.next();
1420        valid_block.body.push(Instruction::binary(
1421            spirv::Op::LogicalAnd,
1422            bool_type_id,
1423            call_valid_id,
1424            t_in_range,
1425            intersection_aabb_id,
1426        ));
1427
1428        let generate_label_id = self.id_gen.next();
1429        let mut generate_block = Block::new(generate_label_id);
1430
1431        let merge_label_id = self.id_gen.next();
1432        let merge_block = Block::new(merge_label_id);
1433
1434        valid_block.body.push(Instruction::selection_merge(
1435            merge_label_id,
1436            spirv::SelectionControl::NONE,
1437        ));
1438        function.consume(
1439            valid_block,
1440            Instruction::branch_conditional(call_valid_id, generate_label_id, merge_label_id),
1441        );
1442
1443        generate_block
1444            .body
1445            .push(Instruction::ray_query_generate_intersection(
1446                query_id, depth_id,
1447            ));
1448
1449        function.consume(generate_block, Instruction::branch(merge_label_id));
1450        function.consume(merge_block, Instruction::branch(final_label_id));
1451
1452        function.consume(final_block, Instruction::return_void());
1453
1454        function.to_words(&mut self.logical_layout.function_definitions);
1455
1456        self.ray_query_functions
1457            .insert(LookupRayQueryFunction::GenerateIntersection, func_id);
1458        func_id
1459    }
1460
1461    fn write_ray_query_confirm_intersection(&mut self) -> spirv::Word {
1462        if let Some(&word) = self
1463            .ray_query_functions
1464            .get(&LookupRayQueryFunction::ConfirmIntersection)
1465        {
1466            return word;
1467        }
1468
1469        let ray_query_type_id = self.get_ray_query_pointer_id();
1470
1471        let u32_ty = self.get_u32_type_id();
1472        let u32_ptr_ty = self.get_pointer_type_id(u32_ty, spirv::StorageClass::Function);
1473
1474        let bool_type_id = self.get_bool_type_id();
1475
1476        let (func_id, mut function, arg_ids) =
1477            self.write_function_signature(&[ray_query_type_id, u32_ptr_ty], self.void_type);
1478
1479        let query_id = arg_ids[0];
1480        let init_tracker_id = arg_ids[1];
1481
1482        let block_id = self.id_gen.next();
1483        let mut block = Block::new(block_id);
1484
1485        let valid_id = self.id_gen.next();
1486        let mut valid_block = Block::new(valid_id);
1487
1488        let final_label_id = self.id_gen.next();
1489        let final_block = Block::new(final_label_id);
1490
1491        let instruction = if self.ray_query_initialization_tracking {
1492            let initialized_tracker_id = self.id_gen.next();
1493            block.body.push(Instruction::load(
1494                u32_ty,
1495                initialized_tracker_id,
1496                init_tracker_id,
1497                None,
1498            ));
1499
1500            let proceeded_id = write_ray_flags_contains_flags(
1501                self,
1502                &mut block,
1503                initialized_tracker_id,
1504                super::RayQueryPoint::PROCEED.bits(),
1505            );
1506            let finished_proceed_id = write_ray_flags_contains_flags(
1507                self,
1508                &mut block,
1509                initialized_tracker_id,
1510                super::RayQueryPoint::FINISHED_TRAVERSAL.bits(),
1511            );
1512            // Although it seems strange to call this twice, I (Vecvec) can't find anything to suggest double calling this function is invalid.
1513            let not_finished_id = self.id_gen.next();
1514            block.body.push(Instruction::unary(
1515                spirv::Op::LogicalNot,
1516                bool_type_id,
1517                not_finished_id,
1518                finished_proceed_id,
1519            ));
1520
1521            let is_valid_id = self.write_logical_and(&mut block, not_finished_id, proceeded_id);
1522
1523            block.body.push(Instruction::selection_merge(
1524                final_label_id,
1525                spirv::SelectionControl::NONE,
1526            ));
1527
1528            Instruction::branch_conditional(is_valid_id, valid_id, final_label_id)
1529        } else {
1530            Instruction::branch(valid_id)
1531        };
1532
1533        function.consume(block, instruction);
1534
1535        let intersection_id = self.get_constant_scalar(crate::Literal::U32(
1536            spirv::RayQueryIntersection::RayQueryCandidateIntersectionKHR as _,
1537        ));
1538        let raw_kind_id = self.id_gen.next();
1539        valid_block
1540            .body
1541            .push(Instruction::ray_query_get_intersection(
1542                spirv::Op::RayQueryGetIntersectionTypeKHR,
1543                u32_ty,
1544                raw_kind_id,
1545                query_id,
1546                intersection_id,
1547            ));
1548
1549        let candidate_tri_id = self.get_constant_scalar(crate::Literal::U32(
1550            spirv::RayQueryCandidateIntersectionType::RayQueryCandidateIntersectionTriangleKHR as _,
1551        ));
1552        let intersection_tri_id = self.id_gen.next();
1553        valid_block.body.push(Instruction::binary(
1554            spirv::Op::IEqual,
1555            bool_type_id,
1556            intersection_tri_id,
1557            raw_kind_id,
1558            candidate_tri_id,
1559        ));
1560
1561        let generate_label_id = self.id_gen.next();
1562        let mut generate_block = Block::new(generate_label_id);
1563
1564        let merge_label_id = self.id_gen.next();
1565        let merge_block = Block::new(merge_label_id);
1566
1567        valid_block.body.push(Instruction::selection_merge(
1568            merge_label_id,
1569            spirv::SelectionControl::NONE,
1570        ));
1571        function.consume(
1572            valid_block,
1573            Instruction::branch_conditional(intersection_tri_id, generate_label_id, merge_label_id),
1574        );
1575
1576        generate_block
1577            .body
1578            .push(Instruction::ray_query_confirm_intersection(query_id));
1579
1580        function.consume(generate_block, Instruction::branch(merge_label_id));
1581        function.consume(merge_block, Instruction::branch(final_label_id));
1582
1583        function.consume(final_block, Instruction::return_void());
1584
1585        self.ray_query_functions
1586            .insert(LookupRayQueryFunction::ConfirmIntersection, func_id);
1587
1588        function.to_words(&mut self.logical_layout.function_definitions);
1589
1590        func_id
1591    }
1592
1593    fn write_ray_query_get_vertex_positions(
1594        &mut self,
1595        is_committed: bool,
1596        ir_module: &crate::Module,
1597    ) -> spirv::Word {
1598        if let Some(&word) =
1599            self.ray_query_functions
1600                .get(&LookupRayQueryFunction::GetVertexPositions {
1601                    committed: is_committed,
1602                })
1603        {
1604            return word;
1605        }
1606
1607        let (committed_ty, committed_tri_ty) = if is_committed {
1608            (
1609                spirv::RayQueryIntersection::RayQueryCommittedIntersectionKHR as u32,
1610                spirv::RayQueryCommittedIntersectionType::RayQueryCommittedIntersectionTriangleKHR
1611                    as u32,
1612            )
1613        } else {
1614            (
1615                spirv::RayQueryIntersection::RayQueryCandidateIntersectionKHR as u32,
1616                spirv::RayQueryCandidateIntersectionType::RayQueryCandidateIntersectionTriangleKHR
1617                    as u32,
1618            )
1619        };
1620
1621        let ray_query_type_id = self.get_ray_query_pointer_id();
1622
1623        let u32_ty = self.get_u32_type_id();
1624        let u32_ptr_ty = self.get_pointer_type_id(u32_ty, spirv::StorageClass::Function);
1625
1626        let rq_get_vertex_positions_ty_id = self.get_handle_type_id(
1627            *ir_module
1628                .special_types
1629                .ray_vertex_return
1630                .as_ref()
1631                .expect("must be generated when reading in get vertex position"),
1632        );
1633        let ptr_return_ty =
1634            self.get_pointer_type_id(rq_get_vertex_positions_ty_id, spirv::StorageClass::Function);
1635
1636        let bool_type_id = self.get_bool_type_id();
1637
1638        let (func_id, mut function, arg_ids) = self.write_function_signature(
1639            &[ray_query_type_id, u32_ptr_ty],
1640            rq_get_vertex_positions_ty_id,
1641        );
1642
1643        let query_id = arg_ids[0];
1644        let init_tracker_id = arg_ids[1];
1645
1646        let block_id = self.id_gen.next();
1647        let mut block = Block::new(block_id);
1648
1649        let return_id = self.id_gen.next();
1650        block.body.push(Instruction::variable(
1651            ptr_return_ty,
1652            return_id,
1653            spirv::StorageClass::Function,
1654            Some(self.get_constant_null(rq_get_vertex_positions_ty_id)),
1655        ));
1656
1657        let valid_id = self.id_gen.next();
1658        let mut valid_block = Block::new(valid_id);
1659
1660        let final_label_id = self.id_gen.next();
1661        let mut final_block = Block::new(final_label_id);
1662
1663        let instruction = if self.ray_query_initialization_tracking {
1664            let initialized_tracker_id = self.id_gen.next();
1665            block.body.push(Instruction::load(
1666                u32_ty,
1667                initialized_tracker_id,
1668                init_tracker_id,
1669                None,
1670            ));
1671
1672            let proceeded_id = write_ray_flags_contains_flags(
1673                self,
1674                &mut block,
1675                initialized_tracker_id,
1676                super::RayQueryPoint::PROCEED.bits(),
1677            );
1678            let finished_proceed_id = write_ray_flags_contains_flags(
1679                self,
1680                &mut block,
1681                initialized_tracker_id,
1682                super::RayQueryPoint::FINISHED_TRAVERSAL.bits(),
1683            );
1684
1685            let correct_finish_id = if is_committed {
1686                finished_proceed_id
1687            } else {
1688                let not_finished_id = self.id_gen.next();
1689                block.body.push(Instruction::unary(
1690                    spirv::Op::LogicalNot,
1691                    bool_type_id,
1692                    not_finished_id,
1693                    finished_proceed_id,
1694                ));
1695                not_finished_id
1696            };
1697
1698            let is_valid_id = self.write_logical_and(&mut block, correct_finish_id, proceeded_id);
1699            block.body.push(Instruction::selection_merge(
1700                final_label_id,
1701                spirv::SelectionControl::NONE,
1702            ));
1703            Instruction::branch_conditional(is_valid_id, valid_id, final_label_id)
1704        } else {
1705            Instruction::branch(valid_id)
1706        };
1707
1708        function.consume(block, instruction);
1709
1710        let intersection_id = self.get_constant_scalar(crate::Literal::U32(committed_ty));
1711        let raw_kind_id = self.id_gen.next();
1712        valid_block
1713            .body
1714            .push(Instruction::ray_query_get_intersection(
1715                spirv::Op::RayQueryGetIntersectionTypeKHR,
1716                u32_ty,
1717                raw_kind_id,
1718                query_id,
1719                intersection_id,
1720            ));
1721
1722        let candidate_tri_id = self.get_constant_scalar(crate::Literal::U32(committed_tri_ty));
1723        let intersection_tri_id = self.id_gen.next();
1724        valid_block.body.push(Instruction::binary(
1725            spirv::Op::IEqual,
1726            bool_type_id,
1727            intersection_tri_id,
1728            raw_kind_id,
1729            candidate_tri_id,
1730        ));
1731
1732        let generate_label_id = self.id_gen.next();
1733        let mut vertex_return_block = Block::new(generate_label_id);
1734
1735        let merge_label_id = self.id_gen.next();
1736        let merge_block = Block::new(merge_label_id);
1737
1738        valid_block.body.push(Instruction::selection_merge(
1739            merge_label_id,
1740            spirv::SelectionControl::NONE,
1741        ));
1742        function.consume(
1743            valid_block,
1744            Instruction::branch_conditional(intersection_tri_id, generate_label_id, merge_label_id),
1745        );
1746
1747        let vertices_id = self.id_gen.next();
1748        vertex_return_block
1749            .body
1750            .push(Instruction::ray_query_return_vertex_position(
1751                rq_get_vertex_positions_ty_id,
1752                vertices_id,
1753                query_id,
1754                intersection_id,
1755            ));
1756        vertex_return_block
1757            .body
1758            .push(Instruction::store(return_id, vertices_id, None));
1759
1760        function.consume(vertex_return_block, Instruction::branch(merge_label_id));
1761        function.consume(merge_block, Instruction::branch(final_label_id));
1762
1763        let loaded_pos_id = self.id_gen.next();
1764        final_block.body.push(Instruction::load(
1765            rq_get_vertex_positions_ty_id,
1766            loaded_pos_id,
1767            return_id,
1768            None,
1769        ));
1770
1771        function.consume(final_block, Instruction::return_value(loaded_pos_id));
1772
1773        self.ray_query_functions.insert(
1774            LookupRayQueryFunction::GetVertexPositions {
1775                committed: is_committed,
1776            },
1777            func_id,
1778        );
1779
1780        function.to_words(&mut self.logical_layout.function_definitions);
1781
1782        func_id
1783    }
1784}
1785
1786impl BlockContext<'_> {
1787    pub(super) fn write_ray_query_function(
1788        &mut self,
1789        query: Handle<crate::Expression>,
1790        function: &crate::RayQueryFunction,
1791        block: &mut Block,
1792    ) {
1793        let query_id = self.cached[query];
1794        let tracker_ids = *self
1795            .ray_query_tracker_expr
1796            .get(&query)
1797            .expect("not a cached ray query");
1798
1799        match *function {
1800            crate::RayQueryFunction::Initialize {
1801                acceleration_structure,
1802                descriptor,
1803            } => {
1804                let desc_id = self.cached[descriptor];
1805                let acc_struct_id = self.get_handle_id(acceleration_structure);
1806
1807                let func = self.writer.write_ray_query_initialize(self.ir_module);
1808
1809                let func_id = self.gen_id();
1810                block.body.push(Instruction::function_call(
1811                    self.writer.void_type,
1812                    func_id,
1813                    func,
1814                    &[
1815                        query_id,
1816                        acc_struct_id,
1817                        desc_id,
1818                        tracker_ids.initialized_tracker,
1819                        tracker_ids.t_max_tracker,
1820                    ],
1821                ));
1822            }
1823            crate::RayQueryFunction::Proceed { result } => {
1824                let id = self.gen_id();
1825                self.cached[result] = id;
1826
1827                let bool_ty = self.writer.get_bool_type_id();
1828
1829                let func_id = self.writer.write_ray_query_proceed();
1830                block.body.push(Instruction::function_call(
1831                    bool_ty,
1832                    id,
1833                    func_id,
1834                    &[query_id, tracker_ids.initialized_tracker],
1835                ));
1836            }
1837            crate::RayQueryFunction::GenerateIntersection { hit_t } => {
1838                let hit_id = self.cached[hit_t];
1839
1840                let func_id = self.writer.write_ray_query_generate_intersection();
1841
1842                let func_call_id = self.gen_id();
1843                block.body.push(Instruction::function_call(
1844                    self.writer.void_type,
1845                    func_call_id,
1846                    func_id,
1847                    &[
1848                        query_id,
1849                        tracker_ids.initialized_tracker,
1850                        hit_id,
1851                        tracker_ids.t_max_tracker,
1852                    ],
1853                ));
1854            }
1855            crate::RayQueryFunction::ConfirmIntersection => {
1856                let func_id = self.writer.write_ray_query_confirm_intersection();
1857
1858                let func_call_id = self.gen_id();
1859                block.body.push(Instruction::function_call(
1860                    self.writer.void_type,
1861                    func_call_id,
1862                    func_id,
1863                    &[query_id, tracker_ids.initialized_tracker],
1864                ));
1865            }
1866            crate::RayQueryFunction::Terminate => {}
1867        }
1868    }
1869
1870    pub(super) fn write_ray_query_return_vertex_position(
1871        &mut self,
1872        query: Handle<crate::Expression>,
1873        block: &mut Block,
1874        is_committed: bool,
1875    ) -> spirv::Word {
1876        let fn_id = self
1877            .writer
1878            .write_ray_query_get_vertex_positions(is_committed, self.ir_module);
1879
1880        let query_id = self.cached[query];
1881        let tracker_id = *self
1882            .ray_query_tracker_expr
1883            .get(&query)
1884            .expect("not a cached ray query");
1885
1886        let rq_get_vertex_positions_ty_id = self.get_handle_type_id(
1887            *self
1888                .ir_module
1889                .special_types
1890                .ray_vertex_return
1891                .as_ref()
1892                .expect("must be generated when reading in get vertex position"),
1893        );
1894
1895        let func_call_id = self.gen_id();
1896        block.body.push(Instruction::function_call(
1897            rq_get_vertex_positions_ty_id,
1898            func_call_id,
1899            fn_id,
1900            &[query_id, tracker_id.initialized_tracker],
1901        ));
1902        func_call_id
1903    }
1904}