naga/back/spv/ray/
query.rs

1/*!
2Generating SPIR-V for ray query operations.
3*/
4
5use super::{
6    super::{
7        Block, BlockContext, Instruction, LocalType, LookupRayQueryFunction, NumericType, Writer,
8        WriterFlags,
9    },
10    write_ray_flags_contains_flags,
11};
12use crate::{arena::Handle, back::RayQueryPoint};
13
14impl Writer {
15    pub(in super::super) fn write_ray_query_get_intersection_function(
16        &mut self,
17        is_committed: bool,
18        ir_module: &crate::Module,
19    ) -> spirv::Word {
20        if let Some(&word) =
21            self.ray_query_functions
22                .get(&LookupRayQueryFunction::GetIntersection {
23                    committed: is_committed,
24                })
25        {
26            return word;
27        }
28        let ray_intersection = ir_module.special_types.ray_intersection.unwrap();
29        let intersection_type_id = self.get_handle_type_id(ray_intersection);
30        let intersection_pointer_type_id =
31            self.get_pointer_type_id(intersection_type_id, spirv::StorageClass::Function);
32
33        let flag_type_id = self.get_u32_type_id();
34        let flag_pointer_type_id =
35            self.get_pointer_type_id(flag_type_id, spirv::StorageClass::Function);
36
37        let transform_type_id = self.get_numeric_type_id(NumericType::Matrix {
38            columns: crate::VectorSize::Quad,
39            rows: crate::VectorSize::Tri,
40            scalar: crate::Scalar::F32,
41        });
42        let transform_pointer_type_id =
43            self.get_pointer_type_id(transform_type_id, spirv::StorageClass::Function);
44
45        let barycentrics_type_id = self.get_numeric_type_id(NumericType::Vector {
46            size: crate::VectorSize::Bi,
47            scalar: crate::Scalar::F32,
48        });
49        let barycentrics_pointer_type_id =
50            self.get_pointer_type_id(barycentrics_type_id, spirv::StorageClass::Function);
51
52        let bool_type_id = self.get_bool_type_id();
53        let bool_pointer_type_id =
54            self.get_pointer_type_id(bool_type_id, spirv::StorageClass::Function);
55
56        let scalar_type_id = self.get_f32_type_id();
57        let float_pointer_type_id = self.get_f32_pointer_type_id(spirv::StorageClass::Function);
58
59        let argument_type_id = self.get_ray_query_pointer_id();
60
61        let (func_id, mut function, arg_ids) = self.write_function_signature(
62            &[argument_type_id, flag_pointer_type_id],
63            intersection_type_id,
64        );
65
66        let query_id = arg_ids[0];
67        let intersection_tracker_id = arg_ids[1];
68
69        let label_id = self.id_gen.next();
70        let mut block = Block::new(label_id);
71
72        let blank_intersection = self.get_constant_null(intersection_type_id);
73        let blank_intersection_id = self.id_gen.next();
74        // This must be before everything else in the function.
75        block.body.push(Instruction::variable(
76            intersection_pointer_type_id,
77            blank_intersection_id,
78            spirv::StorageClass::Function,
79            Some(blank_intersection),
80        ));
81
82        let intersection_id = self.get_constant_scalar(crate::Literal::U32(if is_committed {
83            spirv::RayQueryIntersection::RayQueryCommittedIntersectionKHR
84        } else {
85            spirv::RayQueryIntersection::RayQueryCandidateIntersectionKHR
86        } as _));
87
88        let loaded_ray_query_tracker_id = self.id_gen.next();
89        block.body.push(Instruction::load(
90            flag_type_id,
91            loaded_ray_query_tracker_id,
92            intersection_tracker_id,
93            None,
94        ));
95        let proceeded_id = write_ray_flags_contains_flags(
96            self,
97            &mut block,
98            loaded_ray_query_tracker_id,
99            RayQueryPoint::PROCEED.bits(),
100        );
101        let finished_proceed_id = write_ray_flags_contains_flags(
102            self,
103            &mut block,
104            loaded_ray_query_tracker_id,
105            RayQueryPoint::FINISHED_TRAVERSAL.bits(),
106        );
107        let proceed_finished_correct_id = if is_committed {
108            finished_proceed_id
109        } else {
110            let not_finished_id = self.id_gen.next();
111            block.body.push(Instruction::unary(
112                spirv::Op::LogicalNot,
113                bool_type_id,
114                not_finished_id,
115                finished_proceed_id,
116            ));
117            not_finished_id
118        };
119
120        let is_valid_id =
121            self.write_logical_and(&mut block, proceed_finished_correct_id, proceeded_id);
122
123        let valid_id = self.id_gen.next();
124        let mut valid_block = Block::new(valid_id);
125
126        let final_label_id = self.id_gen.next();
127        let mut final_block = Block::new(final_label_id);
128
129        block.body.push(Instruction::selection_merge(
130            final_label_id,
131            spirv::SelectionControl::NONE,
132        ));
133        function.consume(
134            block,
135            Instruction::branch_conditional(is_valid_id, valid_id, final_label_id),
136        );
137
138        let raw_kind_id = self.id_gen.next();
139        valid_block
140            .body
141            .push(Instruction::ray_query_get_intersection(
142                spirv::Op::RayQueryGetIntersectionTypeKHR,
143                flag_type_id,
144                raw_kind_id,
145                query_id,
146                intersection_id,
147            ));
148        let kind_id = if is_committed {
149            // Nothing to do: the IR value matches `spirv::RayQueryCommittedIntersectionType`
150            raw_kind_id
151        } else {
152            // Remap from the candidate kind to IR
153            let condition_id = self.id_gen.next();
154            let committed_triangle_kind_id = self.get_constant_scalar(crate::Literal::U32(
155                spirv::RayQueryCandidateIntersectionType::RayQueryCandidateIntersectionTriangleKHR
156                    as _,
157            ));
158            valid_block.body.push(Instruction::binary(
159                spirv::Op::IEqual,
160                self.get_bool_type_id(),
161                condition_id,
162                raw_kind_id,
163                committed_triangle_kind_id,
164            ));
165            let kind_id = self.id_gen.next();
166            valid_block.body.push(Instruction::select(
167                flag_type_id,
168                kind_id,
169                condition_id,
170                self.get_constant_scalar(crate::Literal::U32(
171                    crate::RayQueryIntersection::Triangle as _,
172                )),
173                self.get_constant_scalar(crate::Literal::U32(
174                    crate::RayQueryIntersection::Aabb as _,
175                )),
176            ));
177            kind_id
178        };
179        let idx_id = self.get_index_constant(0);
180        let access_idx = self.id_gen.next();
181        valid_block.body.push(Instruction::access_chain(
182            flag_pointer_type_id,
183            access_idx,
184            blank_intersection_id,
185            &[idx_id],
186        ));
187        valid_block
188            .body
189            .push(Instruction::store(access_idx, kind_id, None));
190
191        let not_none_comp_id = self.id_gen.next();
192        let none_id =
193            self.get_constant_scalar(crate::Literal::U32(crate::RayQueryIntersection::None as _));
194        valid_block.body.push(Instruction::binary(
195            spirv::Op::INotEqual,
196            self.get_bool_type_id(),
197            not_none_comp_id,
198            kind_id,
199            none_id,
200        ));
201
202        let not_none_label_id = self.id_gen.next();
203        let mut not_none_block = Block::new(not_none_label_id);
204
205        let outer_merge_label_id = self.id_gen.next();
206        let outer_merge_block = Block::new(outer_merge_label_id);
207
208        valid_block.body.push(Instruction::selection_merge(
209            outer_merge_label_id,
210            spirv::SelectionControl::NONE,
211        ));
212        function.consume(
213            valid_block,
214            Instruction::branch_conditional(
215                not_none_comp_id,
216                not_none_label_id,
217                outer_merge_label_id,
218            ),
219        );
220
221        let instance_custom_index_id = self.id_gen.next();
222        not_none_block
223            .body
224            .push(Instruction::ray_query_get_intersection(
225                spirv::Op::RayQueryGetIntersectionInstanceCustomIndexKHR,
226                flag_type_id,
227                instance_custom_index_id,
228                query_id,
229                intersection_id,
230            ));
231        let instance_id = self.id_gen.next();
232        not_none_block
233            .body
234            .push(Instruction::ray_query_get_intersection(
235                spirv::Op::RayQueryGetIntersectionInstanceIdKHR,
236                flag_type_id,
237                instance_id,
238                query_id,
239                intersection_id,
240            ));
241        let sbt_record_offset_id = self.id_gen.next();
242        not_none_block
243            .body
244            .push(Instruction::ray_query_get_intersection(
245                spirv::Op::RayQueryGetIntersectionInstanceShaderBindingTableRecordOffsetKHR,
246                flag_type_id,
247                sbt_record_offset_id,
248                query_id,
249                intersection_id,
250            ));
251        let geometry_index_id = self.id_gen.next();
252        not_none_block
253            .body
254            .push(Instruction::ray_query_get_intersection(
255                spirv::Op::RayQueryGetIntersectionGeometryIndexKHR,
256                flag_type_id,
257                geometry_index_id,
258                query_id,
259                intersection_id,
260            ));
261        let primitive_index_id = self.id_gen.next();
262        not_none_block
263            .body
264            .push(Instruction::ray_query_get_intersection(
265                spirv::Op::RayQueryGetIntersectionPrimitiveIndexKHR,
266                flag_type_id,
267                primitive_index_id,
268                query_id,
269                intersection_id,
270            ));
271
272        //Note: there is also `OpRayQueryGetIntersectionCandidateAABBOpaqueKHR`,
273        // but it's not a property of an intersection.
274
275        let object_to_world_id = self.id_gen.next();
276        not_none_block
277            .body
278            .push(Instruction::ray_query_get_intersection(
279                spirv::Op::RayQueryGetIntersectionObjectToWorldKHR,
280                transform_type_id,
281                object_to_world_id,
282                query_id,
283                intersection_id,
284            ));
285        let world_to_object_id = self.id_gen.next();
286        not_none_block
287            .body
288            .push(Instruction::ray_query_get_intersection(
289                spirv::Op::RayQueryGetIntersectionWorldToObjectKHR,
290                transform_type_id,
291                world_to_object_id,
292                query_id,
293                intersection_id,
294            ));
295
296        // instance custom index
297        let idx_id = self.get_index_constant(2);
298        let access_idx = self.id_gen.next();
299        not_none_block.body.push(Instruction::access_chain(
300            flag_pointer_type_id,
301            access_idx,
302            blank_intersection_id,
303            &[idx_id],
304        ));
305        not_none_block.body.push(Instruction::store(
306            access_idx,
307            instance_custom_index_id,
308            None,
309        ));
310
311        // instance
312        let idx_id = self.get_index_constant(3);
313        let access_idx = self.id_gen.next();
314        not_none_block.body.push(Instruction::access_chain(
315            flag_pointer_type_id,
316            access_idx,
317            blank_intersection_id,
318            &[idx_id],
319        ));
320        not_none_block
321            .body
322            .push(Instruction::store(access_idx, instance_id, None));
323
324        let idx_id = self.get_index_constant(4);
325        let access_idx = self.id_gen.next();
326        not_none_block.body.push(Instruction::access_chain(
327            flag_pointer_type_id,
328            access_idx,
329            blank_intersection_id,
330            &[idx_id],
331        ));
332        not_none_block
333            .body
334            .push(Instruction::store(access_idx, sbt_record_offset_id, None));
335
336        let idx_id = self.get_index_constant(5);
337        let access_idx = self.id_gen.next();
338        not_none_block.body.push(Instruction::access_chain(
339            flag_pointer_type_id,
340            access_idx,
341            blank_intersection_id,
342            &[idx_id],
343        ));
344        not_none_block
345            .body
346            .push(Instruction::store(access_idx, geometry_index_id, None));
347
348        let idx_id = self.get_index_constant(6);
349        let access_idx = self.id_gen.next();
350        not_none_block.body.push(Instruction::access_chain(
351            flag_pointer_type_id,
352            access_idx,
353            blank_intersection_id,
354            &[idx_id],
355        ));
356        not_none_block
357            .body
358            .push(Instruction::store(access_idx, primitive_index_id, None));
359
360        let idx_id = self.get_index_constant(9);
361        let access_idx = self.id_gen.next();
362        not_none_block.body.push(Instruction::access_chain(
363            transform_pointer_type_id,
364            access_idx,
365            blank_intersection_id,
366            &[idx_id],
367        ));
368        not_none_block
369            .body
370            .push(Instruction::store(access_idx, object_to_world_id, None));
371
372        let idx_id = self.get_index_constant(10);
373        let access_idx = self.id_gen.next();
374        not_none_block.body.push(Instruction::access_chain(
375            transform_pointer_type_id,
376            access_idx,
377            blank_intersection_id,
378            &[idx_id],
379        ));
380        not_none_block
381            .body
382            .push(Instruction::store(access_idx, world_to_object_id, None));
383
384        let tri_comp_id = self.id_gen.next();
385        let tri_id = self.get_constant_scalar(crate::Literal::U32(
386            crate::RayQueryIntersection::Triangle as _,
387        ));
388        not_none_block.body.push(Instruction::binary(
389            spirv::Op::IEqual,
390            self.get_bool_type_id(),
391            tri_comp_id,
392            kind_id,
393            tri_id,
394        ));
395
396        let tri_label_id = self.id_gen.next();
397        let mut tri_block = Block::new(tri_label_id);
398
399        let merge_label_id = self.id_gen.next();
400        let merge_block = Block::new(merge_label_id);
401        // t
402        {
403            let block = if is_committed {
404                &mut not_none_block
405            } else {
406                &mut tri_block
407            };
408            let t_id = self.id_gen.next();
409            block.body.push(Instruction::ray_query_get_intersection(
410                spirv::Op::RayQueryGetIntersectionTKHR,
411                scalar_type_id,
412                t_id,
413                query_id,
414                intersection_id,
415            ));
416            let idx_id = self.get_index_constant(1);
417            let access_idx = self.id_gen.next();
418            block.body.push(Instruction::access_chain(
419                float_pointer_type_id,
420                access_idx,
421                blank_intersection_id,
422                &[idx_id],
423            ));
424            block.body.push(Instruction::store(access_idx, t_id, None));
425        }
426        not_none_block.body.push(Instruction::selection_merge(
427            merge_label_id,
428            spirv::SelectionControl::NONE,
429        ));
430        function.consume(
431            not_none_block,
432            Instruction::branch_conditional(tri_comp_id, tri_label_id, merge_label_id),
433        );
434
435        let barycentrics_id = self.id_gen.next();
436        tri_block.body.push(Instruction::ray_query_get_intersection(
437            spirv::Op::RayQueryGetIntersectionBarycentricsKHR,
438            barycentrics_type_id,
439            barycentrics_id,
440            query_id,
441            intersection_id,
442        ));
443
444        let front_face_id = self.id_gen.next();
445        tri_block.body.push(Instruction::ray_query_get_intersection(
446            spirv::Op::RayQueryGetIntersectionFrontFaceKHR,
447            bool_type_id,
448            front_face_id,
449            query_id,
450            intersection_id,
451        ));
452
453        let idx_id = self.get_index_constant(7);
454        let access_idx = self.id_gen.next();
455        tri_block.body.push(Instruction::access_chain(
456            barycentrics_pointer_type_id,
457            access_idx,
458            blank_intersection_id,
459            &[idx_id],
460        ));
461        tri_block
462            .body
463            .push(Instruction::store(access_idx, barycentrics_id, None));
464
465        let idx_id = self.get_index_constant(8);
466        let access_idx = self.id_gen.next();
467        tri_block.body.push(Instruction::access_chain(
468            bool_pointer_type_id,
469            access_idx,
470            blank_intersection_id,
471            &[idx_id],
472        ));
473        tri_block
474            .body
475            .push(Instruction::store(access_idx, front_face_id, None));
476        function.consume(tri_block, Instruction::branch(merge_label_id));
477        function.consume(merge_block, Instruction::branch(outer_merge_label_id));
478        function.consume(outer_merge_block, Instruction::branch(final_label_id));
479
480        let loaded_blank_intersection_id = self.id_gen.next();
481        final_block.body.push(Instruction::load(
482            intersection_type_id,
483            loaded_blank_intersection_id,
484            blank_intersection_id,
485            None,
486        ));
487        function.consume(
488            final_block,
489            Instruction::return_value(loaded_blank_intersection_id),
490        );
491
492        function.to_words(&mut self.logical_layout.function_definitions);
493        self.ray_query_functions.insert(
494            LookupRayQueryFunction::GetIntersection {
495                committed: is_committed,
496            },
497            func_id,
498        );
499        func_id
500    }
501
502    fn write_ray_query_initialize(&mut self, ir_module: &crate::Module) -> spirv::Word {
503        if let Some(&word) = self
504            .ray_query_functions
505            .get(&LookupRayQueryFunction::Initialize)
506        {
507            return word;
508        }
509
510        let ray_query_type_id = self.get_ray_query_pointer_id();
511        let acceleration_structure_type_id =
512            self.get_localtype_id(LocalType::AccelerationStructure);
513        let ray_desc_type_id = self.get_handle_type_id(
514            ir_module
515                .special_types
516                .ray_desc
517                .expect("ray desc should be set if ray queries are being initialized"),
518        );
519
520        let u32_ty = self.get_u32_type_id();
521        let u32_ptr_ty = self.get_pointer_type_id(u32_ty, spirv::StorageClass::Function);
522
523        let f32_type_id = self.get_f32_type_id();
524        let f32_ptr_ty = self.get_pointer_type_id(f32_type_id, spirv::StorageClass::Function);
525
526        let (func_id, mut function, arg_ids) = self.write_function_signature(
527            &[
528                ray_query_type_id,
529                acceleration_structure_type_id,
530                ray_desc_type_id,
531                u32_ptr_ty,
532                f32_ptr_ty,
533            ],
534            self.void_type,
535        );
536
537        let query_id = arg_ids[0];
538        let acceleration_structure_id = arg_ids[1];
539        let desc_id = arg_ids[2];
540        let init_tracker_id = arg_ids[3];
541        let t_max_tracker_id = arg_ids[4];
542
543        let label_id = self.id_gen.next();
544        let mut block = Block::new(label_id);
545
546        let super::ExtractedRayDesc {
547            ray_flags_id,
548            cull_mask_id,
549            tmin_id,
550            tmax_id,
551            ray_origin_id,
552            ray_dir_id,
553            valid_id,
554        } = self.write_extract_ray_desc(
555            &mut block,
556            desc_id,
557            self.ray_query_initialization_tracking,
558        );
559
560        block
561            .body
562            .push(Instruction::store(t_max_tracker_id, tmax_id, None));
563
564        let merge_label_id = self.id_gen.next();
565        let merge_block = Block::new(merge_label_id);
566
567        // NOTE: this block will be unreachable if initialization tracking is disabled.
568        let invalid_label_id = self.id_gen.next();
569        let mut invalid_block = Block::new(invalid_label_id);
570
571        let valid_label_id = self.id_gen.next();
572        let mut valid_block = Block::new(valid_label_id);
573
574        match valid_id {
575            Some(all_valid_id) => {
576                block.body.push(Instruction::selection_merge(
577                    merge_label_id,
578                    spirv::SelectionControl::NONE,
579                ));
580                function.consume(
581                    block,
582                    Instruction::branch_conditional(all_valid_id, valid_label_id, invalid_label_id),
583                );
584            }
585            None => {
586                function.consume(block, Instruction::branch(valid_label_id));
587            }
588        }
589
590        valid_block.body.push(Instruction::ray_query_initialize(
591            query_id,
592            acceleration_structure_id,
593            ray_flags_id,
594            cull_mask_id,
595            ray_origin_id,
596            tmin_id,
597            ray_dir_id,
598            tmax_id,
599        ));
600
601        let const_initialized =
602            self.get_constant_scalar(crate::Literal::U32(RayQueryPoint::INITIALIZED.bits()));
603        valid_block
604            .body
605            .push(Instruction::store(init_tracker_id, const_initialized, None));
606
607        function.consume(valid_block, Instruction::branch(merge_label_id));
608
609        if self
610            .flags
611            .contains(WriterFlags::PRINT_ON_RAY_QUERY_INITIALIZATION_FAIL)
612        {
613            self.write_debug_printf(
614                &mut invalid_block,
615                "Naga ignored invalid arguments to rayQueryInitialize with flags: %u t_min: %f t_max: %f origin: %v4f dir: %v4f",
616                &[
617                    ray_flags_id,
618                    tmin_id,
619                    tmax_id,
620                    ray_origin_id,
621                    ray_dir_id,
622                ],
623            );
624        }
625
626        function.consume(invalid_block, Instruction::branch(merge_label_id));
627
628        function.consume(merge_block, Instruction::return_void());
629
630        function.to_words(&mut self.logical_layout.function_definitions);
631
632        self.ray_query_functions
633            .insert(LookupRayQueryFunction::Initialize, func_id);
634        func_id
635    }
636
637    fn write_ray_query_proceed(&mut self) -> spirv::Word {
638        if let Some(&word) = self
639            .ray_query_functions
640            .get(&LookupRayQueryFunction::Proceed)
641        {
642            return word;
643        }
644
645        let ray_query_type_id = self.get_ray_query_pointer_id();
646
647        let u32_ty = self.get_u32_type_id();
648        let u32_ptr_ty = self.get_pointer_type_id(u32_ty, spirv::StorageClass::Function);
649
650        let bool_type_id = self.get_bool_type_id();
651        let bool_ptr_ty = self.get_pointer_type_id(bool_type_id, spirv::StorageClass::Function);
652
653        let (func_id, mut function, arg_ids) =
654            self.write_function_signature(&[ray_query_type_id, u32_ptr_ty], bool_type_id);
655
656        let query_id = arg_ids[0];
657        let init_tracker_id = arg_ids[1];
658
659        let block_id = self.id_gen.next();
660        let mut block = Block::new(block_id);
661
662        // TODO: perhaps this could be replaced with an OpPhi?
663        let proceeded_id = self.id_gen.next();
664        let const_false = self.get_constant_scalar(crate::Literal::Bool(false));
665        block.body.push(Instruction::variable(
666            bool_ptr_ty,
667            proceeded_id,
668            spirv::StorageClass::Function,
669            Some(const_false),
670        ));
671
672        let initialized_tracker_id = self.id_gen.next();
673        block.body.push(Instruction::load(
674            u32_ty,
675            initialized_tracker_id,
676            init_tracker_id,
677            None,
678        ));
679
680        let merge_id = self.id_gen.next();
681        let mut merge_block = Block::new(merge_id);
682
683        let valid_block_id = self.id_gen.next();
684        let mut valid_block = Block::new(valid_block_id);
685
686        let instruction = if self.ray_query_initialization_tracking {
687            let is_initialized = write_ray_flags_contains_flags(
688                self,
689                &mut block,
690                initialized_tracker_id,
691                RayQueryPoint::INITIALIZED.bits(),
692            );
693
694            // Unlike in HLSL, in SPIR-V proceed is only guaranteed to return false once,
695            // after that it is UB to call. Therefore, don't call proceed if we have
696            // already finished as this will appear to have the same behaviour (after the
697            // first call that returns false, all subsequent ones will also return false)
698            let finished_proceed_id = write_ray_flags_contains_flags(
699                self,
700                &mut block,
701                initialized_tracker_id,
702                RayQueryPoint::FINISHED_TRAVERSAL.bits(),
703            );
704
705            let not_finished_id = self.id_gen.next();
706            block.body.push(Instruction::unary(
707                spirv::Op::LogicalNot,
708                bool_type_id,
709                not_finished_id,
710                finished_proceed_id,
711            ));
712
713            let is_valid_id = self.write_logical_and(&mut block, not_finished_id, is_initialized);
714
715            block.body.push(Instruction::selection_merge(
716                merge_id,
717                spirv::SelectionControl::NONE,
718            ));
719
720            Instruction::branch_conditional(is_valid_id, valid_block_id, merge_id)
721        } else {
722            Instruction::branch(valid_block_id)
723        };
724
725        function.consume(block, instruction);
726
727        let has_proceeded = self.id_gen.next();
728        valid_block.body.push(Instruction::ray_query_proceed(
729            bool_type_id,
730            has_proceeded,
731            query_id,
732        ));
733
734        valid_block
735            .body
736            .push(Instruction::store(proceeded_id, has_proceeded, None));
737
738        let add_flag_finished = self.get_constant_scalar(crate::Literal::U32(
739            (RayQueryPoint::PROCEED | RayQueryPoint::FINISHED_TRAVERSAL).bits(),
740        ));
741        let add_flag_continuing =
742            self.get_constant_scalar(crate::Literal::U32(RayQueryPoint::PROCEED.bits()));
743
744        let add_flags_id = self.id_gen.next();
745        valid_block.body.push(Instruction::select(
746            u32_ty,
747            add_flags_id,
748            has_proceeded,
749            add_flag_continuing,
750            add_flag_finished,
751        ));
752        let final_flags = self.id_gen.next();
753        valid_block.body.push(Instruction::binary(
754            spirv::Op::BitwiseOr,
755            u32_ty,
756            final_flags,
757            initialized_tracker_id,
758            add_flags_id,
759        ));
760        valid_block
761            .body
762            .push(Instruction::store(init_tracker_id, final_flags, None));
763
764        function.consume(valid_block, Instruction::branch(merge_id));
765
766        let loaded_proceeded_id = self.id_gen.next();
767        merge_block.body.push(Instruction::load(
768            bool_type_id,
769            loaded_proceeded_id,
770            proceeded_id,
771            None,
772        ));
773
774        function.consume(merge_block, Instruction::return_value(loaded_proceeded_id));
775
776        function.to_words(&mut self.logical_layout.function_definitions);
777
778        self.ray_query_functions
779            .insert(LookupRayQueryFunction::Proceed, func_id);
780        func_id
781    }
782
783    fn write_ray_query_generate_intersection(&mut self) -> spirv::Word {
784        if let Some(&word) = self
785            .ray_query_functions
786            .get(&LookupRayQueryFunction::GenerateIntersection)
787        {
788            return word;
789        }
790
791        let ray_query_type_id = self.get_ray_query_pointer_id();
792
793        let u32_ty = self.get_u32_type_id();
794        let u32_ptr_ty = self.get_pointer_type_id(u32_ty, spirv::StorageClass::Function);
795
796        let f32_type_id = self.get_f32_type_id();
797        let f32_ptr_type_id = self.get_pointer_type_id(f32_type_id, spirv::StorageClass::Function);
798
799        let bool_type_id = self.get_bool_type_id();
800
801        let (func_id, mut function, arg_ids) = self.write_function_signature(
802            &[ray_query_type_id, u32_ptr_ty, f32_type_id, f32_ptr_type_id],
803            self.void_type,
804        );
805
806        let query_id = arg_ids[0];
807        let init_tracker_id = arg_ids[1];
808        let depth_id = arg_ids[2];
809        let t_max_tracker_id = arg_ids[3];
810
811        let block_id = self.id_gen.next();
812        let mut block = Block::new(block_id);
813
814        let current_t = self.id_gen.next();
815        block.body.push(Instruction::variable(
816            f32_ptr_type_id,
817            current_t,
818            spirv::StorageClass::Function,
819            None,
820        ));
821
822        let current_t = self.id_gen.next();
823        block.body.push(Instruction::variable(
824            f32_ptr_type_id,
825            current_t,
826            spirv::StorageClass::Function,
827            None,
828        ));
829
830        let valid_id = self.id_gen.next();
831        let mut valid_block = Block::new(valid_id);
832
833        let final_label_id = self.id_gen.next();
834        let final_block = Block::new(final_label_id);
835
836        let instruction = if self.ray_query_initialization_tracking {
837            let initialized_tracker_id = self.id_gen.next();
838            block.body.push(Instruction::load(
839                u32_ty,
840                initialized_tracker_id,
841                init_tracker_id,
842                None,
843            ));
844
845            let proceeded_id = write_ray_flags_contains_flags(
846                self,
847                &mut block,
848                initialized_tracker_id,
849                RayQueryPoint::PROCEED.bits(),
850            );
851            let finished_proceed_id = write_ray_flags_contains_flags(
852                self,
853                &mut block,
854                initialized_tracker_id,
855                RayQueryPoint::FINISHED_TRAVERSAL.bits(),
856            );
857
858            // Can't find anything to suggest double calling this function is invalid.
859
860            let not_finished_id = self.id_gen.next();
861            block.body.push(Instruction::unary(
862                spirv::Op::LogicalNot,
863                bool_type_id,
864                not_finished_id,
865                finished_proceed_id,
866            ));
867
868            let is_valid_id = self.write_logical_and(&mut block, not_finished_id, proceeded_id);
869
870            block.body.push(Instruction::selection_merge(
871                final_label_id,
872                spirv::SelectionControl::NONE,
873            ));
874
875            Instruction::branch_conditional(is_valid_id, valid_id, final_label_id)
876        } else {
877            Instruction::branch(valid_id)
878        };
879
880        function.consume(block, instruction);
881
882        let intersection_id = self.get_constant_scalar(crate::Literal::U32(
883            spirv::RayQueryIntersection::RayQueryCandidateIntersectionKHR as _,
884        ));
885        let committed_intersection_id = self.get_constant_scalar(crate::Literal::U32(
886            spirv::RayQueryIntersection::RayQueryCommittedIntersectionKHR as _,
887        ));
888        let raw_kind_id = self.id_gen.next();
889        valid_block
890            .body
891            .push(Instruction::ray_query_get_intersection(
892                spirv::Op::RayQueryGetIntersectionTypeKHR,
893                u32_ty,
894                raw_kind_id,
895                query_id,
896                intersection_id,
897            ));
898
899        let candidate_aabb_id = self.get_constant_scalar(crate::Literal::U32(
900            spirv::RayQueryCandidateIntersectionType::RayQueryCandidateIntersectionAABBKHR as _,
901        ));
902        let intersection_aabb_id = self.id_gen.next();
903        valid_block.body.push(Instruction::binary(
904            spirv::Op::IEqual,
905            bool_type_id,
906            intersection_aabb_id,
907            raw_kind_id,
908            candidate_aabb_id,
909        ));
910
911        // Check that the provided t value is between t min and the current committed
912        // t value, (https://docs.vulkan.org/spec/latest/appendices/spirvenv.html#VUID-RuntimeSpirv-OpRayQueryGenerateIntersectionKHR-06353)
913
914        // Get the tmin
915        let t_min_id = self.id_gen.next();
916        valid_block.body.push(Instruction::ray_query_get_t_min(
917            f32_type_id,
918            t_min_id,
919            query_id,
920        ));
921
922        // Get the current committed t, or tmax if no hit.
923        // Basically emulate HLSL's (easier) version
924        // Pseudo-code:
925        // ````wgsl
926        // // start of function
927        // var current_t:f32;
928        // ...
929        // let committed_type_id = RayQueryGetIntersectionTypeKHR<Committed>(query_id);
930        // if committed_type_id == Committed_None {
931        //     current_t = load(t_max_tracker);
932        // } else {
933        //     current_t = RayQueryGetIntersectionTKHR<Committed>(query_id);
934        // }
935        // ...
936        // ````
937
938        let committed_type_id = self.id_gen.next();
939        valid_block
940            .body
941            .push(Instruction::ray_query_get_intersection(
942                spirv::Op::RayQueryGetIntersectionTypeKHR,
943                u32_ty,
944                committed_type_id,
945                query_id,
946                committed_intersection_id,
947            ));
948
949        let no_committed = self.id_gen.next();
950        valid_block.body.push(Instruction::binary(
951            spirv::Op::IEqual,
952            bool_type_id,
953            no_committed,
954            committed_type_id,
955            self.get_constant_scalar(crate::Literal::U32(
956                spirv::RayQueryCommittedIntersectionType::RayQueryCommittedIntersectionNoneKHR as _,
957            )),
958        ));
959
960        let next_valid_block_id = self.id_gen.next();
961        let no_committed_block_id = self.id_gen.next();
962        let mut no_committed_block = Block::new(no_committed_block_id);
963        let committed_block_id = self.id_gen.next();
964        let mut committed_block = Block::new(committed_block_id);
965        valid_block.body.push(Instruction::selection_merge(
966            next_valid_block_id,
967            spirv::SelectionControl::NONE,
968        ));
969        function.consume(
970            valid_block,
971            Instruction::branch_conditional(
972                no_committed,
973                no_committed_block_id,
974                committed_block_id,
975            ),
976        );
977
978        // Assign t_max to current_t
979        let t_max_id = self.id_gen.next();
980        no_committed_block.body.push(Instruction::load(
981            f32_type_id,
982            t_max_id,
983            t_max_tracker_id,
984            None,
985        ));
986        no_committed_block
987            .body
988            .push(Instruction::store(current_t, t_max_id, None));
989        function.consume(no_committed_block, Instruction::branch(next_valid_block_id));
990
991        // Assign t_current to current_t
992        let latest_t_id = self.id_gen.next();
993        committed_block
994            .body
995            .push(Instruction::ray_query_get_intersection(
996                spirv::Op::RayQueryGetIntersectionTKHR,
997                f32_type_id,
998                latest_t_id,
999                query_id,
1000                intersection_id,
1001            ));
1002        committed_block
1003            .body
1004            .push(Instruction::store(current_t, latest_t_id, None));
1005        function.consume(committed_block, Instruction::branch(next_valid_block_id));
1006
1007        let mut valid_block = Block::new(next_valid_block_id);
1008
1009        let t_ge_t_min = self.id_gen.next();
1010        valid_block.body.push(Instruction::binary(
1011            spirv::Op::FOrdGreaterThanEqual,
1012            bool_type_id,
1013            t_ge_t_min,
1014            depth_id,
1015            t_min_id,
1016        ));
1017        let t_current = self.id_gen.next();
1018        valid_block
1019            .body
1020            .push(Instruction::load(f32_type_id, t_current, current_t, None));
1021        let t_le_t_current = self.id_gen.next();
1022        valid_block.body.push(Instruction::binary(
1023            spirv::Op::FOrdLessThanEqual,
1024            bool_type_id,
1025            t_le_t_current,
1026            depth_id,
1027            t_current,
1028        ));
1029
1030        let t_in_range = self.id_gen.next();
1031        valid_block.body.push(Instruction::binary(
1032            spirv::Op::LogicalAnd,
1033            bool_type_id,
1034            t_in_range,
1035            t_ge_t_min,
1036            t_le_t_current,
1037        ));
1038
1039        let call_valid_id = self.id_gen.next();
1040        valid_block.body.push(Instruction::binary(
1041            spirv::Op::LogicalAnd,
1042            bool_type_id,
1043            call_valid_id,
1044            t_in_range,
1045            intersection_aabb_id,
1046        ));
1047
1048        let generate_label_id = self.id_gen.next();
1049        let mut generate_block = Block::new(generate_label_id);
1050
1051        let merge_label_id = self.id_gen.next();
1052        let merge_block = Block::new(merge_label_id);
1053
1054        valid_block.body.push(Instruction::selection_merge(
1055            merge_label_id,
1056            spirv::SelectionControl::NONE,
1057        ));
1058        function.consume(
1059            valid_block,
1060            Instruction::branch_conditional(call_valid_id, generate_label_id, merge_label_id),
1061        );
1062
1063        generate_block
1064            .body
1065            .push(Instruction::ray_query_generate_intersection(
1066                query_id, depth_id,
1067            ));
1068
1069        function.consume(generate_block, Instruction::branch(merge_label_id));
1070        function.consume(merge_block, Instruction::branch(final_label_id));
1071
1072        function.consume(final_block, Instruction::return_void());
1073
1074        function.to_words(&mut self.logical_layout.function_definitions);
1075
1076        self.ray_query_functions
1077            .insert(LookupRayQueryFunction::GenerateIntersection, func_id);
1078        func_id
1079    }
1080
1081    fn write_ray_query_confirm_intersection(&mut self) -> spirv::Word {
1082        if let Some(&word) = self
1083            .ray_query_functions
1084            .get(&LookupRayQueryFunction::ConfirmIntersection)
1085        {
1086            return word;
1087        }
1088
1089        let ray_query_type_id = self.get_ray_query_pointer_id();
1090
1091        let u32_ty = self.get_u32_type_id();
1092        let u32_ptr_ty = self.get_pointer_type_id(u32_ty, spirv::StorageClass::Function);
1093
1094        let bool_type_id = self.get_bool_type_id();
1095
1096        let (func_id, mut function, arg_ids) =
1097            self.write_function_signature(&[ray_query_type_id, u32_ptr_ty], self.void_type);
1098
1099        let query_id = arg_ids[0];
1100        let init_tracker_id = arg_ids[1];
1101
1102        let block_id = self.id_gen.next();
1103        let mut block = Block::new(block_id);
1104
1105        let valid_id = self.id_gen.next();
1106        let mut valid_block = Block::new(valid_id);
1107
1108        let final_label_id = self.id_gen.next();
1109        let final_block = Block::new(final_label_id);
1110
1111        let instruction = if self.ray_query_initialization_tracking {
1112            let initialized_tracker_id = self.id_gen.next();
1113            block.body.push(Instruction::load(
1114                u32_ty,
1115                initialized_tracker_id,
1116                init_tracker_id,
1117                None,
1118            ));
1119
1120            let proceeded_id = write_ray_flags_contains_flags(
1121                self,
1122                &mut block,
1123                initialized_tracker_id,
1124                RayQueryPoint::PROCEED.bits(),
1125            );
1126            let finished_proceed_id = write_ray_flags_contains_flags(
1127                self,
1128                &mut block,
1129                initialized_tracker_id,
1130                RayQueryPoint::FINISHED_TRAVERSAL.bits(),
1131            );
1132            // Although it seems strange to call this twice, I (Vecvec) can't find anything to suggest double calling this function is invalid.
1133            let not_finished_id = self.id_gen.next();
1134            block.body.push(Instruction::unary(
1135                spirv::Op::LogicalNot,
1136                bool_type_id,
1137                not_finished_id,
1138                finished_proceed_id,
1139            ));
1140
1141            let is_valid_id = self.write_logical_and(&mut block, not_finished_id, proceeded_id);
1142
1143            block.body.push(Instruction::selection_merge(
1144                final_label_id,
1145                spirv::SelectionControl::NONE,
1146            ));
1147
1148            Instruction::branch_conditional(is_valid_id, valid_id, final_label_id)
1149        } else {
1150            Instruction::branch(valid_id)
1151        };
1152
1153        function.consume(block, instruction);
1154
1155        let intersection_id = self.get_constant_scalar(crate::Literal::U32(
1156            spirv::RayQueryIntersection::RayQueryCandidateIntersectionKHR as _,
1157        ));
1158        let raw_kind_id = self.id_gen.next();
1159        valid_block
1160            .body
1161            .push(Instruction::ray_query_get_intersection(
1162                spirv::Op::RayQueryGetIntersectionTypeKHR,
1163                u32_ty,
1164                raw_kind_id,
1165                query_id,
1166                intersection_id,
1167            ));
1168
1169        let candidate_tri_id = self.get_constant_scalar(crate::Literal::U32(
1170            spirv::RayQueryCandidateIntersectionType::RayQueryCandidateIntersectionTriangleKHR as _,
1171        ));
1172        let intersection_tri_id = self.id_gen.next();
1173        valid_block.body.push(Instruction::binary(
1174            spirv::Op::IEqual,
1175            bool_type_id,
1176            intersection_tri_id,
1177            raw_kind_id,
1178            candidate_tri_id,
1179        ));
1180
1181        let generate_label_id = self.id_gen.next();
1182        let mut generate_block = Block::new(generate_label_id);
1183
1184        let merge_label_id = self.id_gen.next();
1185        let merge_block = Block::new(merge_label_id);
1186
1187        valid_block.body.push(Instruction::selection_merge(
1188            merge_label_id,
1189            spirv::SelectionControl::NONE,
1190        ));
1191        function.consume(
1192            valid_block,
1193            Instruction::branch_conditional(intersection_tri_id, generate_label_id, merge_label_id),
1194        );
1195
1196        generate_block
1197            .body
1198            .push(Instruction::ray_query_confirm_intersection(query_id));
1199
1200        function.consume(generate_block, Instruction::branch(merge_label_id));
1201        function.consume(merge_block, Instruction::branch(final_label_id));
1202
1203        function.consume(final_block, Instruction::return_void());
1204
1205        self.ray_query_functions
1206            .insert(LookupRayQueryFunction::ConfirmIntersection, func_id);
1207
1208        function.to_words(&mut self.logical_layout.function_definitions);
1209
1210        func_id
1211    }
1212
1213    fn write_ray_query_get_vertex_positions(
1214        &mut self,
1215        is_committed: bool,
1216        ir_module: &crate::Module,
1217    ) -> spirv::Word {
1218        if let Some(&word) =
1219            self.ray_query_functions
1220                .get(&LookupRayQueryFunction::GetVertexPositions {
1221                    committed: is_committed,
1222                })
1223        {
1224            return word;
1225        }
1226
1227        let (committed_ty, committed_tri_ty) = if is_committed {
1228            (
1229                spirv::RayQueryIntersection::RayQueryCommittedIntersectionKHR as u32,
1230                spirv::RayQueryCommittedIntersectionType::RayQueryCommittedIntersectionTriangleKHR
1231                    as u32,
1232            )
1233        } else {
1234            (
1235                spirv::RayQueryIntersection::RayQueryCandidateIntersectionKHR as u32,
1236                spirv::RayQueryCandidateIntersectionType::RayQueryCandidateIntersectionTriangleKHR
1237                    as u32,
1238            )
1239        };
1240
1241        let ray_query_type_id = self.get_ray_query_pointer_id();
1242
1243        let u32_ty = self.get_u32_type_id();
1244        let u32_ptr_ty = self.get_pointer_type_id(u32_ty, spirv::StorageClass::Function);
1245
1246        let rq_get_vertex_positions_ty_id = self.get_handle_type_id(
1247            *ir_module
1248                .special_types
1249                .ray_vertex_return
1250                .as_ref()
1251                .expect("must be generated when reading in get vertex position"),
1252        );
1253        let ptr_return_ty =
1254            self.get_pointer_type_id(rq_get_vertex_positions_ty_id, spirv::StorageClass::Function);
1255
1256        let bool_type_id = self.get_bool_type_id();
1257
1258        let (func_id, mut function, arg_ids) = self.write_function_signature(
1259            &[ray_query_type_id, u32_ptr_ty],
1260            rq_get_vertex_positions_ty_id,
1261        );
1262
1263        let query_id = arg_ids[0];
1264        let init_tracker_id = arg_ids[1];
1265
1266        let block_id = self.id_gen.next();
1267        let mut block = Block::new(block_id);
1268
1269        let return_id = self.id_gen.next();
1270        block.body.push(Instruction::variable(
1271            ptr_return_ty,
1272            return_id,
1273            spirv::StorageClass::Function,
1274            Some(self.get_constant_null(rq_get_vertex_positions_ty_id)),
1275        ));
1276
1277        let valid_id = self.id_gen.next();
1278        let mut valid_block = Block::new(valid_id);
1279
1280        let final_label_id = self.id_gen.next();
1281        let mut final_block = Block::new(final_label_id);
1282
1283        let instruction = if self.ray_query_initialization_tracking {
1284            let initialized_tracker_id = self.id_gen.next();
1285            block.body.push(Instruction::load(
1286                u32_ty,
1287                initialized_tracker_id,
1288                init_tracker_id,
1289                None,
1290            ));
1291
1292            let proceeded_id = write_ray_flags_contains_flags(
1293                self,
1294                &mut block,
1295                initialized_tracker_id,
1296                RayQueryPoint::PROCEED.bits(),
1297            );
1298            let finished_proceed_id = write_ray_flags_contains_flags(
1299                self,
1300                &mut block,
1301                initialized_tracker_id,
1302                RayQueryPoint::FINISHED_TRAVERSAL.bits(),
1303            );
1304
1305            let correct_finish_id = if is_committed {
1306                finished_proceed_id
1307            } else {
1308                let not_finished_id = self.id_gen.next();
1309                block.body.push(Instruction::unary(
1310                    spirv::Op::LogicalNot,
1311                    bool_type_id,
1312                    not_finished_id,
1313                    finished_proceed_id,
1314                ));
1315                not_finished_id
1316            };
1317
1318            let is_valid_id = self.write_logical_and(&mut block, correct_finish_id, proceeded_id);
1319            block.body.push(Instruction::selection_merge(
1320                final_label_id,
1321                spirv::SelectionControl::NONE,
1322            ));
1323            Instruction::branch_conditional(is_valid_id, valid_id, final_label_id)
1324        } else {
1325            Instruction::branch(valid_id)
1326        };
1327
1328        function.consume(block, instruction);
1329
1330        let intersection_id = self.get_constant_scalar(crate::Literal::U32(committed_ty));
1331        let raw_kind_id = self.id_gen.next();
1332        valid_block
1333            .body
1334            .push(Instruction::ray_query_get_intersection(
1335                spirv::Op::RayQueryGetIntersectionTypeKHR,
1336                u32_ty,
1337                raw_kind_id,
1338                query_id,
1339                intersection_id,
1340            ));
1341
1342        let candidate_tri_id = self.get_constant_scalar(crate::Literal::U32(committed_tri_ty));
1343        let intersection_tri_id = self.id_gen.next();
1344        valid_block.body.push(Instruction::binary(
1345            spirv::Op::IEqual,
1346            bool_type_id,
1347            intersection_tri_id,
1348            raw_kind_id,
1349            candidate_tri_id,
1350        ));
1351
1352        let generate_label_id = self.id_gen.next();
1353        let mut vertex_return_block = Block::new(generate_label_id);
1354
1355        let merge_label_id = self.id_gen.next();
1356        let merge_block = Block::new(merge_label_id);
1357
1358        valid_block.body.push(Instruction::selection_merge(
1359            merge_label_id,
1360            spirv::SelectionControl::NONE,
1361        ));
1362        function.consume(
1363            valid_block,
1364            Instruction::branch_conditional(intersection_tri_id, generate_label_id, merge_label_id),
1365        );
1366
1367        let vertices_id = self.id_gen.next();
1368        vertex_return_block
1369            .body
1370            .push(Instruction::ray_query_return_vertex_position(
1371                rq_get_vertex_positions_ty_id,
1372                vertices_id,
1373                query_id,
1374                intersection_id,
1375            ));
1376        vertex_return_block
1377            .body
1378            .push(Instruction::store(return_id, vertices_id, None));
1379
1380        function.consume(vertex_return_block, Instruction::branch(merge_label_id));
1381        function.consume(merge_block, Instruction::branch(final_label_id));
1382
1383        let loaded_pos_id = self.id_gen.next();
1384        final_block.body.push(Instruction::load(
1385            rq_get_vertex_positions_ty_id,
1386            loaded_pos_id,
1387            return_id,
1388            None,
1389        ));
1390
1391        function.consume(final_block, Instruction::return_value(loaded_pos_id));
1392
1393        self.ray_query_functions.insert(
1394            LookupRayQueryFunction::GetVertexPositions {
1395                committed: is_committed,
1396            },
1397            func_id,
1398        );
1399
1400        function.to_words(&mut self.logical_layout.function_definitions);
1401
1402        func_id
1403    }
1404
1405    fn write_ray_query_terminate(&mut self) -> spirv::Word {
1406        if let Some(&word) = self
1407            .ray_query_functions
1408            .get(&LookupRayQueryFunction::Terminate)
1409        {
1410            return word;
1411        }
1412
1413        let ray_query_type_id = self.get_ray_query_pointer_id();
1414
1415        let u32_ty = self.get_u32_type_id();
1416        let u32_ptr_ty = self.get_pointer_type_id(u32_ty, spirv::StorageClass::Function);
1417
1418        let bool_type_id = self.get_bool_type_id();
1419
1420        let (func_id, mut function, arg_ids) =
1421            self.write_function_signature(&[ray_query_type_id, u32_ptr_ty], self.void_type);
1422
1423        let query_id = arg_ids[0];
1424        let init_tracker_id = arg_ids[1];
1425
1426        let block_id = self.id_gen.next();
1427        let mut block = Block::new(block_id);
1428
1429        let initialized_tracker_id = self.id_gen.next();
1430        block.body.push(Instruction::load(
1431            u32_ty,
1432            initialized_tracker_id,
1433            init_tracker_id,
1434            None,
1435        ));
1436
1437        let merge_id = self.id_gen.next();
1438        let merge_block = Block::new(merge_id);
1439
1440        let valid_block_id = self.id_gen.next();
1441        let mut valid_block = Block::new(valid_block_id);
1442
1443        let instruction = if self.ray_query_initialization_tracking {
1444            let has_proceeded = write_ray_flags_contains_flags(
1445                self,
1446                &mut block,
1447                initialized_tracker_id,
1448                RayQueryPoint::PROCEED.bits(),
1449            );
1450
1451            let finished_proceed_id = write_ray_flags_contains_flags(
1452                self,
1453                &mut block,
1454                initialized_tracker_id,
1455                RayQueryPoint::FINISHED_TRAVERSAL.bits(),
1456            );
1457
1458            let not_finished_id = self.id_gen.next();
1459            block.body.push(Instruction::unary(
1460                spirv::Op::LogicalNot,
1461                bool_type_id,
1462                not_finished_id,
1463                finished_proceed_id,
1464            ));
1465
1466            let valid_call = self.write_logical_and(&mut block, not_finished_id, has_proceeded);
1467
1468            block.body.push(Instruction::selection_merge(
1469                merge_id,
1470                spirv::SelectionControl::NONE,
1471            ));
1472
1473            Instruction::branch_conditional(valid_call, valid_block_id, merge_id)
1474        } else {
1475            Instruction::branch(valid_block_id)
1476        };
1477
1478        function.consume(block, instruction);
1479
1480        valid_block
1481            .body
1482            .push(Instruction::ray_query_terminate(query_id));
1483
1484        function.consume(valid_block, Instruction::branch(merge_id));
1485
1486        function.consume(merge_block, Instruction::return_void());
1487
1488        function.to_words(&mut self.logical_layout.function_definitions);
1489
1490        self.ray_query_functions
1491            .insert(LookupRayQueryFunction::Terminate, func_id);
1492        func_id
1493    }
1494}
1495
1496impl BlockContext<'_> {
1497    pub(in super::super) fn write_ray_query_function(
1498        &mut self,
1499        query: Handle<crate::Expression>,
1500        function: &crate::RayQueryFunction,
1501        block: &mut Block,
1502    ) {
1503        let query_id = self.cached[query];
1504        let tracker_ids = *self
1505            .ray_query_tracker_expr
1506            .get(&query)
1507            .expect("not a cached ray query");
1508
1509        match *function {
1510            crate::RayQueryFunction::Initialize {
1511                acceleration_structure,
1512                descriptor,
1513            } => {
1514                let desc_id = self.cached[descriptor];
1515                let acc_struct_id = self.get_handle_id(acceleration_structure);
1516
1517                let func = self.writer.write_ray_query_initialize(self.ir_module);
1518
1519                let func_id = self.gen_id();
1520                block.body.push(Instruction::function_call(
1521                    self.writer.void_type,
1522                    func_id,
1523                    func,
1524                    &[
1525                        query_id,
1526                        acc_struct_id,
1527                        desc_id,
1528                        tracker_ids.initialized_tracker,
1529                        tracker_ids.t_max_tracker,
1530                    ],
1531                ));
1532            }
1533            crate::RayQueryFunction::Proceed { result } => {
1534                let id = self.gen_id();
1535                self.cached[result] = id;
1536
1537                let bool_ty = self.writer.get_bool_type_id();
1538
1539                let func_id = self.writer.write_ray_query_proceed();
1540                block.body.push(Instruction::function_call(
1541                    bool_ty,
1542                    id,
1543                    func_id,
1544                    &[query_id, tracker_ids.initialized_tracker],
1545                ));
1546            }
1547            crate::RayQueryFunction::GenerateIntersection { hit_t } => {
1548                let hit_id = self.cached[hit_t];
1549
1550                let func_id = self.writer.write_ray_query_generate_intersection();
1551
1552                let func_call_id = self.gen_id();
1553                block.body.push(Instruction::function_call(
1554                    self.writer.void_type,
1555                    func_call_id,
1556                    func_id,
1557                    &[
1558                        query_id,
1559                        tracker_ids.initialized_tracker,
1560                        hit_id,
1561                        tracker_ids.t_max_tracker,
1562                    ],
1563                ));
1564            }
1565            crate::RayQueryFunction::ConfirmIntersection => {
1566                let func_id = self.writer.write_ray_query_confirm_intersection();
1567
1568                let func_call_id = self.gen_id();
1569                block.body.push(Instruction::function_call(
1570                    self.writer.void_type,
1571                    func_call_id,
1572                    func_id,
1573                    &[query_id, tracker_ids.initialized_tracker],
1574                ));
1575            }
1576            crate::RayQueryFunction::Terminate => {
1577                let id = self.gen_id();
1578
1579                let func_id = self.writer.write_ray_query_terminate();
1580                block.body.push(Instruction::function_call(
1581                    self.writer.void_type,
1582                    id,
1583                    func_id,
1584                    &[query_id, tracker_ids.initialized_tracker],
1585                ));
1586            }
1587        }
1588    }
1589
1590    pub(in super::super) fn write_ray_query_return_vertex_position(
1591        &mut self,
1592        query: Handle<crate::Expression>,
1593        block: &mut Block,
1594        is_committed: bool,
1595    ) -> spirv::Word {
1596        let fn_id = self
1597            .writer
1598            .write_ray_query_get_vertex_positions(is_committed, self.ir_module);
1599
1600        let query_id = self.cached[query];
1601        let tracker_id = *self
1602            .ray_query_tracker_expr
1603            .get(&query)
1604            .expect("not a cached ray query");
1605
1606        let rq_get_vertex_positions_ty_id = self.get_handle_type_id(
1607            *self
1608                .ir_module
1609                .special_types
1610                .ray_vertex_return
1611                .as_ref()
1612                .expect("must be generated when reading in get vertex position"),
1613        );
1614
1615        let func_call_id = self.gen_id();
1616        block.body.push(Instruction::function_call(
1617            rq_get_vertex_positions_ty_id,
1618            func_call_id,
1619            fn_id,
1620            &[query_id, tracker_id.initialized_tracker],
1621        ));
1622        func_call_id
1623    }
1624}