naga/back/spv/ray/
query.rs

1/*!
2Generating SPIR-V for ray query operations.
3*/
4
5use super::{
6    super::{
7        Block, BlockContext, Instruction, LocalType, LookupRayQueryFunction, NumericType, Writer,
8        WriterFlags,
9    },
10    write_ray_flags_contains_flags,
11};
12use crate::{arena::Handle, back::RayQueryPoint};
13
14impl Writer {
15    pub(in super::super) fn write_ray_query_get_intersection_function(
16        &mut self,
17        is_committed: bool,
18        ir_module: &crate::Module,
19    ) -> spirv::Word {
20        if let Some(&word) =
21            self.ray_query_functions
22                .get(&LookupRayQueryFunction::GetIntersection {
23                    committed: is_committed,
24                })
25        {
26            return word;
27        }
28        let ray_intersection = ir_module.special_types.ray_intersection.unwrap();
29        let intersection_type_id = self.get_handle_type_id(ray_intersection);
30        let intersection_pointer_type_id =
31            self.get_pointer_type_id(intersection_type_id, spirv::StorageClass::Function);
32
33        let flag_type_id = self.get_u32_type_id();
34        let flag_pointer_type_id =
35            self.get_pointer_type_id(flag_type_id, spirv::StorageClass::Function);
36
37        let transform_type_id = self.get_numeric_type_id(NumericType::Matrix {
38            columns: crate::VectorSize::Quad,
39            rows: crate::VectorSize::Tri,
40            scalar: crate::Scalar::F32,
41        });
42        let transform_pointer_type_id =
43            self.get_pointer_type_id(transform_type_id, spirv::StorageClass::Function);
44
45        let barycentrics_type_id = self.get_numeric_type_id(NumericType::Vector {
46            size: crate::VectorSize::Bi,
47            scalar: crate::Scalar::F32,
48        });
49        let barycentrics_pointer_type_id =
50            self.get_pointer_type_id(barycentrics_type_id, spirv::StorageClass::Function);
51
52        let bool_type_id = self.get_bool_type_id();
53        let bool_pointer_type_id =
54            self.get_pointer_type_id(bool_type_id, spirv::StorageClass::Function);
55
56        let scalar_type_id = self.get_f32_type_id();
57        let float_pointer_type_id = self.get_f32_pointer_type_id(spirv::StorageClass::Function);
58
59        let argument_type_id = self.get_ray_query_pointer_id();
60
61        let (func_id, mut function, arg_ids) = self.write_function_signature(
62            &[argument_type_id, flag_pointer_type_id],
63            intersection_type_id,
64        );
65
66        let query_id = arg_ids[0];
67        let intersection_tracker_id = arg_ids[1];
68
69        let label_id = self.id_gen.next();
70        let mut block = Block::new(label_id);
71
72        let blank_intersection = self.get_constant_null(intersection_type_id);
73        let blank_intersection_id = self.id_gen.next();
74        // This must be before everything else in the function.
75        block.body.push(Instruction::variable(
76            intersection_pointer_type_id,
77            blank_intersection_id,
78            spirv::StorageClass::Function,
79            Some(blank_intersection),
80        ));
81
82        let intersection_id = self.get_constant_scalar(crate::Literal::U32(if is_committed {
83            spirv::RayQueryIntersection::RayQueryCommittedIntersectionKHR
84        } else {
85            spirv::RayQueryIntersection::RayQueryCandidateIntersectionKHR
86        } as _));
87
88        let loaded_ray_query_tracker_id = self.id_gen.next();
89        block.body.push(Instruction::load(
90            flag_type_id,
91            loaded_ray_query_tracker_id,
92            intersection_tracker_id,
93            None,
94        ));
95        let proceeded_id = write_ray_flags_contains_flags(
96            self,
97            &mut block,
98            loaded_ray_query_tracker_id,
99            RayQueryPoint::PROCEED.bits(),
100        );
101        let finished_proceed_id = write_ray_flags_contains_flags(
102            self,
103            &mut block,
104            loaded_ray_query_tracker_id,
105            RayQueryPoint::FINISHED_TRAVERSAL.bits(),
106        );
107        let proceed_finished_correct_id = if is_committed {
108            finished_proceed_id
109        } else {
110            let not_finished_id = self.id_gen.next();
111            block.body.push(Instruction::unary(
112                spirv::Op::LogicalNot,
113                bool_type_id,
114                not_finished_id,
115                finished_proceed_id,
116            ));
117            not_finished_id
118        };
119
120        let is_valid_id =
121            self.write_logical_and(&mut block, proceed_finished_correct_id, proceeded_id);
122
123        let valid_id = self.id_gen.next();
124        let mut valid_block = Block::new(valid_id);
125
126        let final_label_id = self.id_gen.next();
127        let mut final_block = Block::new(final_label_id);
128
129        block.body.push(Instruction::selection_merge(
130            final_label_id,
131            spirv::SelectionControl::NONE,
132        ));
133        function.consume(
134            block,
135            Instruction::branch_conditional(is_valid_id, valid_id, final_label_id),
136        );
137
138        let raw_kind_id = self.id_gen.next();
139        valid_block
140            .body
141            .push(Instruction::ray_query_get_intersection(
142                spirv::Op::RayQueryGetIntersectionTypeKHR,
143                flag_type_id,
144                raw_kind_id,
145                query_id,
146                intersection_id,
147            ));
148        let kind_id = if is_committed {
149            // Nothing to do: the IR value matches `spirv::RayQueryCommittedIntersectionType`
150            raw_kind_id
151        } else {
152            // Remap from the candidate kind to IR
153            let condition_id = self.id_gen.next();
154            let committed_triangle_kind_id = self.get_constant_scalar(crate::Literal::U32(
155                spirv::RayQueryCandidateIntersectionType::RayQueryCandidateIntersectionTriangleKHR
156                    as _,
157            ));
158            valid_block.body.push(Instruction::binary(
159                spirv::Op::IEqual,
160                self.get_bool_type_id(),
161                condition_id,
162                raw_kind_id,
163                committed_triangle_kind_id,
164            ));
165            let kind_id = self.id_gen.next();
166            valid_block.body.push(Instruction::select(
167                flag_type_id,
168                kind_id,
169                condition_id,
170                self.get_constant_scalar(crate::Literal::U32(
171                    crate::RayQueryIntersection::Triangle as _,
172                )),
173                self.get_constant_scalar(crate::Literal::U32(
174                    crate::RayQueryIntersection::Aabb as _,
175                )),
176            ));
177            kind_id
178        };
179        let idx_id = self.get_index_constant(0);
180        let access_idx = self.id_gen.next();
181        valid_block.body.push(Instruction::access_chain(
182            flag_pointer_type_id,
183            access_idx,
184            blank_intersection_id,
185            &[idx_id],
186        ));
187        valid_block
188            .body
189            .push(Instruction::store(access_idx, kind_id, None));
190
191        let not_none_comp_id = self.id_gen.next();
192        let none_id =
193            self.get_constant_scalar(crate::Literal::U32(crate::RayQueryIntersection::None as _));
194        valid_block.body.push(Instruction::binary(
195            spirv::Op::INotEqual,
196            self.get_bool_type_id(),
197            not_none_comp_id,
198            kind_id,
199            none_id,
200        ));
201
202        let not_none_label_id = self.id_gen.next();
203        let mut not_none_block = Block::new(not_none_label_id);
204
205        let outer_merge_label_id = self.id_gen.next();
206        let outer_merge_block = Block::new(outer_merge_label_id);
207
208        valid_block.body.push(Instruction::selection_merge(
209            outer_merge_label_id,
210            spirv::SelectionControl::NONE,
211        ));
212        function.consume(
213            valid_block,
214            Instruction::branch_conditional(
215                not_none_comp_id,
216                not_none_label_id,
217                outer_merge_label_id,
218            ),
219        );
220
221        let instance_custom_index_id = self.id_gen.next();
222        not_none_block
223            .body
224            .push(Instruction::ray_query_get_intersection(
225                spirv::Op::RayQueryGetIntersectionInstanceCustomIndexKHR,
226                flag_type_id,
227                instance_custom_index_id,
228                query_id,
229                intersection_id,
230            ));
231        let instance_id = self.id_gen.next();
232        not_none_block
233            .body
234            .push(Instruction::ray_query_get_intersection(
235                spirv::Op::RayQueryGetIntersectionInstanceIdKHR,
236                flag_type_id,
237                instance_id,
238                query_id,
239                intersection_id,
240            ));
241        let sbt_record_offset_id = self.id_gen.next();
242        not_none_block
243            .body
244            .push(Instruction::ray_query_get_intersection(
245                spirv::Op::RayQueryGetIntersectionInstanceShaderBindingTableRecordOffsetKHR,
246                flag_type_id,
247                sbt_record_offset_id,
248                query_id,
249                intersection_id,
250            ));
251        let geometry_index_id = self.id_gen.next();
252        not_none_block
253            .body
254            .push(Instruction::ray_query_get_intersection(
255                spirv::Op::RayQueryGetIntersectionGeometryIndexKHR,
256                flag_type_id,
257                geometry_index_id,
258                query_id,
259                intersection_id,
260            ));
261        let primitive_index_id = self.id_gen.next();
262        not_none_block
263            .body
264            .push(Instruction::ray_query_get_intersection(
265                spirv::Op::RayQueryGetIntersectionPrimitiveIndexKHR,
266                flag_type_id,
267                primitive_index_id,
268                query_id,
269                intersection_id,
270            ));
271
272        //Note: there is also `OpRayQueryGetIntersectionCandidateAABBOpaqueKHR`,
273        // but it's not a property of an intersection.
274
275        let object_to_world_id = self.id_gen.next();
276        not_none_block
277            .body
278            .push(Instruction::ray_query_get_intersection(
279                spirv::Op::RayQueryGetIntersectionObjectToWorldKHR,
280                transform_type_id,
281                object_to_world_id,
282                query_id,
283                intersection_id,
284            ));
285        let world_to_object_id = self.id_gen.next();
286        not_none_block
287            .body
288            .push(Instruction::ray_query_get_intersection(
289                spirv::Op::RayQueryGetIntersectionWorldToObjectKHR,
290                transform_type_id,
291                world_to_object_id,
292                query_id,
293                intersection_id,
294            ));
295
296        // instance custom index
297        let idx_id = self.get_index_constant(2);
298        let access_idx = self.id_gen.next();
299        not_none_block.body.push(Instruction::access_chain(
300            flag_pointer_type_id,
301            access_idx,
302            blank_intersection_id,
303            &[idx_id],
304        ));
305        not_none_block.body.push(Instruction::store(
306            access_idx,
307            instance_custom_index_id,
308            None,
309        ));
310
311        // instance
312        let idx_id = self.get_index_constant(3);
313        let access_idx = self.id_gen.next();
314        not_none_block.body.push(Instruction::access_chain(
315            flag_pointer_type_id,
316            access_idx,
317            blank_intersection_id,
318            &[idx_id],
319        ));
320        not_none_block
321            .body
322            .push(Instruction::store(access_idx, instance_id, None));
323
324        let idx_id = self.get_index_constant(4);
325        let access_idx = self.id_gen.next();
326        not_none_block.body.push(Instruction::access_chain(
327            flag_pointer_type_id,
328            access_idx,
329            blank_intersection_id,
330            &[idx_id],
331        ));
332        not_none_block
333            .body
334            .push(Instruction::store(access_idx, sbt_record_offset_id, None));
335
336        let idx_id = self.get_index_constant(5);
337        let access_idx = self.id_gen.next();
338        not_none_block.body.push(Instruction::access_chain(
339            flag_pointer_type_id,
340            access_idx,
341            blank_intersection_id,
342            &[idx_id],
343        ));
344        not_none_block
345            .body
346            .push(Instruction::store(access_idx, geometry_index_id, None));
347
348        let idx_id = self.get_index_constant(6);
349        let access_idx = self.id_gen.next();
350        not_none_block.body.push(Instruction::access_chain(
351            flag_pointer_type_id,
352            access_idx,
353            blank_intersection_id,
354            &[idx_id],
355        ));
356        not_none_block
357            .body
358            .push(Instruction::store(access_idx, primitive_index_id, None));
359
360        let idx_id = self.get_index_constant(9);
361        let access_idx = self.id_gen.next();
362        not_none_block.body.push(Instruction::access_chain(
363            transform_pointer_type_id,
364            access_idx,
365            blank_intersection_id,
366            &[idx_id],
367        ));
368        not_none_block
369            .body
370            .push(Instruction::store(access_idx, object_to_world_id, None));
371
372        let idx_id = self.get_index_constant(10);
373        let access_idx = self.id_gen.next();
374        not_none_block.body.push(Instruction::access_chain(
375            transform_pointer_type_id,
376            access_idx,
377            blank_intersection_id,
378            &[idx_id],
379        ));
380        not_none_block
381            .body
382            .push(Instruction::store(access_idx, world_to_object_id, None));
383
384        let tri_comp_id = self.id_gen.next();
385        let tri_id = self.get_constant_scalar(crate::Literal::U32(
386            crate::RayQueryIntersection::Triangle as _,
387        ));
388        not_none_block.body.push(Instruction::binary(
389            spirv::Op::IEqual,
390            self.get_bool_type_id(),
391            tri_comp_id,
392            kind_id,
393            tri_id,
394        ));
395
396        let tri_label_id = self.id_gen.next();
397        let mut tri_block = Block::new(tri_label_id);
398
399        let merge_label_id = self.id_gen.next();
400        let merge_block = Block::new(merge_label_id);
401        // t
402        {
403            let block = if is_committed {
404                &mut not_none_block
405            } else {
406                &mut tri_block
407            };
408            let t_id = self.id_gen.next();
409            block.body.push(Instruction::ray_query_get_intersection(
410                spirv::Op::RayQueryGetIntersectionTKHR,
411                scalar_type_id,
412                t_id,
413                query_id,
414                intersection_id,
415            ));
416            let idx_id = self.get_index_constant(1);
417            let access_idx = self.id_gen.next();
418            block.body.push(Instruction::access_chain(
419                float_pointer_type_id,
420                access_idx,
421                blank_intersection_id,
422                &[idx_id],
423            ));
424            block.body.push(Instruction::store(access_idx, t_id, None));
425        }
426        not_none_block.body.push(Instruction::selection_merge(
427            merge_label_id,
428            spirv::SelectionControl::NONE,
429        ));
430        function.consume(
431            not_none_block,
432            Instruction::branch_conditional(not_none_comp_id, tri_label_id, merge_label_id),
433        );
434
435        let barycentrics_id = self.id_gen.next();
436        tri_block.body.push(Instruction::ray_query_get_intersection(
437            spirv::Op::RayQueryGetIntersectionBarycentricsKHR,
438            barycentrics_type_id,
439            barycentrics_id,
440            query_id,
441            intersection_id,
442        ));
443
444        let front_face_id = self.id_gen.next();
445        tri_block.body.push(Instruction::ray_query_get_intersection(
446            spirv::Op::RayQueryGetIntersectionFrontFaceKHR,
447            bool_type_id,
448            front_face_id,
449            query_id,
450            intersection_id,
451        ));
452
453        let idx_id = self.get_index_constant(7);
454        let access_idx = self.id_gen.next();
455        tri_block.body.push(Instruction::access_chain(
456            barycentrics_pointer_type_id,
457            access_idx,
458            blank_intersection_id,
459            &[idx_id],
460        ));
461        tri_block
462            .body
463            .push(Instruction::store(access_idx, barycentrics_id, None));
464
465        let idx_id = self.get_index_constant(8);
466        let access_idx = self.id_gen.next();
467        tri_block.body.push(Instruction::access_chain(
468            bool_pointer_type_id,
469            access_idx,
470            blank_intersection_id,
471            &[idx_id],
472        ));
473        tri_block
474            .body
475            .push(Instruction::store(access_idx, front_face_id, None));
476        function.consume(tri_block, Instruction::branch(merge_label_id));
477        function.consume(merge_block, Instruction::branch(outer_merge_label_id));
478        function.consume(outer_merge_block, Instruction::branch(final_label_id));
479
480        let loaded_blank_intersection_id = self.id_gen.next();
481        final_block.body.push(Instruction::load(
482            intersection_type_id,
483            loaded_blank_intersection_id,
484            blank_intersection_id,
485            None,
486        ));
487        function.consume(
488            final_block,
489            Instruction::return_value(loaded_blank_intersection_id),
490        );
491
492        function.to_words(&mut self.logical_layout.function_definitions);
493        self.ray_query_functions.insert(
494            LookupRayQueryFunction::GetIntersection {
495                committed: is_committed,
496            },
497            func_id,
498        );
499        func_id
500    }
501
502    fn write_ray_query_initialize(&mut self, ir_module: &crate::Module) -> spirv::Word {
503        if let Some(&word) = self
504            .ray_query_functions
505            .get(&LookupRayQueryFunction::Initialize)
506        {
507            return word;
508        }
509
510        let ray_query_type_id = self.get_ray_query_pointer_id();
511        let acceleration_structure_type_id =
512            self.get_localtype_id(LocalType::AccelerationStructure);
513        let ray_desc_type_id = self.get_handle_type_id(
514            ir_module
515                .special_types
516                .ray_desc
517                .expect("ray desc should be set if ray queries are being initialized"),
518        );
519
520        let u32_ty = self.get_u32_type_id();
521        let u32_ptr_ty = self.get_pointer_type_id(u32_ty, spirv::StorageClass::Function);
522
523        let f32_type_id = self.get_f32_type_id();
524        let f32_ptr_ty = self.get_pointer_type_id(f32_type_id, spirv::StorageClass::Function);
525
526        let (func_id, mut function, arg_ids) = self.write_function_signature(
527            &[
528                ray_query_type_id,
529                acceleration_structure_type_id,
530                ray_desc_type_id,
531                u32_ptr_ty,
532                f32_ptr_ty,
533            ],
534            self.void_type,
535        );
536
537        let query_id = arg_ids[0];
538        let acceleration_structure_id = arg_ids[1];
539        let desc_id = arg_ids[2];
540        let init_tracker_id = arg_ids[3];
541        let t_max_tracker_id = arg_ids[4];
542
543        let label_id = self.id_gen.next();
544        let mut block = Block::new(label_id);
545
546        let super::ExtractedRayDesc {
547            ray_flags_id,
548            cull_mask_id,
549            tmin_id,
550            tmax_id,
551            ray_origin_id,
552            ray_dir_id,
553            valid_id,
554        } = self.write_extract_ray_desc(
555            &mut block,
556            desc_id,
557            self.ray_query_initialization_tracking,
558        );
559
560        block
561            .body
562            .push(Instruction::store(t_max_tracker_id, tmax_id, None));
563
564        let merge_label_id = self.id_gen.next();
565        let merge_block = Block::new(merge_label_id);
566
567        // NOTE: this block will be unreachable if initialization tracking is disabled.
568        let invalid_label_id = self.id_gen.next();
569        let mut invalid_block = Block::new(invalid_label_id);
570
571        let valid_label_id = self.id_gen.next();
572        let mut valid_block = Block::new(valid_label_id);
573
574        match valid_id {
575            Some(all_valid_id) => {
576                block.body.push(Instruction::selection_merge(
577                    merge_label_id,
578                    spirv::SelectionControl::NONE,
579                ));
580                function.consume(
581                    block,
582                    Instruction::branch_conditional(all_valid_id, valid_label_id, invalid_label_id),
583                );
584            }
585            None => {
586                function.consume(block, Instruction::branch(valid_label_id));
587            }
588        }
589
590        valid_block.body.push(Instruction::ray_query_initialize(
591            query_id,
592            acceleration_structure_id,
593            ray_flags_id,
594            cull_mask_id,
595            ray_origin_id,
596            tmin_id,
597            ray_dir_id,
598            tmax_id,
599        ));
600
601        let const_initialized =
602            self.get_constant_scalar(crate::Literal::U32(RayQueryPoint::INITIALIZED.bits()));
603        valid_block
604            .body
605            .push(Instruction::store(init_tracker_id, const_initialized, None));
606
607        function.consume(valid_block, Instruction::branch(merge_label_id));
608
609        if self
610            .flags
611            .contains(WriterFlags::PRINT_ON_RAY_QUERY_INITIALIZATION_FAIL)
612        {
613            self.write_debug_printf(
614                &mut invalid_block,
615                "Naga ignored invalid arguments to rayQueryInitialize with flags: %u t_min: %f t_max: %f origin: %v4f dir: %v4f",
616                &[
617                    ray_flags_id,
618                    tmin_id,
619                    tmax_id,
620                    ray_origin_id,
621                    ray_dir_id,
622                ],
623            );
624        }
625
626        function.consume(invalid_block, Instruction::branch(merge_label_id));
627
628        function.consume(merge_block, Instruction::return_void());
629
630        function.to_words(&mut self.logical_layout.function_definitions);
631
632        self.ray_query_functions
633            .insert(LookupRayQueryFunction::Initialize, func_id);
634        func_id
635    }
636
637    fn write_ray_query_proceed(&mut self) -> spirv::Word {
638        if let Some(&word) = self
639            .ray_query_functions
640            .get(&LookupRayQueryFunction::Proceed)
641        {
642            return word;
643        }
644
645        let ray_query_type_id = self.get_ray_query_pointer_id();
646
647        let u32_ty = self.get_u32_type_id();
648        let u32_ptr_ty = self.get_pointer_type_id(u32_ty, spirv::StorageClass::Function);
649
650        let bool_type_id = self.get_bool_type_id();
651        let bool_ptr_ty = self.get_pointer_type_id(bool_type_id, spirv::StorageClass::Function);
652
653        let (func_id, mut function, arg_ids) =
654            self.write_function_signature(&[ray_query_type_id, u32_ptr_ty], bool_type_id);
655
656        let query_id = arg_ids[0];
657        let init_tracker_id = arg_ids[1];
658
659        let block_id = self.id_gen.next();
660        let mut block = Block::new(block_id);
661
662        // TODO: perhaps this could be replaced with an OpPhi?
663        let proceeded_id = self.id_gen.next();
664        let const_false = self.get_constant_scalar(crate::Literal::Bool(false));
665        block.body.push(Instruction::variable(
666            bool_ptr_ty,
667            proceeded_id,
668            spirv::StorageClass::Function,
669            Some(const_false),
670        ));
671
672        let initialized_tracker_id = self.id_gen.next();
673        block.body.push(Instruction::load(
674            u32_ty,
675            initialized_tracker_id,
676            init_tracker_id,
677            None,
678        ));
679
680        let merge_id = self.id_gen.next();
681        let mut merge_block = Block::new(merge_id);
682
683        let valid_block_id = self.id_gen.next();
684        let mut valid_block = Block::new(valid_block_id);
685
686        let instruction = if self.ray_query_initialization_tracking {
687            let is_initialized = write_ray_flags_contains_flags(
688                self,
689                &mut block,
690                initialized_tracker_id,
691                RayQueryPoint::INITIALIZED.bits(),
692            );
693
694            block.body.push(Instruction::selection_merge(
695                merge_id,
696                spirv::SelectionControl::NONE,
697            ));
698
699            Instruction::branch_conditional(is_initialized, valid_block_id, merge_id)
700        } else {
701            Instruction::branch(valid_block_id)
702        };
703
704        function.consume(block, instruction);
705
706        let has_proceeded = self.id_gen.next();
707        valid_block.body.push(Instruction::ray_query_proceed(
708            bool_type_id,
709            has_proceeded,
710            query_id,
711        ));
712
713        valid_block
714            .body
715            .push(Instruction::store(proceeded_id, has_proceeded, None));
716
717        let add_flag_finished = self.get_constant_scalar(crate::Literal::U32(
718            (RayQueryPoint::PROCEED | RayQueryPoint::FINISHED_TRAVERSAL).bits(),
719        ));
720        let add_flag_continuing =
721            self.get_constant_scalar(crate::Literal::U32(RayQueryPoint::PROCEED.bits()));
722
723        let add_flags_id = self.id_gen.next();
724        valid_block.body.push(Instruction::select(
725            u32_ty,
726            add_flags_id,
727            has_proceeded,
728            add_flag_continuing,
729            add_flag_finished,
730        ));
731        let final_flags = self.id_gen.next();
732        valid_block.body.push(Instruction::binary(
733            spirv::Op::BitwiseOr,
734            u32_ty,
735            final_flags,
736            initialized_tracker_id,
737            add_flags_id,
738        ));
739        valid_block
740            .body
741            .push(Instruction::store(init_tracker_id, final_flags, None));
742
743        function.consume(valid_block, Instruction::branch(merge_id));
744
745        let loaded_proceeded_id = self.id_gen.next();
746        merge_block.body.push(Instruction::load(
747            bool_type_id,
748            loaded_proceeded_id,
749            proceeded_id,
750            None,
751        ));
752
753        function.consume(merge_block, Instruction::return_value(loaded_proceeded_id));
754
755        function.to_words(&mut self.logical_layout.function_definitions);
756
757        self.ray_query_functions
758            .insert(LookupRayQueryFunction::Proceed, func_id);
759        func_id
760    }
761
762    fn write_ray_query_generate_intersection(&mut self) -> spirv::Word {
763        if let Some(&word) = self
764            .ray_query_functions
765            .get(&LookupRayQueryFunction::GenerateIntersection)
766        {
767            return word;
768        }
769
770        let ray_query_type_id = self.get_ray_query_pointer_id();
771
772        let u32_ty = self.get_u32_type_id();
773        let u32_ptr_ty = self.get_pointer_type_id(u32_ty, spirv::StorageClass::Function);
774
775        let f32_type_id = self.get_f32_type_id();
776        let f32_ptr_type_id = self.get_pointer_type_id(f32_type_id, spirv::StorageClass::Function);
777
778        let bool_type_id = self.get_bool_type_id();
779
780        let (func_id, mut function, arg_ids) = self.write_function_signature(
781            &[ray_query_type_id, u32_ptr_ty, f32_type_id, f32_ptr_type_id],
782            self.void_type,
783        );
784
785        let query_id = arg_ids[0];
786        let init_tracker_id = arg_ids[1];
787        let depth_id = arg_ids[2];
788        let t_max_tracker_id = arg_ids[3];
789
790        let block_id = self.id_gen.next();
791        let mut block = Block::new(block_id);
792
793        let current_t = self.id_gen.next();
794        block.body.push(Instruction::variable(
795            f32_ptr_type_id,
796            current_t,
797            spirv::StorageClass::Function,
798            None,
799        ));
800
801        let current_t = self.id_gen.next();
802        block.body.push(Instruction::variable(
803            f32_ptr_type_id,
804            current_t,
805            spirv::StorageClass::Function,
806            None,
807        ));
808
809        let valid_id = self.id_gen.next();
810        let mut valid_block = Block::new(valid_id);
811
812        let final_label_id = self.id_gen.next();
813        let final_block = Block::new(final_label_id);
814
815        let instruction = if self.ray_query_initialization_tracking {
816            let initialized_tracker_id = self.id_gen.next();
817            block.body.push(Instruction::load(
818                u32_ty,
819                initialized_tracker_id,
820                init_tracker_id,
821                None,
822            ));
823
824            let proceeded_id = write_ray_flags_contains_flags(
825                self,
826                &mut block,
827                initialized_tracker_id,
828                RayQueryPoint::PROCEED.bits(),
829            );
830            let finished_proceed_id = write_ray_flags_contains_flags(
831                self,
832                &mut block,
833                initialized_tracker_id,
834                RayQueryPoint::FINISHED_TRAVERSAL.bits(),
835            );
836
837            // Can't find anything to suggest double calling this function is invalid.
838
839            let not_finished_id = self.id_gen.next();
840            block.body.push(Instruction::unary(
841                spirv::Op::LogicalNot,
842                bool_type_id,
843                not_finished_id,
844                finished_proceed_id,
845            ));
846
847            let is_valid_id = self.write_logical_and(&mut block, not_finished_id, proceeded_id);
848
849            block.body.push(Instruction::selection_merge(
850                final_label_id,
851                spirv::SelectionControl::NONE,
852            ));
853
854            Instruction::branch_conditional(is_valid_id, valid_id, final_label_id)
855        } else {
856            Instruction::branch(valid_id)
857        };
858
859        function.consume(block, instruction);
860
861        let intersection_id = self.get_constant_scalar(crate::Literal::U32(
862            spirv::RayQueryIntersection::RayQueryCandidateIntersectionKHR as _,
863        ));
864        let committed_intersection_id = self.get_constant_scalar(crate::Literal::U32(
865            spirv::RayQueryIntersection::RayQueryCommittedIntersectionKHR as _,
866        ));
867        let raw_kind_id = self.id_gen.next();
868        valid_block
869            .body
870            .push(Instruction::ray_query_get_intersection(
871                spirv::Op::RayQueryGetIntersectionTypeKHR,
872                u32_ty,
873                raw_kind_id,
874                query_id,
875                intersection_id,
876            ));
877
878        let candidate_aabb_id = self.get_constant_scalar(crate::Literal::U32(
879            spirv::RayQueryCandidateIntersectionType::RayQueryCandidateIntersectionAABBKHR as _,
880        ));
881        let intersection_aabb_id = self.id_gen.next();
882        valid_block.body.push(Instruction::binary(
883            spirv::Op::IEqual,
884            bool_type_id,
885            intersection_aabb_id,
886            raw_kind_id,
887            candidate_aabb_id,
888        ));
889
890        // Check that the provided t value is between t min and the current committed
891        // t value, (https://docs.vulkan.org/spec/latest/appendices/spirvenv.html#VUID-RuntimeSpirv-OpRayQueryGenerateIntersectionKHR-06353)
892
893        // Get the tmin
894        let t_min_id = self.id_gen.next();
895        valid_block.body.push(Instruction::ray_query_get_t_min(
896            f32_type_id,
897            t_min_id,
898            query_id,
899        ));
900
901        // Get the current committed t, or tmax if no hit.
902        // Basically emulate HLSL's (easier) version
903        // Pseudo-code:
904        // ````wgsl
905        // // start of function
906        // var current_t:f32;
907        // ...
908        // let committed_type_id = RayQueryGetIntersectionTypeKHR<Committed>(query_id);
909        // if committed_type_id == Committed_None {
910        //     current_t = load(t_max_tracker);
911        // } else {
912        //     current_t = RayQueryGetIntersectionTKHR<Committed>(query_id);
913        // }
914        // ...
915        // ````
916
917        let committed_type_id = self.id_gen.next();
918        valid_block
919            .body
920            .push(Instruction::ray_query_get_intersection(
921                spirv::Op::RayQueryGetIntersectionTypeKHR,
922                u32_ty,
923                committed_type_id,
924                query_id,
925                committed_intersection_id,
926            ));
927
928        let no_committed = self.id_gen.next();
929        valid_block.body.push(Instruction::binary(
930            spirv::Op::IEqual,
931            bool_type_id,
932            no_committed,
933            committed_type_id,
934            self.get_constant_scalar(crate::Literal::U32(
935                spirv::RayQueryCommittedIntersectionType::RayQueryCommittedIntersectionNoneKHR as _,
936            )),
937        ));
938
939        let next_valid_block_id = self.id_gen.next();
940        let no_committed_block_id = self.id_gen.next();
941        let mut no_committed_block = Block::new(no_committed_block_id);
942        let committed_block_id = self.id_gen.next();
943        let mut committed_block = Block::new(committed_block_id);
944        valid_block.body.push(Instruction::selection_merge(
945            next_valid_block_id,
946            spirv::SelectionControl::NONE,
947        ));
948        function.consume(
949            valid_block,
950            Instruction::branch_conditional(
951                no_committed,
952                no_committed_block_id,
953                committed_block_id,
954            ),
955        );
956
957        // Assign t_max to current_t
958        let t_max_id = self.id_gen.next();
959        no_committed_block.body.push(Instruction::load(
960            f32_type_id,
961            t_max_id,
962            t_max_tracker_id,
963            None,
964        ));
965        no_committed_block
966            .body
967            .push(Instruction::store(current_t, t_max_id, None));
968        function.consume(no_committed_block, Instruction::branch(next_valid_block_id));
969
970        // Assign t_current to current_t
971        let latest_t_id = self.id_gen.next();
972        committed_block
973            .body
974            .push(Instruction::ray_query_get_intersection(
975                spirv::Op::RayQueryGetIntersectionTKHR,
976                f32_type_id,
977                latest_t_id,
978                query_id,
979                intersection_id,
980            ));
981        committed_block
982            .body
983            .push(Instruction::store(current_t, latest_t_id, None));
984        function.consume(committed_block, Instruction::branch(next_valid_block_id));
985
986        let mut valid_block = Block::new(next_valid_block_id);
987
988        let t_ge_t_min = self.id_gen.next();
989        valid_block.body.push(Instruction::binary(
990            spirv::Op::FOrdGreaterThanEqual,
991            bool_type_id,
992            t_ge_t_min,
993            depth_id,
994            t_min_id,
995        ));
996        let t_current = self.id_gen.next();
997        valid_block
998            .body
999            .push(Instruction::load(f32_type_id, t_current, current_t, None));
1000        let t_le_t_current = self.id_gen.next();
1001        valid_block.body.push(Instruction::binary(
1002            spirv::Op::FOrdLessThanEqual,
1003            bool_type_id,
1004            t_le_t_current,
1005            depth_id,
1006            t_current,
1007        ));
1008
1009        let t_in_range = self.id_gen.next();
1010        valid_block.body.push(Instruction::binary(
1011            spirv::Op::LogicalAnd,
1012            bool_type_id,
1013            t_in_range,
1014            t_ge_t_min,
1015            t_le_t_current,
1016        ));
1017
1018        let call_valid_id = self.id_gen.next();
1019        valid_block.body.push(Instruction::binary(
1020            spirv::Op::LogicalAnd,
1021            bool_type_id,
1022            call_valid_id,
1023            t_in_range,
1024            intersection_aabb_id,
1025        ));
1026
1027        let generate_label_id = self.id_gen.next();
1028        let mut generate_block = Block::new(generate_label_id);
1029
1030        let merge_label_id = self.id_gen.next();
1031        let merge_block = Block::new(merge_label_id);
1032
1033        valid_block.body.push(Instruction::selection_merge(
1034            merge_label_id,
1035            spirv::SelectionControl::NONE,
1036        ));
1037        function.consume(
1038            valid_block,
1039            Instruction::branch_conditional(call_valid_id, generate_label_id, merge_label_id),
1040        );
1041
1042        generate_block
1043            .body
1044            .push(Instruction::ray_query_generate_intersection(
1045                query_id, depth_id,
1046            ));
1047
1048        function.consume(generate_block, Instruction::branch(merge_label_id));
1049        function.consume(merge_block, Instruction::branch(final_label_id));
1050
1051        function.consume(final_block, Instruction::return_void());
1052
1053        function.to_words(&mut self.logical_layout.function_definitions);
1054
1055        self.ray_query_functions
1056            .insert(LookupRayQueryFunction::GenerateIntersection, func_id);
1057        func_id
1058    }
1059
1060    fn write_ray_query_confirm_intersection(&mut self) -> spirv::Word {
1061        if let Some(&word) = self
1062            .ray_query_functions
1063            .get(&LookupRayQueryFunction::ConfirmIntersection)
1064        {
1065            return word;
1066        }
1067
1068        let ray_query_type_id = self.get_ray_query_pointer_id();
1069
1070        let u32_ty = self.get_u32_type_id();
1071        let u32_ptr_ty = self.get_pointer_type_id(u32_ty, spirv::StorageClass::Function);
1072
1073        let bool_type_id = self.get_bool_type_id();
1074
1075        let (func_id, mut function, arg_ids) =
1076            self.write_function_signature(&[ray_query_type_id, u32_ptr_ty], self.void_type);
1077
1078        let query_id = arg_ids[0];
1079        let init_tracker_id = arg_ids[1];
1080
1081        let block_id = self.id_gen.next();
1082        let mut block = Block::new(block_id);
1083
1084        let valid_id = self.id_gen.next();
1085        let mut valid_block = Block::new(valid_id);
1086
1087        let final_label_id = self.id_gen.next();
1088        let final_block = Block::new(final_label_id);
1089
1090        let instruction = if self.ray_query_initialization_tracking {
1091            let initialized_tracker_id = self.id_gen.next();
1092            block.body.push(Instruction::load(
1093                u32_ty,
1094                initialized_tracker_id,
1095                init_tracker_id,
1096                None,
1097            ));
1098
1099            let proceeded_id = write_ray_flags_contains_flags(
1100                self,
1101                &mut block,
1102                initialized_tracker_id,
1103                RayQueryPoint::PROCEED.bits(),
1104            );
1105            let finished_proceed_id = write_ray_flags_contains_flags(
1106                self,
1107                &mut block,
1108                initialized_tracker_id,
1109                RayQueryPoint::FINISHED_TRAVERSAL.bits(),
1110            );
1111            // Although it seems strange to call this twice, I (Vecvec) can't find anything to suggest double calling this function is invalid.
1112            let not_finished_id = self.id_gen.next();
1113            block.body.push(Instruction::unary(
1114                spirv::Op::LogicalNot,
1115                bool_type_id,
1116                not_finished_id,
1117                finished_proceed_id,
1118            ));
1119
1120            let is_valid_id = self.write_logical_and(&mut block, not_finished_id, proceeded_id);
1121
1122            block.body.push(Instruction::selection_merge(
1123                final_label_id,
1124                spirv::SelectionControl::NONE,
1125            ));
1126
1127            Instruction::branch_conditional(is_valid_id, valid_id, final_label_id)
1128        } else {
1129            Instruction::branch(valid_id)
1130        };
1131
1132        function.consume(block, instruction);
1133
1134        let intersection_id = self.get_constant_scalar(crate::Literal::U32(
1135            spirv::RayQueryIntersection::RayQueryCandidateIntersectionKHR as _,
1136        ));
1137        let raw_kind_id = self.id_gen.next();
1138        valid_block
1139            .body
1140            .push(Instruction::ray_query_get_intersection(
1141                spirv::Op::RayQueryGetIntersectionTypeKHR,
1142                u32_ty,
1143                raw_kind_id,
1144                query_id,
1145                intersection_id,
1146            ));
1147
1148        let candidate_tri_id = self.get_constant_scalar(crate::Literal::U32(
1149            spirv::RayQueryCandidateIntersectionType::RayQueryCandidateIntersectionTriangleKHR as _,
1150        ));
1151        let intersection_tri_id = self.id_gen.next();
1152        valid_block.body.push(Instruction::binary(
1153            spirv::Op::IEqual,
1154            bool_type_id,
1155            intersection_tri_id,
1156            raw_kind_id,
1157            candidate_tri_id,
1158        ));
1159
1160        let generate_label_id = self.id_gen.next();
1161        let mut generate_block = Block::new(generate_label_id);
1162
1163        let merge_label_id = self.id_gen.next();
1164        let merge_block = Block::new(merge_label_id);
1165
1166        valid_block.body.push(Instruction::selection_merge(
1167            merge_label_id,
1168            spirv::SelectionControl::NONE,
1169        ));
1170        function.consume(
1171            valid_block,
1172            Instruction::branch_conditional(intersection_tri_id, generate_label_id, merge_label_id),
1173        );
1174
1175        generate_block
1176            .body
1177            .push(Instruction::ray_query_confirm_intersection(query_id));
1178
1179        function.consume(generate_block, Instruction::branch(merge_label_id));
1180        function.consume(merge_block, Instruction::branch(final_label_id));
1181
1182        function.consume(final_block, Instruction::return_void());
1183
1184        self.ray_query_functions
1185            .insert(LookupRayQueryFunction::ConfirmIntersection, func_id);
1186
1187        function.to_words(&mut self.logical_layout.function_definitions);
1188
1189        func_id
1190    }
1191
1192    fn write_ray_query_get_vertex_positions(
1193        &mut self,
1194        is_committed: bool,
1195        ir_module: &crate::Module,
1196    ) -> spirv::Word {
1197        if let Some(&word) =
1198            self.ray_query_functions
1199                .get(&LookupRayQueryFunction::GetVertexPositions {
1200                    committed: is_committed,
1201                })
1202        {
1203            return word;
1204        }
1205
1206        let (committed_ty, committed_tri_ty) = if is_committed {
1207            (
1208                spirv::RayQueryIntersection::RayQueryCommittedIntersectionKHR as u32,
1209                spirv::RayQueryCommittedIntersectionType::RayQueryCommittedIntersectionTriangleKHR
1210                    as u32,
1211            )
1212        } else {
1213            (
1214                spirv::RayQueryIntersection::RayQueryCandidateIntersectionKHR as u32,
1215                spirv::RayQueryCandidateIntersectionType::RayQueryCandidateIntersectionTriangleKHR
1216                    as u32,
1217            )
1218        };
1219
1220        let ray_query_type_id = self.get_ray_query_pointer_id();
1221
1222        let u32_ty = self.get_u32_type_id();
1223        let u32_ptr_ty = self.get_pointer_type_id(u32_ty, spirv::StorageClass::Function);
1224
1225        let rq_get_vertex_positions_ty_id = self.get_handle_type_id(
1226            *ir_module
1227                .special_types
1228                .ray_vertex_return
1229                .as_ref()
1230                .expect("must be generated when reading in get vertex position"),
1231        );
1232        let ptr_return_ty =
1233            self.get_pointer_type_id(rq_get_vertex_positions_ty_id, spirv::StorageClass::Function);
1234
1235        let bool_type_id = self.get_bool_type_id();
1236
1237        let (func_id, mut function, arg_ids) = self.write_function_signature(
1238            &[ray_query_type_id, u32_ptr_ty],
1239            rq_get_vertex_positions_ty_id,
1240        );
1241
1242        let query_id = arg_ids[0];
1243        let init_tracker_id = arg_ids[1];
1244
1245        let block_id = self.id_gen.next();
1246        let mut block = Block::new(block_id);
1247
1248        let return_id = self.id_gen.next();
1249        block.body.push(Instruction::variable(
1250            ptr_return_ty,
1251            return_id,
1252            spirv::StorageClass::Function,
1253            Some(self.get_constant_null(rq_get_vertex_positions_ty_id)),
1254        ));
1255
1256        let valid_id = self.id_gen.next();
1257        let mut valid_block = Block::new(valid_id);
1258
1259        let final_label_id = self.id_gen.next();
1260        let mut final_block = Block::new(final_label_id);
1261
1262        let instruction = if self.ray_query_initialization_tracking {
1263            let initialized_tracker_id = self.id_gen.next();
1264            block.body.push(Instruction::load(
1265                u32_ty,
1266                initialized_tracker_id,
1267                init_tracker_id,
1268                None,
1269            ));
1270
1271            let proceeded_id = write_ray_flags_contains_flags(
1272                self,
1273                &mut block,
1274                initialized_tracker_id,
1275                RayQueryPoint::PROCEED.bits(),
1276            );
1277            let finished_proceed_id = write_ray_flags_contains_flags(
1278                self,
1279                &mut block,
1280                initialized_tracker_id,
1281                RayQueryPoint::FINISHED_TRAVERSAL.bits(),
1282            );
1283
1284            let correct_finish_id = if is_committed {
1285                finished_proceed_id
1286            } else {
1287                let not_finished_id = self.id_gen.next();
1288                block.body.push(Instruction::unary(
1289                    spirv::Op::LogicalNot,
1290                    bool_type_id,
1291                    not_finished_id,
1292                    finished_proceed_id,
1293                ));
1294                not_finished_id
1295            };
1296
1297            let is_valid_id = self.write_logical_and(&mut block, correct_finish_id, proceeded_id);
1298            block.body.push(Instruction::selection_merge(
1299                final_label_id,
1300                spirv::SelectionControl::NONE,
1301            ));
1302            Instruction::branch_conditional(is_valid_id, valid_id, final_label_id)
1303        } else {
1304            Instruction::branch(valid_id)
1305        };
1306
1307        function.consume(block, instruction);
1308
1309        let intersection_id = self.get_constant_scalar(crate::Literal::U32(committed_ty));
1310        let raw_kind_id = self.id_gen.next();
1311        valid_block
1312            .body
1313            .push(Instruction::ray_query_get_intersection(
1314                spirv::Op::RayQueryGetIntersectionTypeKHR,
1315                u32_ty,
1316                raw_kind_id,
1317                query_id,
1318                intersection_id,
1319            ));
1320
1321        let candidate_tri_id = self.get_constant_scalar(crate::Literal::U32(committed_tri_ty));
1322        let intersection_tri_id = self.id_gen.next();
1323        valid_block.body.push(Instruction::binary(
1324            spirv::Op::IEqual,
1325            bool_type_id,
1326            intersection_tri_id,
1327            raw_kind_id,
1328            candidate_tri_id,
1329        ));
1330
1331        let generate_label_id = self.id_gen.next();
1332        let mut vertex_return_block = Block::new(generate_label_id);
1333
1334        let merge_label_id = self.id_gen.next();
1335        let merge_block = Block::new(merge_label_id);
1336
1337        valid_block.body.push(Instruction::selection_merge(
1338            merge_label_id,
1339            spirv::SelectionControl::NONE,
1340        ));
1341        function.consume(
1342            valid_block,
1343            Instruction::branch_conditional(intersection_tri_id, generate_label_id, merge_label_id),
1344        );
1345
1346        let vertices_id = self.id_gen.next();
1347        vertex_return_block
1348            .body
1349            .push(Instruction::ray_query_return_vertex_position(
1350                rq_get_vertex_positions_ty_id,
1351                vertices_id,
1352                query_id,
1353                intersection_id,
1354            ));
1355        vertex_return_block
1356            .body
1357            .push(Instruction::store(return_id, vertices_id, None));
1358
1359        function.consume(vertex_return_block, Instruction::branch(merge_label_id));
1360        function.consume(merge_block, Instruction::branch(final_label_id));
1361
1362        let loaded_pos_id = self.id_gen.next();
1363        final_block.body.push(Instruction::load(
1364            rq_get_vertex_positions_ty_id,
1365            loaded_pos_id,
1366            return_id,
1367            None,
1368        ));
1369
1370        function.consume(final_block, Instruction::return_value(loaded_pos_id));
1371
1372        self.ray_query_functions.insert(
1373            LookupRayQueryFunction::GetVertexPositions {
1374                committed: is_committed,
1375            },
1376            func_id,
1377        );
1378
1379        function.to_words(&mut self.logical_layout.function_definitions);
1380
1381        func_id
1382    }
1383
1384    fn write_ray_query_terminate(&mut self) -> spirv::Word {
1385        if let Some(&word) = self
1386            .ray_query_functions
1387            .get(&LookupRayQueryFunction::Terminate)
1388        {
1389            return word;
1390        }
1391
1392        let ray_query_type_id = self.get_ray_query_pointer_id();
1393
1394        let u32_ty = self.get_u32_type_id();
1395        let u32_ptr_ty = self.get_pointer_type_id(u32_ty, spirv::StorageClass::Function);
1396
1397        let bool_type_id = self.get_bool_type_id();
1398
1399        let (func_id, mut function, arg_ids) =
1400            self.write_function_signature(&[ray_query_type_id, u32_ptr_ty], self.void_type);
1401
1402        let query_id = arg_ids[0];
1403        let init_tracker_id = arg_ids[1];
1404
1405        let block_id = self.id_gen.next();
1406        let mut block = Block::new(block_id);
1407
1408        let initialized_tracker_id = self.id_gen.next();
1409        block.body.push(Instruction::load(
1410            u32_ty,
1411            initialized_tracker_id,
1412            init_tracker_id,
1413            None,
1414        ));
1415
1416        let merge_id = self.id_gen.next();
1417        let merge_block = Block::new(merge_id);
1418
1419        let valid_block_id = self.id_gen.next();
1420        let mut valid_block = Block::new(valid_block_id);
1421
1422        let instruction = if self.ray_query_initialization_tracking {
1423            let has_proceeded = write_ray_flags_contains_flags(
1424                self,
1425                &mut block,
1426                initialized_tracker_id,
1427                RayQueryPoint::PROCEED.bits(),
1428            );
1429
1430            let finished_proceed_id = write_ray_flags_contains_flags(
1431                self,
1432                &mut block,
1433                initialized_tracker_id,
1434                RayQueryPoint::FINISHED_TRAVERSAL.bits(),
1435            );
1436
1437            let not_finished_id = self.id_gen.next();
1438            block.body.push(Instruction::unary(
1439                spirv::Op::LogicalNot,
1440                bool_type_id,
1441                not_finished_id,
1442                finished_proceed_id,
1443            ));
1444
1445            let valid_call = self.write_logical_and(&mut block, not_finished_id, has_proceeded);
1446
1447            block.body.push(Instruction::selection_merge(
1448                merge_id,
1449                spirv::SelectionControl::NONE,
1450            ));
1451
1452            Instruction::branch_conditional(valid_call, valid_block_id, merge_id)
1453        } else {
1454            Instruction::branch(valid_block_id)
1455        };
1456
1457        function.consume(block, instruction);
1458
1459        valid_block
1460            .body
1461            .push(Instruction::ray_query_terminate(query_id));
1462
1463        function.consume(valid_block, Instruction::branch(merge_id));
1464
1465        function.consume(merge_block, Instruction::return_void());
1466
1467        function.to_words(&mut self.logical_layout.function_definitions);
1468
1469        self.ray_query_functions
1470            .insert(LookupRayQueryFunction::Proceed, func_id);
1471        func_id
1472    }
1473}
1474
1475impl BlockContext<'_> {
1476    pub(in super::super) fn write_ray_query_function(
1477        &mut self,
1478        query: Handle<crate::Expression>,
1479        function: &crate::RayQueryFunction,
1480        block: &mut Block,
1481    ) {
1482        let query_id = self.cached[query];
1483        let tracker_ids = *self
1484            .ray_query_tracker_expr
1485            .get(&query)
1486            .expect("not a cached ray query");
1487
1488        match *function {
1489            crate::RayQueryFunction::Initialize {
1490                acceleration_structure,
1491                descriptor,
1492            } => {
1493                let desc_id = self.cached[descriptor];
1494                let acc_struct_id = self.get_handle_id(acceleration_structure);
1495
1496                let func = self.writer.write_ray_query_initialize(self.ir_module);
1497
1498                let func_id = self.gen_id();
1499                block.body.push(Instruction::function_call(
1500                    self.writer.void_type,
1501                    func_id,
1502                    func,
1503                    &[
1504                        query_id,
1505                        acc_struct_id,
1506                        desc_id,
1507                        tracker_ids.initialized_tracker,
1508                        tracker_ids.t_max_tracker,
1509                    ],
1510                ));
1511            }
1512            crate::RayQueryFunction::Proceed { result } => {
1513                let id = self.gen_id();
1514                self.cached[result] = id;
1515
1516                let bool_ty = self.writer.get_bool_type_id();
1517
1518                let func_id = self.writer.write_ray_query_proceed();
1519                block.body.push(Instruction::function_call(
1520                    bool_ty,
1521                    id,
1522                    func_id,
1523                    &[query_id, tracker_ids.initialized_tracker],
1524                ));
1525            }
1526            crate::RayQueryFunction::GenerateIntersection { hit_t } => {
1527                let hit_id = self.cached[hit_t];
1528
1529                let func_id = self.writer.write_ray_query_generate_intersection();
1530
1531                let func_call_id = self.gen_id();
1532                block.body.push(Instruction::function_call(
1533                    self.writer.void_type,
1534                    func_call_id,
1535                    func_id,
1536                    &[
1537                        query_id,
1538                        tracker_ids.initialized_tracker,
1539                        hit_id,
1540                        tracker_ids.t_max_tracker,
1541                    ],
1542                ));
1543            }
1544            crate::RayQueryFunction::ConfirmIntersection => {
1545                let func_id = self.writer.write_ray_query_confirm_intersection();
1546
1547                let func_call_id = self.gen_id();
1548                block.body.push(Instruction::function_call(
1549                    self.writer.void_type,
1550                    func_call_id,
1551                    func_id,
1552                    &[query_id, tracker_ids.initialized_tracker],
1553                ));
1554            }
1555            crate::RayQueryFunction::Terminate => {
1556                let id = self.gen_id();
1557
1558                let func_id = self.writer.write_ray_query_terminate();
1559                block.body.push(Instruction::function_call(
1560                    self.writer.void_type,
1561                    id,
1562                    func_id,
1563                    &[query_id, tracker_ids.initialized_tracker],
1564                ));
1565            }
1566        }
1567    }
1568
1569    pub(in super::super) fn write_ray_query_return_vertex_position(
1570        &mut self,
1571        query: Handle<crate::Expression>,
1572        block: &mut Block,
1573        is_committed: bool,
1574    ) -> spirv::Word {
1575        let fn_id = self
1576            .writer
1577            .write_ray_query_get_vertex_positions(is_committed, self.ir_module);
1578
1579        let query_id = self.cached[query];
1580        let tracker_id = *self
1581            .ray_query_tracker_expr
1582            .get(&query)
1583            .expect("not a cached ray query");
1584
1585        let rq_get_vertex_positions_ty_id = self.get_handle_type_id(
1586            *self
1587                .ir_module
1588                .special_types
1589                .ray_vertex_return
1590                .as_ref()
1591                .expect("must be generated when reading in get vertex position"),
1592        );
1593
1594        let func_call_id = self.gen_id();
1595        block.body.push(Instruction::function_call(
1596            rq_get_vertex_positions_ty_id,
1597            func_call_id,
1598            fn_id,
1599            &[query_id, tracker_id.initialized_tracker],
1600        ));
1601        func_call_id
1602    }
1603}