naga/back/spv/
ray.rs

1/*!
2Generating SPIR-V for ray query operations.
3*/
4
5use alloc::vec;
6
7use super::{
8    Block, BlockContext, Function, FunctionArgument, Instruction, LookupFunctionType, NumericType,
9    Writer,
10};
11use crate::arena::Handle;
12
13impl Writer {
14    pub(super) fn write_ray_query_get_intersection_function(
15        &mut self,
16        is_committed: bool,
17        ir_module: &crate::Module,
18    ) -> spirv::Word {
19        if is_committed {
20            if let Some(func_id) = self.ray_get_committed_intersection_function {
21                return func_id;
22            }
23        } else if let Some(func_id) = self.ray_get_candidate_intersection_function {
24            return func_id;
25        };
26        let ray_intersection = ir_module.special_types.ray_intersection.unwrap();
27        let intersection_type_id = self.get_handle_type_id(ray_intersection);
28        let intersection_pointer_type_id =
29            self.get_pointer_type_id(intersection_type_id, spirv::StorageClass::Function);
30
31        let flag_type_id = self.get_u32_type_id();
32        let flag_pointer_type_id =
33            self.get_pointer_type_id(flag_type_id, spirv::StorageClass::Function);
34
35        let transform_type_id = self.get_numeric_type_id(NumericType::Matrix {
36            columns: crate::VectorSize::Quad,
37            rows: crate::VectorSize::Tri,
38            scalar: crate::Scalar::F32,
39        });
40        let transform_pointer_type_id =
41            self.get_pointer_type_id(transform_type_id, spirv::StorageClass::Function);
42
43        let barycentrics_type_id = self.get_numeric_type_id(NumericType::Vector {
44            size: crate::VectorSize::Bi,
45            scalar: crate::Scalar::F32,
46        });
47        let barycentrics_pointer_type_id =
48            self.get_pointer_type_id(barycentrics_type_id, spirv::StorageClass::Function);
49
50        let bool_type_id = self.get_bool_type_id();
51        let bool_pointer_type_id =
52            self.get_pointer_type_id(bool_type_id, spirv::StorageClass::Function);
53
54        let scalar_type_id = self.get_f32_type_id();
55        let float_pointer_type_id = self.get_f32_pointer_type_id(spirv::StorageClass::Function);
56
57        let argument_type_id = self.get_ray_query_pointer_id();
58
59        let func_ty = self.get_function_type(LookupFunctionType {
60            parameter_type_ids: vec![argument_type_id],
61            return_type_id: intersection_type_id,
62        });
63
64        let mut function = Function::default();
65        let func_id = self.id_gen.next();
66        function.signature = Some(Instruction::function(
67            intersection_type_id,
68            func_id,
69            spirv::FunctionControl::empty(),
70            func_ty,
71        ));
72        let blank_intersection = self.get_constant_null(intersection_type_id);
73        let query_id = self.id_gen.next();
74        let instruction = Instruction::function_parameter(argument_type_id, query_id);
75        function.parameters.push(FunctionArgument {
76            instruction,
77            handle_id: 0,
78        });
79
80        let label_id = self.id_gen.next();
81        let mut block = Block::new(label_id);
82
83        let blank_intersection_id = self.id_gen.next();
84        block.body.push(Instruction::variable(
85            intersection_pointer_type_id,
86            blank_intersection_id,
87            spirv::StorageClass::Function,
88            Some(blank_intersection),
89        ));
90
91        let intersection_id = self.get_constant_scalar(crate::Literal::U32(if is_committed {
92            spirv::RayQueryIntersection::RayQueryCommittedIntersectionKHR
93        } else {
94            spirv::RayQueryIntersection::RayQueryCandidateIntersectionKHR
95        } as _));
96        let raw_kind_id = self.id_gen.next();
97        block.body.push(Instruction::ray_query_get_intersection(
98            spirv::Op::RayQueryGetIntersectionTypeKHR,
99            flag_type_id,
100            raw_kind_id,
101            query_id,
102            intersection_id,
103        ));
104        let kind_id = if is_committed {
105            // Nothing to do: the IR value matches `spirv::RayQueryCommittedIntersectionType`
106            raw_kind_id
107        } else {
108            // Remap from the candidate kind to IR
109            let condition_id = self.id_gen.next();
110            let committed_triangle_kind_id = self.get_constant_scalar(crate::Literal::U32(
111                spirv::RayQueryCandidateIntersectionType::RayQueryCandidateIntersectionTriangleKHR
112                    as _,
113            ));
114            block.body.push(Instruction::binary(
115                spirv::Op::IEqual,
116                self.get_bool_type_id(),
117                condition_id,
118                raw_kind_id,
119                committed_triangle_kind_id,
120            ));
121            let kind_id = self.id_gen.next();
122            block.body.push(Instruction::select(
123                flag_type_id,
124                kind_id,
125                condition_id,
126                self.get_constant_scalar(crate::Literal::U32(
127                    crate::RayQueryIntersection::Triangle as _,
128                )),
129                self.get_constant_scalar(crate::Literal::U32(
130                    crate::RayQueryIntersection::Aabb as _,
131                )),
132            ));
133            kind_id
134        };
135        let idx_id = self.get_index_constant(0);
136        let access_idx = self.id_gen.next();
137        block.body.push(Instruction::access_chain(
138            flag_pointer_type_id,
139            access_idx,
140            blank_intersection_id,
141            &[idx_id],
142        ));
143        block
144            .body
145            .push(Instruction::store(access_idx, kind_id, None));
146
147        let not_none_comp_id = self.id_gen.next();
148        let none_id =
149            self.get_constant_scalar(crate::Literal::U32(crate::RayQueryIntersection::None as _));
150        block.body.push(Instruction::binary(
151            spirv::Op::INotEqual,
152            self.get_bool_type_id(),
153            not_none_comp_id,
154            kind_id,
155            none_id,
156        ));
157
158        let not_none_label_id = self.id_gen.next();
159        let mut not_none_block = Block::new(not_none_label_id);
160
161        let final_label_id = self.id_gen.next();
162        let mut final_block = Block::new(final_label_id);
163
164        block.body.push(Instruction::selection_merge(
165            final_label_id,
166            spirv::SelectionControl::NONE,
167        ));
168        function.consume(
169            block,
170            Instruction::branch_conditional(not_none_comp_id, not_none_label_id, final_label_id),
171        );
172
173        let instance_custom_index_id = self.id_gen.next();
174        not_none_block
175            .body
176            .push(Instruction::ray_query_get_intersection(
177                spirv::Op::RayQueryGetIntersectionInstanceCustomIndexKHR,
178                flag_type_id,
179                instance_custom_index_id,
180                query_id,
181                intersection_id,
182            ));
183        let instance_id = self.id_gen.next();
184        not_none_block
185            .body
186            .push(Instruction::ray_query_get_intersection(
187                spirv::Op::RayQueryGetIntersectionInstanceIdKHR,
188                flag_type_id,
189                instance_id,
190                query_id,
191                intersection_id,
192            ));
193        let sbt_record_offset_id = self.id_gen.next();
194        not_none_block
195            .body
196            .push(Instruction::ray_query_get_intersection(
197                spirv::Op::RayQueryGetIntersectionInstanceShaderBindingTableRecordOffsetKHR,
198                flag_type_id,
199                sbt_record_offset_id,
200                query_id,
201                intersection_id,
202            ));
203        let geometry_index_id = self.id_gen.next();
204        not_none_block
205            .body
206            .push(Instruction::ray_query_get_intersection(
207                spirv::Op::RayQueryGetIntersectionGeometryIndexKHR,
208                flag_type_id,
209                geometry_index_id,
210                query_id,
211                intersection_id,
212            ));
213        let primitive_index_id = self.id_gen.next();
214        not_none_block
215            .body
216            .push(Instruction::ray_query_get_intersection(
217                spirv::Op::RayQueryGetIntersectionPrimitiveIndexKHR,
218                flag_type_id,
219                primitive_index_id,
220                query_id,
221                intersection_id,
222            ));
223
224        //Note: there is also `OpRayQueryGetIntersectionCandidateAABBOpaqueKHR`,
225        // but it's not a property of an intersection.
226
227        let object_to_world_id = self.id_gen.next();
228        not_none_block
229            .body
230            .push(Instruction::ray_query_get_intersection(
231                spirv::Op::RayQueryGetIntersectionObjectToWorldKHR,
232                transform_type_id,
233                object_to_world_id,
234                query_id,
235                intersection_id,
236            ));
237        let world_to_object_id = self.id_gen.next();
238        not_none_block
239            .body
240            .push(Instruction::ray_query_get_intersection(
241                spirv::Op::RayQueryGetIntersectionWorldToObjectKHR,
242                transform_type_id,
243                world_to_object_id,
244                query_id,
245                intersection_id,
246            ));
247
248        // instance custom index
249        let idx_id = self.get_index_constant(2);
250        let access_idx = self.id_gen.next();
251        not_none_block.body.push(Instruction::access_chain(
252            flag_pointer_type_id,
253            access_idx,
254            blank_intersection_id,
255            &[idx_id],
256        ));
257        not_none_block.body.push(Instruction::store(
258            access_idx,
259            instance_custom_index_id,
260            None,
261        ));
262
263        // instance
264        let idx_id = self.get_index_constant(3);
265        let access_idx = self.id_gen.next();
266        not_none_block.body.push(Instruction::access_chain(
267            flag_pointer_type_id,
268            access_idx,
269            blank_intersection_id,
270            &[idx_id],
271        ));
272        not_none_block
273            .body
274            .push(Instruction::store(access_idx, instance_id, None));
275
276        let idx_id = self.get_index_constant(4);
277        let access_idx = self.id_gen.next();
278        not_none_block.body.push(Instruction::access_chain(
279            flag_pointer_type_id,
280            access_idx,
281            blank_intersection_id,
282            &[idx_id],
283        ));
284        not_none_block
285            .body
286            .push(Instruction::store(access_idx, sbt_record_offset_id, None));
287
288        let idx_id = self.get_index_constant(5);
289        let access_idx = self.id_gen.next();
290        not_none_block.body.push(Instruction::access_chain(
291            flag_pointer_type_id,
292            access_idx,
293            blank_intersection_id,
294            &[idx_id],
295        ));
296        not_none_block
297            .body
298            .push(Instruction::store(access_idx, geometry_index_id, None));
299
300        let idx_id = self.get_index_constant(6);
301        let access_idx = self.id_gen.next();
302        not_none_block.body.push(Instruction::access_chain(
303            flag_pointer_type_id,
304            access_idx,
305            blank_intersection_id,
306            &[idx_id],
307        ));
308        not_none_block
309            .body
310            .push(Instruction::store(access_idx, primitive_index_id, None));
311
312        let idx_id = self.get_index_constant(9);
313        let access_idx = self.id_gen.next();
314        not_none_block.body.push(Instruction::access_chain(
315            transform_pointer_type_id,
316            access_idx,
317            blank_intersection_id,
318            &[idx_id],
319        ));
320        not_none_block
321            .body
322            .push(Instruction::store(access_idx, object_to_world_id, None));
323
324        let idx_id = self.get_index_constant(10);
325        let access_idx = self.id_gen.next();
326        not_none_block.body.push(Instruction::access_chain(
327            transform_pointer_type_id,
328            access_idx,
329            blank_intersection_id,
330            &[idx_id],
331        ));
332        not_none_block
333            .body
334            .push(Instruction::store(access_idx, world_to_object_id, None));
335
336        let tri_comp_id = self.id_gen.next();
337        let tri_id = self.get_constant_scalar(crate::Literal::U32(
338            crate::RayQueryIntersection::Triangle as _,
339        ));
340        not_none_block.body.push(Instruction::binary(
341            spirv::Op::IEqual,
342            self.get_bool_type_id(),
343            tri_comp_id,
344            kind_id,
345            tri_id,
346        ));
347
348        let tri_label_id = self.id_gen.next();
349        let mut tri_block = Block::new(tri_label_id);
350
351        let merge_label_id = self.id_gen.next();
352        let merge_block = Block::new(merge_label_id);
353        // t
354        {
355            let block = if is_committed {
356                &mut not_none_block
357            } else {
358                &mut tri_block
359            };
360            let t_id = self.id_gen.next();
361            block.body.push(Instruction::ray_query_get_intersection(
362                spirv::Op::RayQueryGetIntersectionTKHR,
363                scalar_type_id,
364                t_id,
365                query_id,
366                intersection_id,
367            ));
368            let idx_id = self.get_index_constant(1);
369            let access_idx = self.id_gen.next();
370            block.body.push(Instruction::access_chain(
371                float_pointer_type_id,
372                access_idx,
373                blank_intersection_id,
374                &[idx_id],
375            ));
376            block.body.push(Instruction::store(access_idx, t_id, None));
377        }
378        not_none_block.body.push(Instruction::selection_merge(
379            merge_label_id,
380            spirv::SelectionControl::NONE,
381        ));
382        function.consume(
383            not_none_block,
384            Instruction::branch_conditional(not_none_comp_id, tri_label_id, merge_label_id),
385        );
386
387        let barycentrics_id = self.id_gen.next();
388        tri_block.body.push(Instruction::ray_query_get_intersection(
389            spirv::Op::RayQueryGetIntersectionBarycentricsKHR,
390            barycentrics_type_id,
391            barycentrics_id,
392            query_id,
393            intersection_id,
394        ));
395
396        let front_face_id = self.id_gen.next();
397        tri_block.body.push(Instruction::ray_query_get_intersection(
398            spirv::Op::RayQueryGetIntersectionFrontFaceKHR,
399            bool_type_id,
400            front_face_id,
401            query_id,
402            intersection_id,
403        ));
404
405        let idx_id = self.get_index_constant(7);
406        let access_idx = self.id_gen.next();
407        tri_block.body.push(Instruction::access_chain(
408            barycentrics_pointer_type_id,
409            access_idx,
410            blank_intersection_id,
411            &[idx_id],
412        ));
413        tri_block
414            .body
415            .push(Instruction::store(access_idx, barycentrics_id, None));
416
417        let idx_id = self.get_index_constant(8);
418        let access_idx = self.id_gen.next();
419        tri_block.body.push(Instruction::access_chain(
420            bool_pointer_type_id,
421            access_idx,
422            blank_intersection_id,
423            &[idx_id],
424        ));
425        tri_block
426            .body
427            .push(Instruction::store(access_idx, front_face_id, None));
428        function.consume(tri_block, Instruction::branch(merge_label_id));
429        function.consume(merge_block, Instruction::branch(final_label_id));
430
431        let loaded_blank_intersection_id = self.id_gen.next();
432        final_block.body.push(Instruction::load(
433            intersection_type_id,
434            loaded_blank_intersection_id,
435            blank_intersection_id,
436            None,
437        ));
438        function.consume(
439            final_block,
440            Instruction::return_value(loaded_blank_intersection_id),
441        );
442
443        function.to_words(&mut self.logical_layout.function_definitions);
444        if is_committed {
445            self.ray_get_committed_intersection_function = Some(func_id);
446        } else {
447            self.ray_get_candidate_intersection_function = Some(func_id);
448        }
449        func_id
450    }
451}
452
453impl BlockContext<'_> {
454    pub(super) fn write_ray_query_function(
455        &mut self,
456        query: Handle<crate::Expression>,
457        function: &crate::RayQueryFunction,
458        block: &mut Block,
459    ) {
460        let query_id = self.cached[query];
461        match *function {
462            crate::RayQueryFunction::Initialize {
463                acceleration_structure,
464                descriptor,
465            } => {
466                //Note: composite extract indices and types must match `generate_ray_desc_type`
467                let desc_id = self.cached[descriptor];
468                let acc_struct_id = self.get_handle_id(acceleration_structure);
469
470                let flag_type_id =
471                    self.get_numeric_type_id(NumericType::Scalar(crate::Scalar::U32));
472                let ray_flags_id = self.gen_id();
473                block.body.push(Instruction::composite_extract(
474                    flag_type_id,
475                    ray_flags_id,
476                    desc_id,
477                    &[0],
478                ));
479                let cull_mask_id = self.gen_id();
480                block.body.push(Instruction::composite_extract(
481                    flag_type_id,
482                    cull_mask_id,
483                    desc_id,
484                    &[1],
485                ));
486
487                let scalar_type_id =
488                    self.get_numeric_type_id(NumericType::Scalar(crate::Scalar::F32));
489                let tmin_id = self.gen_id();
490                block.body.push(Instruction::composite_extract(
491                    scalar_type_id,
492                    tmin_id,
493                    desc_id,
494                    &[2],
495                ));
496                let tmax_id = self.gen_id();
497                block.body.push(Instruction::composite_extract(
498                    scalar_type_id,
499                    tmax_id,
500                    desc_id,
501                    &[3],
502                ));
503
504                let vector_type_id = self.get_numeric_type_id(NumericType::Vector {
505                    size: crate::VectorSize::Tri,
506                    scalar: crate::Scalar::F32,
507                });
508                let ray_origin_id = self.gen_id();
509                block.body.push(Instruction::composite_extract(
510                    vector_type_id,
511                    ray_origin_id,
512                    desc_id,
513                    &[4],
514                ));
515                let ray_dir_id = self.gen_id();
516                block.body.push(Instruction::composite_extract(
517                    vector_type_id,
518                    ray_dir_id,
519                    desc_id,
520                    &[5],
521                ));
522
523                block.body.push(Instruction::ray_query_initialize(
524                    query_id,
525                    acc_struct_id,
526                    ray_flags_id,
527                    cull_mask_id,
528                    ray_origin_id,
529                    tmin_id,
530                    ray_dir_id,
531                    tmax_id,
532                ));
533            }
534            crate::RayQueryFunction::Proceed { result } => {
535                let id = self.gen_id();
536                self.cached[result] = id;
537                let result_type_id = self.get_expression_type_id(&self.fun_info[result].ty);
538
539                block
540                    .body
541                    .push(Instruction::ray_query_proceed(result_type_id, id, query_id));
542            }
543            crate::RayQueryFunction::GenerateIntersection { hit_t } => {
544                let hit_id = self.cached[hit_t];
545                block
546                    .body
547                    .push(Instruction::ray_query_generate_intersection(
548                        query_id, hit_id,
549                    ));
550            }
551            crate::RayQueryFunction::ConfirmIntersection => {
552                block
553                    .body
554                    .push(Instruction::ray_query_confirm_intersection(query_id));
555            }
556            crate::RayQueryFunction::Terminate => {}
557        }
558    }
559
560    pub(super) fn write_ray_query_return_vertex_position(
561        &mut self,
562        query: Handle<crate::Expression>,
563        block: &mut Block,
564        is_committed: bool,
565    ) -> spirv::Word {
566        let query_id = self.cached[query];
567        let id = self.gen_id();
568        let ray_vertex_return_ty = self
569            .ir_module
570            .special_types
571            .ray_vertex_return
572            .expect("type should have been populated");
573        let ray_vertex_return_ty_id = self.writer.get_handle_type_id(ray_vertex_return_ty);
574        let intersection_id =
575            self.writer
576                .get_constant_scalar(crate::Literal::U32(if is_committed {
577                    spirv::RayQueryIntersection::RayQueryCommittedIntersectionKHR
578                } else {
579                    spirv::RayQueryIntersection::RayQueryCandidateIntersectionKHR
580                } as _));
581        block
582            .body
583            .push(Instruction::ray_query_return_vertex_position(
584                ray_vertex_return_ty_id,
585                id,
586                query_id,
587                intersection_id,
588            ));
589        id
590    }
591}