naga/back/spv/
ray.rs

1/*!
2Generating SPIR-V for ray query operations.
3*/
4
5use alloc::{vec, vec::Vec};
6
7use super::{
8    Block, BlockContext, Function, FunctionArgument, Instruction, LookupFunctionType,
9    LookupRayQueryFunction, NumericType, Writer,
10};
11use crate::{arena::Handle, back::RayQueryPoint};
12
13/// helper function to check if a particular flag is set in a u32.
14fn write_ray_flags_contains_flags(
15    writer: &mut Writer,
16    block: &mut Block,
17    id: spirv::Word,
18    flag: u32,
19) -> spirv::Word {
20    let bit_id = writer.get_constant_scalar(crate::Literal::U32(flag));
21    let zero_id = writer.get_constant_scalar(crate::Literal::U32(0));
22    let u32_type_id = writer.get_u32_type_id();
23    let bool_ty = writer.get_bool_type_id();
24
25    let and_id = writer.id_gen.next();
26    block.body.push(Instruction::binary(
27        spirv::Op::BitwiseAnd,
28        u32_type_id,
29        and_id,
30        id,
31        bit_id,
32    ));
33
34    let eq_id = writer.id_gen.next();
35    block.body.push(Instruction::binary(
36        spirv::Op::INotEqual,
37        bool_ty,
38        eq_id,
39        and_id,
40        zero_id,
41    ));
42
43    eq_id
44}
45
46impl Writer {
47    /// writes a logical and of two scalar booleans
48    fn write_logical_and(
49        &mut self,
50        block: &mut Block,
51        one: spirv::Word,
52        two: spirv::Word,
53    ) -> spirv::Word {
54        let id = self.id_gen.next();
55        let bool_id = self.get_bool_type_id();
56        block.body.push(Instruction::binary(
57            spirv::Op::LogicalAnd,
58            bool_id,
59            id,
60            one,
61            two,
62        ));
63        id
64    }
65
66    fn write_reduce_and(&mut self, block: &mut Block, mut bools: Vec<spirv::Word>) -> spirv::Word {
67        // The combined `and`ed together of all of the bools up to this point.
68        let mut current_combined = bools.pop().unwrap();
69        for boolean in bools {
70            current_combined = self.write_logical_and(block, current_combined, boolean)
71        }
72        current_combined
73    }
74
75    // returns the id of the function, the function, and ids for its arguments.
76    fn write_function_signature(
77        &mut self,
78        arg_types: &[spirv::Word],
79        return_ty: spirv::Word,
80    ) -> (spirv::Word, Function, Vec<spirv::Word>) {
81        let func_ty = self.get_function_type(LookupFunctionType {
82            parameter_type_ids: Vec::from(arg_types),
83            return_type_id: return_ty,
84        });
85
86        let mut function = Function::default();
87        let func_id = self.id_gen.next();
88        function.signature = Some(Instruction::function(
89            return_ty,
90            func_id,
91            spirv::FunctionControl::empty(),
92            func_ty,
93        ));
94
95        let mut arg_ids = Vec::with_capacity(arg_types.len());
96
97        for (idx, &arg_ty) in arg_types.iter().enumerate() {
98            let id = self.id_gen.next();
99            let instruction = Instruction::function_parameter(arg_ty, id);
100            function.parameters.push(FunctionArgument {
101                instruction,
102                handle_id: idx as u32,
103            });
104            arg_ids.push(id);
105        }
106        (func_id, function, arg_ids)
107    }
108
109    pub(super) fn write_ray_query_get_intersection_function(
110        &mut self,
111        is_committed: bool,
112        ir_module: &crate::Module,
113    ) -> spirv::Word {
114        if let Some(&word) =
115            self.ray_query_functions
116                .get(&LookupRayQueryFunction::GetIntersection {
117                    committed: is_committed,
118                })
119        {
120            return word;
121        }
122        let ray_intersection = ir_module.special_types.ray_intersection.unwrap();
123        let intersection_type_id = self.get_handle_type_id(ray_intersection);
124        let intersection_pointer_type_id =
125            self.get_pointer_type_id(intersection_type_id, spirv::StorageClass::Function);
126
127        let flag_type_id = self.get_u32_type_id();
128        let flag_pointer_type_id =
129            self.get_pointer_type_id(flag_type_id, spirv::StorageClass::Function);
130
131        let transform_type_id = self.get_numeric_type_id(NumericType::Matrix {
132            columns: crate::VectorSize::Quad,
133            rows: crate::VectorSize::Tri,
134            scalar: crate::Scalar::F32,
135        });
136        let transform_pointer_type_id =
137            self.get_pointer_type_id(transform_type_id, spirv::StorageClass::Function);
138
139        let barycentrics_type_id = self.get_numeric_type_id(NumericType::Vector {
140            size: crate::VectorSize::Bi,
141            scalar: crate::Scalar::F32,
142        });
143        let barycentrics_pointer_type_id =
144            self.get_pointer_type_id(barycentrics_type_id, spirv::StorageClass::Function);
145
146        let bool_type_id = self.get_bool_type_id();
147        let bool_pointer_type_id =
148            self.get_pointer_type_id(bool_type_id, spirv::StorageClass::Function);
149
150        let scalar_type_id = self.get_f32_type_id();
151        let float_pointer_type_id = self.get_f32_pointer_type_id(spirv::StorageClass::Function);
152
153        let argument_type_id = self.get_ray_query_pointer_id();
154
155        let (func_id, mut function, arg_ids) = self.write_function_signature(
156            &[argument_type_id, flag_pointer_type_id],
157            intersection_type_id,
158        );
159
160        let query_id = arg_ids[0];
161        let intersection_tracker_id = arg_ids[1];
162
163        let label_id = self.id_gen.next();
164        let mut block = Block::new(label_id);
165
166        let blank_intersection = self.get_constant_null(intersection_type_id);
167        let blank_intersection_id = self.id_gen.next();
168        // This must be before everything else in the function.
169        block.body.push(Instruction::variable(
170            intersection_pointer_type_id,
171            blank_intersection_id,
172            spirv::StorageClass::Function,
173            Some(blank_intersection),
174        ));
175
176        let intersection_id = self.get_constant_scalar(crate::Literal::U32(if is_committed {
177            spirv::RayQueryIntersection::RayQueryCommittedIntersectionKHR
178        } else {
179            spirv::RayQueryIntersection::RayQueryCandidateIntersectionKHR
180        } as _));
181
182        let loaded_ray_query_tracker_id = self.id_gen.next();
183        block.body.push(Instruction::load(
184            flag_type_id,
185            loaded_ray_query_tracker_id,
186            intersection_tracker_id,
187            None,
188        ));
189        let proceeded_id = write_ray_flags_contains_flags(
190            self,
191            &mut block,
192            loaded_ray_query_tracker_id,
193            RayQueryPoint::PROCEED.bits(),
194        );
195        let finished_proceed_id = write_ray_flags_contains_flags(
196            self,
197            &mut block,
198            loaded_ray_query_tracker_id,
199            RayQueryPoint::FINISHED_TRAVERSAL.bits(),
200        );
201        let proceed_finished_correct_id = if is_committed {
202            finished_proceed_id
203        } else {
204            let not_finished_id = self.id_gen.next();
205            block.body.push(Instruction::unary(
206                spirv::Op::LogicalNot,
207                bool_type_id,
208                not_finished_id,
209                finished_proceed_id,
210            ));
211            not_finished_id
212        };
213
214        let is_valid_id =
215            self.write_logical_and(&mut block, proceed_finished_correct_id, proceeded_id);
216
217        let valid_id = self.id_gen.next();
218        let mut valid_block = Block::new(valid_id);
219
220        let final_label_id = self.id_gen.next();
221        let mut final_block = Block::new(final_label_id);
222
223        block.body.push(Instruction::selection_merge(
224            final_label_id,
225            spirv::SelectionControl::NONE,
226        ));
227        function.consume(
228            block,
229            Instruction::branch_conditional(is_valid_id, valid_id, final_label_id),
230        );
231
232        let raw_kind_id = self.id_gen.next();
233        valid_block
234            .body
235            .push(Instruction::ray_query_get_intersection(
236                spirv::Op::RayQueryGetIntersectionTypeKHR,
237                flag_type_id,
238                raw_kind_id,
239                query_id,
240                intersection_id,
241            ));
242        let kind_id = if is_committed {
243            // Nothing to do: the IR value matches `spirv::RayQueryCommittedIntersectionType`
244            raw_kind_id
245        } else {
246            // Remap from the candidate kind to IR
247            let condition_id = self.id_gen.next();
248            let committed_triangle_kind_id = self.get_constant_scalar(crate::Literal::U32(
249                spirv::RayQueryCandidateIntersectionType::RayQueryCandidateIntersectionTriangleKHR
250                    as _,
251            ));
252            valid_block.body.push(Instruction::binary(
253                spirv::Op::IEqual,
254                self.get_bool_type_id(),
255                condition_id,
256                raw_kind_id,
257                committed_triangle_kind_id,
258            ));
259            let kind_id = self.id_gen.next();
260            valid_block.body.push(Instruction::select(
261                flag_type_id,
262                kind_id,
263                condition_id,
264                self.get_constant_scalar(crate::Literal::U32(
265                    crate::RayQueryIntersection::Triangle as _,
266                )),
267                self.get_constant_scalar(crate::Literal::U32(
268                    crate::RayQueryIntersection::Aabb as _,
269                )),
270            ));
271            kind_id
272        };
273        let idx_id = self.get_index_constant(0);
274        let access_idx = self.id_gen.next();
275        valid_block.body.push(Instruction::access_chain(
276            flag_pointer_type_id,
277            access_idx,
278            blank_intersection_id,
279            &[idx_id],
280        ));
281        valid_block
282            .body
283            .push(Instruction::store(access_idx, kind_id, None));
284
285        let not_none_comp_id = self.id_gen.next();
286        let none_id =
287            self.get_constant_scalar(crate::Literal::U32(crate::RayQueryIntersection::None as _));
288        valid_block.body.push(Instruction::binary(
289            spirv::Op::INotEqual,
290            self.get_bool_type_id(),
291            not_none_comp_id,
292            kind_id,
293            none_id,
294        ));
295
296        let not_none_label_id = self.id_gen.next();
297        let mut not_none_block = Block::new(not_none_label_id);
298
299        let outer_merge_label_id = self.id_gen.next();
300        let outer_merge_block = Block::new(outer_merge_label_id);
301
302        valid_block.body.push(Instruction::selection_merge(
303            outer_merge_label_id,
304            spirv::SelectionControl::NONE,
305        ));
306        function.consume(
307            valid_block,
308            Instruction::branch_conditional(
309                not_none_comp_id,
310                not_none_label_id,
311                outer_merge_label_id,
312            ),
313        );
314
315        let instance_custom_index_id = self.id_gen.next();
316        not_none_block
317            .body
318            .push(Instruction::ray_query_get_intersection(
319                spirv::Op::RayQueryGetIntersectionInstanceCustomIndexKHR,
320                flag_type_id,
321                instance_custom_index_id,
322                query_id,
323                intersection_id,
324            ));
325        let instance_id = self.id_gen.next();
326        not_none_block
327            .body
328            .push(Instruction::ray_query_get_intersection(
329                spirv::Op::RayQueryGetIntersectionInstanceIdKHR,
330                flag_type_id,
331                instance_id,
332                query_id,
333                intersection_id,
334            ));
335        let sbt_record_offset_id = self.id_gen.next();
336        not_none_block
337            .body
338            .push(Instruction::ray_query_get_intersection(
339                spirv::Op::RayQueryGetIntersectionInstanceShaderBindingTableRecordOffsetKHR,
340                flag_type_id,
341                sbt_record_offset_id,
342                query_id,
343                intersection_id,
344            ));
345        let geometry_index_id = self.id_gen.next();
346        not_none_block
347            .body
348            .push(Instruction::ray_query_get_intersection(
349                spirv::Op::RayQueryGetIntersectionGeometryIndexKHR,
350                flag_type_id,
351                geometry_index_id,
352                query_id,
353                intersection_id,
354            ));
355        let primitive_index_id = self.id_gen.next();
356        not_none_block
357            .body
358            .push(Instruction::ray_query_get_intersection(
359                spirv::Op::RayQueryGetIntersectionPrimitiveIndexKHR,
360                flag_type_id,
361                primitive_index_id,
362                query_id,
363                intersection_id,
364            ));
365
366        //Note: there is also `OpRayQueryGetIntersectionCandidateAABBOpaqueKHR`,
367        // but it's not a property of an intersection.
368
369        let object_to_world_id = self.id_gen.next();
370        not_none_block
371            .body
372            .push(Instruction::ray_query_get_intersection(
373                spirv::Op::RayQueryGetIntersectionObjectToWorldKHR,
374                transform_type_id,
375                object_to_world_id,
376                query_id,
377                intersection_id,
378            ));
379        let world_to_object_id = self.id_gen.next();
380        not_none_block
381            .body
382            .push(Instruction::ray_query_get_intersection(
383                spirv::Op::RayQueryGetIntersectionWorldToObjectKHR,
384                transform_type_id,
385                world_to_object_id,
386                query_id,
387                intersection_id,
388            ));
389
390        // instance custom index
391        let idx_id = self.get_index_constant(2);
392        let access_idx = self.id_gen.next();
393        not_none_block.body.push(Instruction::access_chain(
394            flag_pointer_type_id,
395            access_idx,
396            blank_intersection_id,
397            &[idx_id],
398        ));
399        not_none_block.body.push(Instruction::store(
400            access_idx,
401            instance_custom_index_id,
402            None,
403        ));
404
405        // instance
406        let idx_id = self.get_index_constant(3);
407        let access_idx = self.id_gen.next();
408        not_none_block.body.push(Instruction::access_chain(
409            flag_pointer_type_id,
410            access_idx,
411            blank_intersection_id,
412            &[idx_id],
413        ));
414        not_none_block
415            .body
416            .push(Instruction::store(access_idx, instance_id, None));
417
418        let idx_id = self.get_index_constant(4);
419        let access_idx = self.id_gen.next();
420        not_none_block.body.push(Instruction::access_chain(
421            flag_pointer_type_id,
422            access_idx,
423            blank_intersection_id,
424            &[idx_id],
425        ));
426        not_none_block
427            .body
428            .push(Instruction::store(access_idx, sbt_record_offset_id, None));
429
430        let idx_id = self.get_index_constant(5);
431        let access_idx = self.id_gen.next();
432        not_none_block.body.push(Instruction::access_chain(
433            flag_pointer_type_id,
434            access_idx,
435            blank_intersection_id,
436            &[idx_id],
437        ));
438        not_none_block
439            .body
440            .push(Instruction::store(access_idx, geometry_index_id, None));
441
442        let idx_id = self.get_index_constant(6);
443        let access_idx = self.id_gen.next();
444        not_none_block.body.push(Instruction::access_chain(
445            flag_pointer_type_id,
446            access_idx,
447            blank_intersection_id,
448            &[idx_id],
449        ));
450        not_none_block
451            .body
452            .push(Instruction::store(access_idx, primitive_index_id, None));
453
454        let idx_id = self.get_index_constant(9);
455        let access_idx = self.id_gen.next();
456        not_none_block.body.push(Instruction::access_chain(
457            transform_pointer_type_id,
458            access_idx,
459            blank_intersection_id,
460            &[idx_id],
461        ));
462        not_none_block
463            .body
464            .push(Instruction::store(access_idx, object_to_world_id, None));
465
466        let idx_id = self.get_index_constant(10);
467        let access_idx = self.id_gen.next();
468        not_none_block.body.push(Instruction::access_chain(
469            transform_pointer_type_id,
470            access_idx,
471            blank_intersection_id,
472            &[idx_id],
473        ));
474        not_none_block
475            .body
476            .push(Instruction::store(access_idx, world_to_object_id, None));
477
478        let tri_comp_id = self.id_gen.next();
479        let tri_id = self.get_constant_scalar(crate::Literal::U32(
480            crate::RayQueryIntersection::Triangle as _,
481        ));
482        not_none_block.body.push(Instruction::binary(
483            spirv::Op::IEqual,
484            self.get_bool_type_id(),
485            tri_comp_id,
486            kind_id,
487            tri_id,
488        ));
489
490        let tri_label_id = self.id_gen.next();
491        let mut tri_block = Block::new(tri_label_id);
492
493        let merge_label_id = self.id_gen.next();
494        let merge_block = Block::new(merge_label_id);
495        // t
496        {
497            let block = if is_committed {
498                &mut not_none_block
499            } else {
500                &mut tri_block
501            };
502            let t_id = self.id_gen.next();
503            block.body.push(Instruction::ray_query_get_intersection(
504                spirv::Op::RayQueryGetIntersectionTKHR,
505                scalar_type_id,
506                t_id,
507                query_id,
508                intersection_id,
509            ));
510            let idx_id = self.get_index_constant(1);
511            let access_idx = self.id_gen.next();
512            block.body.push(Instruction::access_chain(
513                float_pointer_type_id,
514                access_idx,
515                blank_intersection_id,
516                &[idx_id],
517            ));
518            block.body.push(Instruction::store(access_idx, t_id, None));
519        }
520        not_none_block.body.push(Instruction::selection_merge(
521            merge_label_id,
522            spirv::SelectionControl::NONE,
523        ));
524        function.consume(
525            not_none_block,
526            Instruction::branch_conditional(not_none_comp_id, tri_label_id, merge_label_id),
527        );
528
529        let barycentrics_id = self.id_gen.next();
530        tri_block.body.push(Instruction::ray_query_get_intersection(
531            spirv::Op::RayQueryGetIntersectionBarycentricsKHR,
532            barycentrics_type_id,
533            barycentrics_id,
534            query_id,
535            intersection_id,
536        ));
537
538        let front_face_id = self.id_gen.next();
539        tri_block.body.push(Instruction::ray_query_get_intersection(
540            spirv::Op::RayQueryGetIntersectionFrontFaceKHR,
541            bool_type_id,
542            front_face_id,
543            query_id,
544            intersection_id,
545        ));
546
547        let idx_id = self.get_index_constant(7);
548        let access_idx = self.id_gen.next();
549        tri_block.body.push(Instruction::access_chain(
550            barycentrics_pointer_type_id,
551            access_idx,
552            blank_intersection_id,
553            &[idx_id],
554        ));
555        tri_block
556            .body
557            .push(Instruction::store(access_idx, barycentrics_id, None));
558
559        let idx_id = self.get_index_constant(8);
560        let access_idx = self.id_gen.next();
561        tri_block.body.push(Instruction::access_chain(
562            bool_pointer_type_id,
563            access_idx,
564            blank_intersection_id,
565            &[idx_id],
566        ));
567        tri_block
568            .body
569            .push(Instruction::store(access_idx, front_face_id, None));
570        function.consume(tri_block, Instruction::branch(merge_label_id));
571        function.consume(merge_block, Instruction::branch(outer_merge_label_id));
572        function.consume(outer_merge_block, Instruction::branch(final_label_id));
573
574        let loaded_blank_intersection_id = self.id_gen.next();
575        final_block.body.push(Instruction::load(
576            intersection_type_id,
577            loaded_blank_intersection_id,
578            blank_intersection_id,
579            None,
580        ));
581        function.consume(
582            final_block,
583            Instruction::return_value(loaded_blank_intersection_id),
584        );
585
586        function.to_words(&mut self.logical_layout.function_definitions);
587        self.ray_query_functions.insert(
588            LookupRayQueryFunction::GetIntersection {
589                committed: is_committed,
590            },
591            func_id,
592        );
593        func_id
594    }
595
596    fn write_ray_query_initialize(&mut self, ir_module: &crate::Module) -> spirv::Word {
597        if let Some(&word) = self
598            .ray_query_functions
599            .get(&LookupRayQueryFunction::Initialize)
600        {
601            return word;
602        }
603
604        let ray_query_type_id = self.get_ray_query_pointer_id();
605        let acceleration_structure_type_id =
606            self.get_localtype_id(super::LocalType::AccelerationStructure);
607        let ray_desc_type_id = self.get_handle_type_id(
608            ir_module
609                .special_types
610                .ray_desc
611                .expect("ray desc should be set if ray queries are being initialized"),
612        );
613
614        let u32_ty = self.get_u32_type_id();
615        let u32_ptr_ty = self.get_pointer_type_id(u32_ty, spirv::StorageClass::Function);
616
617        let f32_type_id = self.get_f32_type_id();
618        let f32_ptr_ty = self.get_pointer_type_id(f32_type_id, spirv::StorageClass::Function);
619
620        let bool_type_id = self.get_bool_type_id();
621        let bool_vec3_type_id = self.get_vec3_bool_type_id();
622
623        let (func_id, mut function, arg_ids) = self.write_function_signature(
624            &[
625                ray_query_type_id,
626                acceleration_structure_type_id,
627                ray_desc_type_id,
628                u32_ptr_ty,
629                f32_ptr_ty,
630            ],
631            self.void_type,
632        );
633
634        let query_id = arg_ids[0];
635        let acceleration_structure_id = arg_ids[1];
636        let desc_id = arg_ids[2];
637        let init_tracker_id = arg_ids[3];
638        let t_max_tracker_id = arg_ids[4];
639
640        let label_id = self.id_gen.next();
641        let mut block = Block::new(label_id);
642
643        let flag_type_id = self.get_numeric_type_id(NumericType::Scalar(crate::Scalar::U32));
644
645        //Note: composite extract indices and types must match `generate_ray_desc_type`
646        let ray_flags_id = self.id_gen.next();
647        block.body.push(Instruction::composite_extract(
648            flag_type_id,
649            ray_flags_id,
650            desc_id,
651            &[0],
652        ));
653        let cull_mask_id = self.id_gen.next();
654        block.body.push(Instruction::composite_extract(
655            flag_type_id,
656            cull_mask_id,
657            desc_id,
658            &[1],
659        ));
660
661        let tmin_id = self.id_gen.next();
662        block.body.push(Instruction::composite_extract(
663            f32_type_id,
664            tmin_id,
665            desc_id,
666            &[2],
667        ));
668        let tmax_id = self.id_gen.next();
669        block.body.push(Instruction::composite_extract(
670            f32_type_id,
671            tmax_id,
672            desc_id,
673            &[3],
674        ));
675        block
676            .body
677            .push(Instruction::store(t_max_tracker_id, tmax_id, None));
678
679        let vector_type_id = self.get_numeric_type_id(NumericType::Vector {
680            size: crate::VectorSize::Tri,
681            scalar: crate::Scalar::F32,
682        });
683        let ray_origin_id = self.id_gen.next();
684        block.body.push(Instruction::composite_extract(
685            vector_type_id,
686            ray_origin_id,
687            desc_id,
688            &[4],
689        ));
690        let ray_dir_id = self.id_gen.next();
691        block.body.push(Instruction::composite_extract(
692            vector_type_id,
693            ray_dir_id,
694            desc_id,
695            &[5],
696        ));
697
698        let valid_id = self.ray_query_initialization_tracking.then(||{
699            let tmin_le_tmax_id = self.id_gen.next();
700            // Check both that tmin is less than or equal to tmax (https://docs.vulkan.org/spec/latest/appendices/spirvenv.html#VUID-RuntimeSpirv-OpRayQueryInitializeKHR-06350)
701            // and implicitly that neither tmin or tmax are NaN (https://docs.vulkan.org/spec/latest/appendices/spirvenv.html#VUID-RuntimeSpirv-OpRayQueryInitializeKHR-06351)
702            // because this checks if tmin and tmax are ordered too (i.e: not NaN).
703            block.body.push(Instruction::binary(
704                spirv::Op::FOrdLessThanEqual,
705                bool_type_id,
706                tmin_le_tmax_id,
707                tmin_id,
708                tmax_id,
709            ));
710
711            // Check that tmin is greater than or equal to 0 (and
712            // therefore also tmax is too because it is greater than
713            // or equal to tmin) (https://docs.vulkan.org/spec/latest/appendices/spirvenv.html#VUID-RuntimeSpirv-OpRayQueryInitializeKHR-06349).
714            let tmin_ge_zero_id = self.id_gen.next();
715            let zero_id = self.get_constant_scalar(crate::Literal::F32(0.0));
716            block.body.push(Instruction::binary(
717                spirv::Op::FOrdGreaterThanEqual,
718                bool_type_id,
719                tmin_ge_zero_id,
720                tmin_id,
721                zero_id,
722            ));
723
724            // Check that ray origin is finite (https://docs.vulkan.org/spec/latest/appendices/spirvenv.html#VUID-RuntimeSpirv-OpRayQueryInitializeKHR-06348)
725            let ray_origin_infinite_id = self.id_gen.next();
726            block.body.push(Instruction::unary(
727                spirv::Op::IsInf,
728                bool_vec3_type_id,
729                ray_origin_infinite_id,
730                ray_origin_id,
731            ));
732            let any_ray_origin_infinite_id = self.id_gen.next();
733            block.body.push(Instruction::unary(
734                spirv::Op::Any,
735                bool_type_id,
736                any_ray_origin_infinite_id,
737                ray_origin_infinite_id,
738            ));
739
740            let ray_origin_nan_id = self.id_gen.next();
741            block.body.push(Instruction::unary(
742                spirv::Op::IsNan,
743                bool_vec3_type_id,
744                ray_origin_nan_id,
745                ray_origin_id,
746            ));
747            let any_ray_origin_nan_id = self.id_gen.next();
748            block.body.push(Instruction::unary(
749                spirv::Op::Any,
750                bool_type_id,
751                any_ray_origin_nan_id,
752                ray_origin_nan_id,
753            ));
754
755            let ray_origin_not_finite_id = self.id_gen.next();
756            block.body.push(Instruction::binary(
757                spirv::Op::LogicalOr,
758                bool_type_id,
759                ray_origin_not_finite_id,
760                any_ray_origin_nan_id,
761                any_ray_origin_infinite_id,
762            ));
763
764            let all_ray_origin_finite_id = self.id_gen.next();
765            block.body.push(Instruction::unary(
766                spirv::Op::LogicalNot,
767                bool_type_id,
768                all_ray_origin_finite_id,
769                ray_origin_not_finite_id,
770            ));
771
772            // Check that ray direction is finite (https://docs.vulkan.org/spec/latest/appendices/spirvenv.html#VUID-RuntimeSpirv-OpRayQueryInitializeKHR-06348)
773            let ray_dir_infinite_id = self.id_gen.next();
774            block.body.push(Instruction::unary(
775                spirv::Op::IsInf,
776                bool_vec3_type_id,
777                ray_dir_infinite_id,
778                ray_dir_id,
779            ));
780            let any_ray_dir_infinite_id = self.id_gen.next();
781            block.body.push(Instruction::unary(
782                spirv::Op::Any,
783                bool_type_id,
784                any_ray_dir_infinite_id,
785                ray_dir_infinite_id,
786            ));
787
788            let ray_dir_nan_id = self.id_gen.next();
789            block.body.push(Instruction::unary(
790                spirv::Op::IsNan,
791                bool_vec3_type_id,
792                ray_dir_nan_id,
793                ray_dir_id,
794            ));
795            let any_ray_dir_nan_id = self.id_gen.next();
796            block.body.push(Instruction::unary(
797                spirv::Op::Any,
798                bool_type_id,
799                any_ray_dir_nan_id,
800                ray_dir_nan_id,
801            ));
802
803            let ray_dir_not_finite_id = self.id_gen.next();
804            block.body.push(Instruction::binary(
805                spirv::Op::LogicalOr,
806                bool_type_id,
807                ray_dir_not_finite_id,
808                any_ray_dir_nan_id,
809                any_ray_dir_infinite_id,
810            ));
811
812            let all_ray_dir_finite_id = self.id_gen.next();
813            block.body.push(Instruction::unary(
814                spirv::Op::LogicalNot,
815                bool_type_id,
816                all_ray_dir_finite_id,
817                ray_dir_not_finite_id,
818            ));
819
820            /// Writes spirv to check that less than two booleans are true
821            ///
822            /// For each boolean: removes it, `and`s it with all others (i.e for all possible combinations of two booleans in the list checks to see if both are true).
823            /// Then `or`s all of these checks together. This produces whether two or more booleans are true.
824            fn write_less_than_2_true(
825                writer: &mut Writer,
826                block: &mut Block,
827                mut bools: Vec<spirv::Word>,
828            ) -> spirv::Word {
829                assert!(bools.len() > 1, "Must have multiple booleans!");
830                let bool_ty = writer.get_bool_type_id();
831                let mut each_two_true = Vec::new();
832                while let Some(last_bool) = bools.pop() {
833                    for &bool in &bools {
834                        let both_true_id = writer.write_logical_and(
835                            block,
836                            last_bool,
837                            bool,
838                        );
839                        each_two_true.push(both_true_id);
840                    }
841                }
842                let mut all_or_id = each_two_true.pop().expect("since this must have multiple booleans, there must be at least one thing in `each_two_true`");
843                for two_true in each_two_true {
844                    let new_all_or_id = writer.id_gen.next();
845                    block.body.push(Instruction::binary(
846                        spirv::Op::LogicalOr,
847                        bool_ty,
848                        new_all_or_id,
849                        all_or_id,
850                        two_true,
851                    ));
852                    all_or_id = new_all_or_id;
853                }
854
855                let less_than_two_id = writer.id_gen.next();
856                block.body.push(Instruction::unary(
857                    spirv::Op::LogicalNot,
858                    bool_ty,
859                    less_than_two_id,
860                    all_or_id,
861                ));
862                less_than_two_id
863            }
864
865            // Check that at most one of skip triangles and skip AABBs is
866            // present (https://docs.vulkan.org/spec/latest/appendices/spirvenv.html#VUID-RuntimeSpirv-OpRayQueryInitializeKHR-06889)
867            let contains_skip_triangles = write_ray_flags_contains_flags(
868                self,
869                &mut block,
870                ray_flags_id,
871                crate::RayFlag::SKIP_TRIANGLES.bits(),
872            );
873            let contains_skip_aabbs = write_ray_flags_contains_flags(
874                self,
875                &mut block,
876                ray_flags_id,
877                crate::RayFlag::SKIP_AABBS.bits(),
878            );
879
880            let not_contain_skip_triangles_aabbs = write_less_than_2_true(
881                self,
882                &mut block,
883                vec![contains_skip_triangles, contains_skip_aabbs],
884            );
885
886            // Check that at most one of skip triangles (taken from above check),
887            // cull back facing, and cull front face is present (https://docs.vulkan.org/spec/latest/appendices/spirvenv.html#VUID-RuntimeSpirv-OpRayQueryInitializeKHR-06890)
888            let contains_cull_back = write_ray_flags_contains_flags(
889                self,
890                &mut block,
891                ray_flags_id,
892                crate::RayFlag::CULL_BACK_FACING.bits(),
893            );
894            let contains_cull_front = write_ray_flags_contains_flags(
895                self,
896                &mut block,
897                ray_flags_id,
898                crate::RayFlag::CULL_FRONT_FACING.bits(),
899            );
900
901            let not_contain_skip_triangles_cull = write_less_than_2_true(
902                self,
903                &mut block,
904                vec![
905                    contains_skip_triangles,
906                    contains_cull_back,
907                    contains_cull_front,
908                ],
909            );
910
911            // Check that at most one of force opaque, force not opaque, cull opaque,
912            // and cull not opaque are present (https://docs.vulkan.org/spec/latest/appendices/spirvenv.html#VUID-RuntimeSpirv-OpRayQueryInitializeKHR-06891)
913            let contains_opaque = write_ray_flags_contains_flags(
914                self,
915                &mut block,
916                ray_flags_id,
917                crate::RayFlag::FORCE_OPAQUE.bits(),
918            );
919            let contains_no_opaque = write_ray_flags_contains_flags(
920                self,
921                &mut block,
922                ray_flags_id,
923                crate::RayFlag::FORCE_NO_OPAQUE.bits(),
924            );
925            let contains_cull_opaque = write_ray_flags_contains_flags(
926                self,
927                &mut block,
928                ray_flags_id,
929                crate::RayFlag::CULL_OPAQUE.bits(),
930            );
931            let contains_cull_no_opaque = write_ray_flags_contains_flags(
932                self,
933                &mut block,
934                ray_flags_id,
935                crate::RayFlag::CULL_NO_OPAQUE.bits(),
936            );
937
938            let not_contain_multiple_opaque = write_less_than_2_true(
939                self,
940                &mut block,
941                vec![
942                    contains_opaque,
943                    contains_no_opaque,
944                    contains_cull_opaque,
945                    contains_cull_no_opaque,
946                ],
947            );
948
949            // Combine all checks into a single flag saying whether the call is valid or not.
950            self.write_reduce_and(
951                &mut block,
952                vec![
953                    tmin_le_tmax_id,
954                    tmin_ge_zero_id,
955                    all_ray_origin_finite_id,
956                    all_ray_dir_finite_id,
957                    not_contain_skip_triangles_aabbs,
958                    not_contain_skip_triangles_cull,
959                    not_contain_multiple_opaque,
960                ],
961            )
962        });
963
964        let merge_label_id = self.id_gen.next();
965        let merge_block = Block::new(merge_label_id);
966
967        // NOTE: this block will be unreachable if initialization tracking is disabled.
968        let invalid_label_id = self.id_gen.next();
969        let mut invalid_block = Block::new(invalid_label_id);
970
971        let valid_label_id = self.id_gen.next();
972        let mut valid_block = Block::new(valid_label_id);
973
974        match valid_id {
975            Some(all_valid_id) => {
976                block.body.push(Instruction::selection_merge(
977                    merge_label_id,
978                    spirv::SelectionControl::NONE,
979                ));
980                function.consume(
981                    block,
982                    Instruction::branch_conditional(all_valid_id, valid_label_id, invalid_label_id),
983                );
984            }
985            None => {
986                function.consume(block, Instruction::branch(valid_label_id));
987            }
988        }
989
990        valid_block.body.push(Instruction::ray_query_initialize(
991            query_id,
992            acceleration_structure_id,
993            ray_flags_id,
994            cull_mask_id,
995            ray_origin_id,
996            tmin_id,
997            ray_dir_id,
998            tmax_id,
999        ));
1000
1001        let const_initialized =
1002            self.get_constant_scalar(crate::Literal::U32(RayQueryPoint::INITIALIZED.bits()));
1003        valid_block
1004            .body
1005            .push(Instruction::store(init_tracker_id, const_initialized, None));
1006
1007        function.consume(valid_block, Instruction::branch(merge_label_id));
1008
1009        if self
1010            .flags
1011            .contains(super::WriterFlags::PRINT_ON_RAY_QUERY_INITIALIZATION_FAIL)
1012        {
1013            self.write_debug_printf(
1014                &mut invalid_block,
1015                "Naga ignored invalid arguments to rayQueryInitialize with flags: %u t_min: %f t_max: %f origin: %v4f dir: %v4f",
1016                &[
1017                    ray_flags_id,
1018                    tmin_id,
1019                    tmax_id,
1020                    ray_origin_id,
1021                    ray_dir_id,
1022                ],
1023            );
1024        }
1025
1026        function.consume(invalid_block, Instruction::branch(merge_label_id));
1027
1028        function.consume(merge_block, Instruction::return_void());
1029
1030        function.to_words(&mut self.logical_layout.function_definitions);
1031
1032        self.ray_query_functions
1033            .insert(LookupRayQueryFunction::Initialize, func_id);
1034        func_id
1035    }
1036
1037    fn write_ray_query_proceed(&mut self) -> spirv::Word {
1038        if let Some(&word) = self
1039            .ray_query_functions
1040            .get(&LookupRayQueryFunction::Proceed)
1041        {
1042            return word;
1043        }
1044
1045        let ray_query_type_id = self.get_ray_query_pointer_id();
1046
1047        let u32_ty = self.get_u32_type_id();
1048        let u32_ptr_ty = self.get_pointer_type_id(u32_ty, spirv::StorageClass::Function);
1049
1050        let bool_type_id = self.get_bool_type_id();
1051        let bool_ptr_ty = self.get_pointer_type_id(bool_type_id, spirv::StorageClass::Function);
1052
1053        let (func_id, mut function, arg_ids) =
1054            self.write_function_signature(&[ray_query_type_id, u32_ptr_ty], bool_type_id);
1055
1056        let query_id = arg_ids[0];
1057        let init_tracker_id = arg_ids[1];
1058
1059        let block_id = self.id_gen.next();
1060        let mut block = Block::new(block_id);
1061
1062        // TODO: perhaps this could be replaced with an OpPhi?
1063        let proceeded_id = self.id_gen.next();
1064        let const_false = self.get_constant_scalar(crate::Literal::Bool(false));
1065        block.body.push(Instruction::variable(
1066            bool_ptr_ty,
1067            proceeded_id,
1068            spirv::StorageClass::Function,
1069            Some(const_false),
1070        ));
1071
1072        let initialized_tracker_id = self.id_gen.next();
1073        block.body.push(Instruction::load(
1074            u32_ty,
1075            initialized_tracker_id,
1076            init_tracker_id,
1077            None,
1078        ));
1079
1080        let merge_id = self.id_gen.next();
1081        let mut merge_block = Block::new(merge_id);
1082
1083        let valid_block_id = self.id_gen.next();
1084        let mut valid_block = Block::new(valid_block_id);
1085
1086        let instruction = if self.ray_query_initialization_tracking {
1087            let is_initialized = write_ray_flags_contains_flags(
1088                self,
1089                &mut block,
1090                initialized_tracker_id,
1091                RayQueryPoint::INITIALIZED.bits(),
1092            );
1093
1094            block.body.push(Instruction::selection_merge(
1095                merge_id,
1096                spirv::SelectionControl::NONE,
1097            ));
1098
1099            Instruction::branch_conditional(is_initialized, valid_block_id, merge_id)
1100        } else {
1101            Instruction::branch(valid_block_id)
1102        };
1103
1104        function.consume(block, instruction);
1105
1106        let has_proceeded = self.id_gen.next();
1107        valid_block.body.push(Instruction::ray_query_proceed(
1108            bool_type_id,
1109            has_proceeded,
1110            query_id,
1111        ));
1112
1113        valid_block
1114            .body
1115            .push(Instruction::store(proceeded_id, has_proceeded, None));
1116
1117        let add_flag_finished = self.get_constant_scalar(crate::Literal::U32(
1118            (RayQueryPoint::PROCEED | RayQueryPoint::FINISHED_TRAVERSAL).bits(),
1119        ));
1120        let add_flag_continuing =
1121            self.get_constant_scalar(crate::Literal::U32(RayQueryPoint::PROCEED.bits()));
1122
1123        let add_flags_id = self.id_gen.next();
1124        valid_block.body.push(Instruction::select(
1125            u32_ty,
1126            add_flags_id,
1127            has_proceeded,
1128            add_flag_continuing,
1129            add_flag_finished,
1130        ));
1131        let final_flags = self.id_gen.next();
1132        valid_block.body.push(Instruction::binary(
1133            spirv::Op::BitwiseOr,
1134            u32_ty,
1135            final_flags,
1136            initialized_tracker_id,
1137            add_flags_id,
1138        ));
1139        valid_block
1140            .body
1141            .push(Instruction::store(init_tracker_id, final_flags, None));
1142
1143        function.consume(valid_block, Instruction::branch(merge_id));
1144
1145        let loaded_proceeded_id = self.id_gen.next();
1146        merge_block.body.push(Instruction::load(
1147            bool_type_id,
1148            loaded_proceeded_id,
1149            proceeded_id,
1150            None,
1151        ));
1152
1153        function.consume(merge_block, Instruction::return_value(loaded_proceeded_id));
1154
1155        function.to_words(&mut self.logical_layout.function_definitions);
1156
1157        self.ray_query_functions
1158            .insert(LookupRayQueryFunction::Proceed, func_id);
1159        func_id
1160    }
1161
1162    fn write_ray_query_generate_intersection(&mut self) -> spirv::Word {
1163        if let Some(&word) = self
1164            .ray_query_functions
1165            .get(&LookupRayQueryFunction::GenerateIntersection)
1166        {
1167            return word;
1168        }
1169
1170        let ray_query_type_id = self.get_ray_query_pointer_id();
1171
1172        let u32_ty = self.get_u32_type_id();
1173        let u32_ptr_ty = self.get_pointer_type_id(u32_ty, spirv::StorageClass::Function);
1174
1175        let f32_type_id = self.get_f32_type_id();
1176        let f32_ptr_type_id = self.get_pointer_type_id(f32_type_id, spirv::StorageClass::Function);
1177
1178        let bool_type_id = self.get_bool_type_id();
1179
1180        let (func_id, mut function, arg_ids) = self.write_function_signature(
1181            &[ray_query_type_id, u32_ptr_ty, f32_type_id, f32_ptr_type_id],
1182            self.void_type,
1183        );
1184
1185        let query_id = arg_ids[0];
1186        let init_tracker_id = arg_ids[1];
1187        let depth_id = arg_ids[2];
1188        let t_max_tracker_id = arg_ids[3];
1189
1190        let block_id = self.id_gen.next();
1191        let mut block = Block::new(block_id);
1192
1193        let current_t = self.id_gen.next();
1194        block.body.push(Instruction::variable(
1195            f32_ptr_type_id,
1196            current_t,
1197            spirv::StorageClass::Function,
1198            None,
1199        ));
1200
1201        let current_t = self.id_gen.next();
1202        block.body.push(Instruction::variable(
1203            f32_ptr_type_id,
1204            current_t,
1205            spirv::StorageClass::Function,
1206            None,
1207        ));
1208
1209        let valid_id = self.id_gen.next();
1210        let mut valid_block = Block::new(valid_id);
1211
1212        let final_label_id = self.id_gen.next();
1213        let final_block = Block::new(final_label_id);
1214
1215        let instruction = if self.ray_query_initialization_tracking {
1216            let initialized_tracker_id = self.id_gen.next();
1217            block.body.push(Instruction::load(
1218                u32_ty,
1219                initialized_tracker_id,
1220                init_tracker_id,
1221                None,
1222            ));
1223
1224            let proceeded_id = write_ray_flags_contains_flags(
1225                self,
1226                &mut block,
1227                initialized_tracker_id,
1228                RayQueryPoint::PROCEED.bits(),
1229            );
1230            let finished_proceed_id = write_ray_flags_contains_flags(
1231                self,
1232                &mut block,
1233                initialized_tracker_id,
1234                RayQueryPoint::FINISHED_TRAVERSAL.bits(),
1235            );
1236
1237            // Can't find anything to suggest double calling this function is invalid.
1238
1239            let not_finished_id = self.id_gen.next();
1240            block.body.push(Instruction::unary(
1241                spirv::Op::LogicalNot,
1242                bool_type_id,
1243                not_finished_id,
1244                finished_proceed_id,
1245            ));
1246
1247            let is_valid_id = self.write_logical_and(&mut block, not_finished_id, proceeded_id);
1248
1249            block.body.push(Instruction::selection_merge(
1250                final_label_id,
1251                spirv::SelectionControl::NONE,
1252            ));
1253
1254            Instruction::branch_conditional(is_valid_id, valid_id, final_label_id)
1255        } else {
1256            Instruction::branch(valid_id)
1257        };
1258
1259        function.consume(block, instruction);
1260
1261        let intersection_id = self.get_constant_scalar(crate::Literal::U32(
1262            spirv::RayQueryIntersection::RayQueryCandidateIntersectionKHR as _,
1263        ));
1264        let committed_intersection_id = self.get_constant_scalar(crate::Literal::U32(
1265            spirv::RayQueryIntersection::RayQueryCommittedIntersectionKHR as _,
1266        ));
1267        let raw_kind_id = self.id_gen.next();
1268        valid_block
1269            .body
1270            .push(Instruction::ray_query_get_intersection(
1271                spirv::Op::RayQueryGetIntersectionTypeKHR,
1272                u32_ty,
1273                raw_kind_id,
1274                query_id,
1275                intersection_id,
1276            ));
1277
1278        let candidate_aabb_id = self.get_constant_scalar(crate::Literal::U32(
1279            spirv::RayQueryCandidateIntersectionType::RayQueryCandidateIntersectionAABBKHR as _,
1280        ));
1281        let intersection_aabb_id = self.id_gen.next();
1282        valid_block.body.push(Instruction::binary(
1283            spirv::Op::IEqual,
1284            bool_type_id,
1285            intersection_aabb_id,
1286            raw_kind_id,
1287            candidate_aabb_id,
1288        ));
1289
1290        // Check that the provided t value is between t min and the current committed
1291        // t value, (https://docs.vulkan.org/spec/latest/appendices/spirvenv.html#VUID-RuntimeSpirv-OpRayQueryGenerateIntersectionKHR-06353)
1292
1293        // Get the tmin
1294        let t_min_id = self.id_gen.next();
1295        valid_block.body.push(Instruction::ray_query_get_t_min(
1296            f32_type_id,
1297            t_min_id,
1298            query_id,
1299        ));
1300
1301        // Get the current committed t, or tmax if no hit.
1302        // Basically emulate HLSL's (easier) version
1303        // Pseudo-code:
1304        // ````wgsl
1305        // // start of function
1306        // var current_t:f32;
1307        // ...
1308        // let committed_type_id = RayQueryGetIntersectionTypeKHR<Committed>(query_id);
1309        // if committed_type_id == Committed_None {
1310        //     current_t = load(t_max_tracker);
1311        // } else {
1312        //     current_t = RayQueryGetIntersectionTKHR<Committed>(query_id);
1313        // }
1314        // ...
1315        // ````
1316
1317        let committed_type_id = self.id_gen.next();
1318        valid_block
1319            .body
1320            .push(Instruction::ray_query_get_intersection(
1321                spirv::Op::RayQueryGetIntersectionTypeKHR,
1322                u32_ty,
1323                committed_type_id,
1324                query_id,
1325                committed_intersection_id,
1326            ));
1327
1328        let no_committed = self.id_gen.next();
1329        valid_block.body.push(Instruction::binary(
1330            spirv::Op::IEqual,
1331            bool_type_id,
1332            no_committed,
1333            committed_type_id,
1334            self.get_constant_scalar(crate::Literal::U32(
1335                spirv::RayQueryCommittedIntersectionType::RayQueryCommittedIntersectionNoneKHR as _,
1336            )),
1337        ));
1338
1339        let next_valid_block_id = self.id_gen.next();
1340        let no_committed_block_id = self.id_gen.next();
1341        let mut no_committed_block = Block::new(no_committed_block_id);
1342        let committed_block_id = self.id_gen.next();
1343        let mut committed_block = Block::new(committed_block_id);
1344        valid_block.body.push(Instruction::selection_merge(
1345            next_valid_block_id,
1346            spirv::SelectionControl::NONE,
1347        ));
1348        function.consume(
1349            valid_block,
1350            Instruction::branch_conditional(
1351                no_committed,
1352                no_committed_block_id,
1353                committed_block_id,
1354            ),
1355        );
1356
1357        // Assign t_max to current_t
1358        let t_max_id = self.id_gen.next();
1359        no_committed_block.body.push(Instruction::load(
1360            f32_type_id,
1361            t_max_id,
1362            t_max_tracker_id,
1363            None,
1364        ));
1365        no_committed_block
1366            .body
1367            .push(Instruction::store(current_t, t_max_id, None));
1368        function.consume(no_committed_block, Instruction::branch(next_valid_block_id));
1369
1370        // Assign t_current to current_t
1371        let latest_t_id = self.id_gen.next();
1372        committed_block
1373            .body
1374            .push(Instruction::ray_query_get_intersection(
1375                spirv::Op::RayQueryGetIntersectionTKHR,
1376                f32_type_id,
1377                latest_t_id,
1378                query_id,
1379                intersection_id,
1380            ));
1381        committed_block
1382            .body
1383            .push(Instruction::store(current_t, latest_t_id, None));
1384        function.consume(committed_block, Instruction::branch(next_valid_block_id));
1385
1386        let mut valid_block = Block::new(next_valid_block_id);
1387
1388        let t_ge_t_min = self.id_gen.next();
1389        valid_block.body.push(Instruction::binary(
1390            spirv::Op::FOrdGreaterThanEqual,
1391            bool_type_id,
1392            t_ge_t_min,
1393            depth_id,
1394            t_min_id,
1395        ));
1396        let t_current = self.id_gen.next();
1397        valid_block
1398            .body
1399            .push(Instruction::load(f32_type_id, t_current, current_t, None));
1400        let t_le_t_current = self.id_gen.next();
1401        valid_block.body.push(Instruction::binary(
1402            spirv::Op::FOrdLessThanEqual,
1403            bool_type_id,
1404            t_le_t_current,
1405            depth_id,
1406            t_current,
1407        ));
1408
1409        let t_in_range = self.id_gen.next();
1410        valid_block.body.push(Instruction::binary(
1411            spirv::Op::LogicalAnd,
1412            bool_type_id,
1413            t_in_range,
1414            t_ge_t_min,
1415            t_le_t_current,
1416        ));
1417
1418        let call_valid_id = self.id_gen.next();
1419        valid_block.body.push(Instruction::binary(
1420            spirv::Op::LogicalAnd,
1421            bool_type_id,
1422            call_valid_id,
1423            t_in_range,
1424            intersection_aabb_id,
1425        ));
1426
1427        let generate_label_id = self.id_gen.next();
1428        let mut generate_block = Block::new(generate_label_id);
1429
1430        let merge_label_id = self.id_gen.next();
1431        let merge_block = Block::new(merge_label_id);
1432
1433        valid_block.body.push(Instruction::selection_merge(
1434            merge_label_id,
1435            spirv::SelectionControl::NONE,
1436        ));
1437        function.consume(
1438            valid_block,
1439            Instruction::branch_conditional(call_valid_id, generate_label_id, merge_label_id),
1440        );
1441
1442        generate_block
1443            .body
1444            .push(Instruction::ray_query_generate_intersection(
1445                query_id, depth_id,
1446            ));
1447
1448        function.consume(generate_block, Instruction::branch(merge_label_id));
1449        function.consume(merge_block, Instruction::branch(final_label_id));
1450
1451        function.consume(final_block, Instruction::return_void());
1452
1453        function.to_words(&mut self.logical_layout.function_definitions);
1454
1455        self.ray_query_functions
1456            .insert(LookupRayQueryFunction::GenerateIntersection, func_id);
1457        func_id
1458    }
1459
1460    fn write_ray_query_confirm_intersection(&mut self) -> spirv::Word {
1461        if let Some(&word) = self
1462            .ray_query_functions
1463            .get(&LookupRayQueryFunction::ConfirmIntersection)
1464        {
1465            return word;
1466        }
1467
1468        let ray_query_type_id = self.get_ray_query_pointer_id();
1469
1470        let u32_ty = self.get_u32_type_id();
1471        let u32_ptr_ty = self.get_pointer_type_id(u32_ty, spirv::StorageClass::Function);
1472
1473        let bool_type_id = self.get_bool_type_id();
1474
1475        let (func_id, mut function, arg_ids) =
1476            self.write_function_signature(&[ray_query_type_id, u32_ptr_ty], self.void_type);
1477
1478        let query_id = arg_ids[0];
1479        let init_tracker_id = arg_ids[1];
1480
1481        let block_id = self.id_gen.next();
1482        let mut block = Block::new(block_id);
1483
1484        let valid_id = self.id_gen.next();
1485        let mut valid_block = Block::new(valid_id);
1486
1487        let final_label_id = self.id_gen.next();
1488        let final_block = Block::new(final_label_id);
1489
1490        let instruction = if self.ray_query_initialization_tracking {
1491            let initialized_tracker_id = self.id_gen.next();
1492            block.body.push(Instruction::load(
1493                u32_ty,
1494                initialized_tracker_id,
1495                init_tracker_id,
1496                None,
1497            ));
1498
1499            let proceeded_id = write_ray_flags_contains_flags(
1500                self,
1501                &mut block,
1502                initialized_tracker_id,
1503                RayQueryPoint::PROCEED.bits(),
1504            );
1505            let finished_proceed_id = write_ray_flags_contains_flags(
1506                self,
1507                &mut block,
1508                initialized_tracker_id,
1509                RayQueryPoint::FINISHED_TRAVERSAL.bits(),
1510            );
1511            // Although it seems strange to call this twice, I (Vecvec) can't find anything to suggest double calling this function is invalid.
1512            let not_finished_id = self.id_gen.next();
1513            block.body.push(Instruction::unary(
1514                spirv::Op::LogicalNot,
1515                bool_type_id,
1516                not_finished_id,
1517                finished_proceed_id,
1518            ));
1519
1520            let is_valid_id = self.write_logical_and(&mut block, not_finished_id, proceeded_id);
1521
1522            block.body.push(Instruction::selection_merge(
1523                final_label_id,
1524                spirv::SelectionControl::NONE,
1525            ));
1526
1527            Instruction::branch_conditional(is_valid_id, valid_id, final_label_id)
1528        } else {
1529            Instruction::branch(valid_id)
1530        };
1531
1532        function.consume(block, instruction);
1533
1534        let intersection_id = self.get_constant_scalar(crate::Literal::U32(
1535            spirv::RayQueryIntersection::RayQueryCandidateIntersectionKHR as _,
1536        ));
1537        let raw_kind_id = self.id_gen.next();
1538        valid_block
1539            .body
1540            .push(Instruction::ray_query_get_intersection(
1541                spirv::Op::RayQueryGetIntersectionTypeKHR,
1542                u32_ty,
1543                raw_kind_id,
1544                query_id,
1545                intersection_id,
1546            ));
1547
1548        let candidate_tri_id = self.get_constant_scalar(crate::Literal::U32(
1549            spirv::RayQueryCandidateIntersectionType::RayQueryCandidateIntersectionTriangleKHR as _,
1550        ));
1551        let intersection_tri_id = self.id_gen.next();
1552        valid_block.body.push(Instruction::binary(
1553            spirv::Op::IEqual,
1554            bool_type_id,
1555            intersection_tri_id,
1556            raw_kind_id,
1557            candidate_tri_id,
1558        ));
1559
1560        let generate_label_id = self.id_gen.next();
1561        let mut generate_block = Block::new(generate_label_id);
1562
1563        let merge_label_id = self.id_gen.next();
1564        let merge_block = Block::new(merge_label_id);
1565
1566        valid_block.body.push(Instruction::selection_merge(
1567            merge_label_id,
1568            spirv::SelectionControl::NONE,
1569        ));
1570        function.consume(
1571            valid_block,
1572            Instruction::branch_conditional(intersection_tri_id, generate_label_id, merge_label_id),
1573        );
1574
1575        generate_block
1576            .body
1577            .push(Instruction::ray_query_confirm_intersection(query_id));
1578
1579        function.consume(generate_block, Instruction::branch(merge_label_id));
1580        function.consume(merge_block, Instruction::branch(final_label_id));
1581
1582        function.consume(final_block, Instruction::return_void());
1583
1584        self.ray_query_functions
1585            .insert(LookupRayQueryFunction::ConfirmIntersection, func_id);
1586
1587        function.to_words(&mut self.logical_layout.function_definitions);
1588
1589        func_id
1590    }
1591
1592    fn write_ray_query_get_vertex_positions(
1593        &mut self,
1594        is_committed: bool,
1595        ir_module: &crate::Module,
1596    ) -> spirv::Word {
1597        if let Some(&word) =
1598            self.ray_query_functions
1599                .get(&LookupRayQueryFunction::GetVertexPositions {
1600                    committed: is_committed,
1601                })
1602        {
1603            return word;
1604        }
1605
1606        let (committed_ty, committed_tri_ty) = if is_committed {
1607            (
1608                spirv::RayQueryIntersection::RayQueryCommittedIntersectionKHR as u32,
1609                spirv::RayQueryCommittedIntersectionType::RayQueryCommittedIntersectionTriangleKHR
1610                    as u32,
1611            )
1612        } else {
1613            (
1614                spirv::RayQueryIntersection::RayQueryCandidateIntersectionKHR as u32,
1615                spirv::RayQueryCandidateIntersectionType::RayQueryCandidateIntersectionTriangleKHR
1616                    as u32,
1617            )
1618        };
1619
1620        let ray_query_type_id = self.get_ray_query_pointer_id();
1621
1622        let u32_ty = self.get_u32_type_id();
1623        let u32_ptr_ty = self.get_pointer_type_id(u32_ty, spirv::StorageClass::Function);
1624
1625        let rq_get_vertex_positions_ty_id = self.get_handle_type_id(
1626            *ir_module
1627                .special_types
1628                .ray_vertex_return
1629                .as_ref()
1630                .expect("must be generated when reading in get vertex position"),
1631        );
1632        let ptr_return_ty =
1633            self.get_pointer_type_id(rq_get_vertex_positions_ty_id, spirv::StorageClass::Function);
1634
1635        let bool_type_id = self.get_bool_type_id();
1636
1637        let (func_id, mut function, arg_ids) = self.write_function_signature(
1638            &[ray_query_type_id, u32_ptr_ty],
1639            rq_get_vertex_positions_ty_id,
1640        );
1641
1642        let query_id = arg_ids[0];
1643        let init_tracker_id = arg_ids[1];
1644
1645        let block_id = self.id_gen.next();
1646        let mut block = Block::new(block_id);
1647
1648        let return_id = self.id_gen.next();
1649        block.body.push(Instruction::variable(
1650            ptr_return_ty,
1651            return_id,
1652            spirv::StorageClass::Function,
1653            Some(self.get_constant_null(rq_get_vertex_positions_ty_id)),
1654        ));
1655
1656        let valid_id = self.id_gen.next();
1657        let mut valid_block = Block::new(valid_id);
1658
1659        let final_label_id = self.id_gen.next();
1660        let mut final_block = Block::new(final_label_id);
1661
1662        let instruction = if self.ray_query_initialization_tracking {
1663            let initialized_tracker_id = self.id_gen.next();
1664            block.body.push(Instruction::load(
1665                u32_ty,
1666                initialized_tracker_id,
1667                init_tracker_id,
1668                None,
1669            ));
1670
1671            let proceeded_id = write_ray_flags_contains_flags(
1672                self,
1673                &mut block,
1674                initialized_tracker_id,
1675                RayQueryPoint::PROCEED.bits(),
1676            );
1677            let finished_proceed_id = write_ray_flags_contains_flags(
1678                self,
1679                &mut block,
1680                initialized_tracker_id,
1681                RayQueryPoint::FINISHED_TRAVERSAL.bits(),
1682            );
1683
1684            let correct_finish_id = if is_committed {
1685                finished_proceed_id
1686            } else {
1687                let not_finished_id = self.id_gen.next();
1688                block.body.push(Instruction::unary(
1689                    spirv::Op::LogicalNot,
1690                    bool_type_id,
1691                    not_finished_id,
1692                    finished_proceed_id,
1693                ));
1694                not_finished_id
1695            };
1696
1697            let is_valid_id = self.write_logical_and(&mut block, correct_finish_id, proceeded_id);
1698            block.body.push(Instruction::selection_merge(
1699                final_label_id,
1700                spirv::SelectionControl::NONE,
1701            ));
1702            Instruction::branch_conditional(is_valid_id, valid_id, final_label_id)
1703        } else {
1704            Instruction::branch(valid_id)
1705        };
1706
1707        function.consume(block, instruction);
1708
1709        let intersection_id = self.get_constant_scalar(crate::Literal::U32(committed_ty));
1710        let raw_kind_id = self.id_gen.next();
1711        valid_block
1712            .body
1713            .push(Instruction::ray_query_get_intersection(
1714                spirv::Op::RayQueryGetIntersectionTypeKHR,
1715                u32_ty,
1716                raw_kind_id,
1717                query_id,
1718                intersection_id,
1719            ));
1720
1721        let candidate_tri_id = self.get_constant_scalar(crate::Literal::U32(committed_tri_ty));
1722        let intersection_tri_id = self.id_gen.next();
1723        valid_block.body.push(Instruction::binary(
1724            spirv::Op::IEqual,
1725            bool_type_id,
1726            intersection_tri_id,
1727            raw_kind_id,
1728            candidate_tri_id,
1729        ));
1730
1731        let generate_label_id = self.id_gen.next();
1732        let mut vertex_return_block = Block::new(generate_label_id);
1733
1734        let merge_label_id = self.id_gen.next();
1735        let merge_block = Block::new(merge_label_id);
1736
1737        valid_block.body.push(Instruction::selection_merge(
1738            merge_label_id,
1739            spirv::SelectionControl::NONE,
1740        ));
1741        function.consume(
1742            valid_block,
1743            Instruction::branch_conditional(intersection_tri_id, generate_label_id, merge_label_id),
1744        );
1745
1746        let vertices_id = self.id_gen.next();
1747        vertex_return_block
1748            .body
1749            .push(Instruction::ray_query_return_vertex_position(
1750                rq_get_vertex_positions_ty_id,
1751                vertices_id,
1752                query_id,
1753                intersection_id,
1754            ));
1755        vertex_return_block
1756            .body
1757            .push(Instruction::store(return_id, vertices_id, None));
1758
1759        function.consume(vertex_return_block, Instruction::branch(merge_label_id));
1760        function.consume(merge_block, Instruction::branch(final_label_id));
1761
1762        let loaded_pos_id = self.id_gen.next();
1763        final_block.body.push(Instruction::load(
1764            rq_get_vertex_positions_ty_id,
1765            loaded_pos_id,
1766            return_id,
1767            None,
1768        ));
1769
1770        function.consume(final_block, Instruction::return_value(loaded_pos_id));
1771
1772        self.ray_query_functions.insert(
1773            LookupRayQueryFunction::GetVertexPositions {
1774                committed: is_committed,
1775            },
1776            func_id,
1777        );
1778
1779        function.to_words(&mut self.logical_layout.function_definitions);
1780
1781        func_id
1782    }
1783
1784    fn write_ray_query_terminate(&mut self) -> spirv::Word {
1785        if let Some(&word) = self
1786            .ray_query_functions
1787            .get(&LookupRayQueryFunction::Terminate)
1788        {
1789            return word;
1790        }
1791
1792        let ray_query_type_id = self.get_ray_query_pointer_id();
1793
1794        let u32_ty = self.get_u32_type_id();
1795        let u32_ptr_ty = self.get_pointer_type_id(u32_ty, spirv::StorageClass::Function);
1796
1797        let bool_type_id = self.get_bool_type_id();
1798
1799        let (func_id, mut function, arg_ids) =
1800            self.write_function_signature(&[ray_query_type_id, u32_ptr_ty], self.void_type);
1801
1802        let query_id = arg_ids[0];
1803        let init_tracker_id = arg_ids[1];
1804
1805        let block_id = self.id_gen.next();
1806        let mut block = Block::new(block_id);
1807
1808        let initialized_tracker_id = self.id_gen.next();
1809        block.body.push(Instruction::load(
1810            u32_ty,
1811            initialized_tracker_id,
1812            init_tracker_id,
1813            None,
1814        ));
1815
1816        let merge_id = self.id_gen.next();
1817        let merge_block = Block::new(merge_id);
1818
1819        let valid_block_id = self.id_gen.next();
1820        let mut valid_block = Block::new(valid_block_id);
1821
1822        let instruction = if self.ray_query_initialization_tracking {
1823            let has_proceeded = write_ray_flags_contains_flags(
1824                self,
1825                &mut block,
1826                initialized_tracker_id,
1827                RayQueryPoint::PROCEED.bits(),
1828            );
1829
1830            let finished_proceed_id = write_ray_flags_contains_flags(
1831                self,
1832                &mut block,
1833                initialized_tracker_id,
1834                RayQueryPoint::FINISHED_TRAVERSAL.bits(),
1835            );
1836
1837            let not_finished_id = self.id_gen.next();
1838            block.body.push(Instruction::unary(
1839                spirv::Op::LogicalNot,
1840                bool_type_id,
1841                not_finished_id,
1842                finished_proceed_id,
1843            ));
1844
1845            let valid_call = self.write_logical_and(&mut block, not_finished_id, has_proceeded);
1846
1847            block.body.push(Instruction::selection_merge(
1848                merge_id,
1849                spirv::SelectionControl::NONE,
1850            ));
1851
1852            Instruction::branch_conditional(valid_call, valid_block_id, merge_id)
1853        } else {
1854            Instruction::branch(valid_block_id)
1855        };
1856
1857        function.consume(block, instruction);
1858
1859        valid_block
1860            .body
1861            .push(Instruction::ray_query_terminate(query_id));
1862
1863        function.consume(valid_block, Instruction::branch(merge_id));
1864
1865        function.consume(merge_block, Instruction::return_void());
1866
1867        function.to_words(&mut self.logical_layout.function_definitions);
1868
1869        self.ray_query_functions
1870            .insert(LookupRayQueryFunction::Proceed, func_id);
1871        func_id
1872    }
1873}
1874
1875impl BlockContext<'_> {
1876    pub(super) fn write_ray_query_function(
1877        &mut self,
1878        query: Handle<crate::Expression>,
1879        function: &crate::RayQueryFunction,
1880        block: &mut Block,
1881    ) {
1882        let query_id = self.cached[query];
1883        let tracker_ids = *self
1884            .ray_query_tracker_expr
1885            .get(&query)
1886            .expect("not a cached ray query");
1887
1888        match *function {
1889            crate::RayQueryFunction::Initialize {
1890                acceleration_structure,
1891                descriptor,
1892            } => {
1893                let desc_id = self.cached[descriptor];
1894                let acc_struct_id = self.get_handle_id(acceleration_structure);
1895
1896                let func = self.writer.write_ray_query_initialize(self.ir_module);
1897
1898                let func_id = self.gen_id();
1899                block.body.push(Instruction::function_call(
1900                    self.writer.void_type,
1901                    func_id,
1902                    func,
1903                    &[
1904                        query_id,
1905                        acc_struct_id,
1906                        desc_id,
1907                        tracker_ids.initialized_tracker,
1908                        tracker_ids.t_max_tracker,
1909                    ],
1910                ));
1911            }
1912            crate::RayQueryFunction::Proceed { result } => {
1913                let id = self.gen_id();
1914                self.cached[result] = id;
1915
1916                let bool_ty = self.writer.get_bool_type_id();
1917
1918                let func_id = self.writer.write_ray_query_proceed();
1919                block.body.push(Instruction::function_call(
1920                    bool_ty,
1921                    id,
1922                    func_id,
1923                    &[query_id, tracker_ids.initialized_tracker],
1924                ));
1925            }
1926            crate::RayQueryFunction::GenerateIntersection { hit_t } => {
1927                let hit_id = self.cached[hit_t];
1928
1929                let func_id = self.writer.write_ray_query_generate_intersection();
1930
1931                let func_call_id = self.gen_id();
1932                block.body.push(Instruction::function_call(
1933                    self.writer.void_type,
1934                    func_call_id,
1935                    func_id,
1936                    &[
1937                        query_id,
1938                        tracker_ids.initialized_tracker,
1939                        hit_id,
1940                        tracker_ids.t_max_tracker,
1941                    ],
1942                ));
1943            }
1944            crate::RayQueryFunction::ConfirmIntersection => {
1945                let func_id = self.writer.write_ray_query_confirm_intersection();
1946
1947                let func_call_id = self.gen_id();
1948                block.body.push(Instruction::function_call(
1949                    self.writer.void_type,
1950                    func_call_id,
1951                    func_id,
1952                    &[query_id, tracker_ids.initialized_tracker],
1953                ));
1954            }
1955            crate::RayQueryFunction::Terminate => {
1956                let id = self.gen_id();
1957
1958                let func_id = self.writer.write_ray_query_terminate();
1959                block.body.push(Instruction::function_call(
1960                    self.writer.void_type,
1961                    id,
1962                    func_id,
1963                    &[query_id, tracker_ids.initialized_tracker],
1964                ));
1965            }
1966        }
1967    }
1968
1969    pub(super) fn write_ray_query_return_vertex_position(
1970        &mut self,
1971        query: Handle<crate::Expression>,
1972        block: &mut Block,
1973        is_committed: bool,
1974    ) -> spirv::Word {
1975        let fn_id = self
1976            .writer
1977            .write_ray_query_get_vertex_positions(is_committed, self.ir_module);
1978
1979        let query_id = self.cached[query];
1980        let tracker_id = *self
1981            .ray_query_tracker_expr
1982            .get(&query)
1983            .expect("not a cached ray query");
1984
1985        let rq_get_vertex_positions_ty_id = self.get_handle_type_id(
1986            *self
1987                .ir_module
1988                .special_types
1989                .ray_vertex_return
1990                .as_ref()
1991                .expect("must be generated when reading in get vertex position"),
1992        );
1993
1994        let func_call_id = self.gen_id();
1995        block.body.push(Instruction::function_call(
1996            rq_get_vertex_positions_ty_id,
1997            func_call_id,
1998            fn_id,
1999            &[query_id, tracker_id.initialized_tracker],
2000        ));
2001        func_call_id
2002    }
2003}