1use alloc::{vec, vec::Vec};
6
7use super::{
8 Block, BlockContext, Function, FunctionArgument, Instruction, LookupFunctionType, NumericType,
9 Writer,
10};
11use crate::{arena::Handle, back::spv::LookupRayQueryFunction};
12
13fn write_ray_flags_contains_flags(
15 writer: &mut Writer,
16 block: &mut Block,
17 id: spirv::Word,
18 flag: u32,
19) -> spirv::Word {
20 let bit_id = writer.get_constant_scalar(crate::Literal::U32(flag));
21 let zero_id = writer.get_constant_scalar(crate::Literal::U32(0));
22 let u32_type_id = writer.get_u32_type_id();
23 let bool_ty = writer.get_bool_type_id();
24
25 let and_id = writer.id_gen.next();
26 block.body.push(Instruction::binary(
27 spirv::Op::BitwiseAnd,
28 u32_type_id,
29 and_id,
30 id,
31 bit_id,
32 ));
33
34 let eq_id = writer.id_gen.next();
35 block.body.push(Instruction::binary(
36 spirv::Op::INotEqual,
37 bool_ty,
38 eq_id,
39 and_id,
40 zero_id,
41 ));
42
43 eq_id
44}
45
46impl Writer {
47 fn write_logical_and(
49 &mut self,
50 block: &mut Block,
51 one: spirv::Word,
52 two: spirv::Word,
53 ) -> spirv::Word {
54 let id = self.id_gen.next();
55 let bool_id = self.get_bool_type_id();
56 block.body.push(Instruction::binary(
57 spirv::Op::LogicalAnd,
58 bool_id,
59 id,
60 one,
61 two,
62 ));
63 id
64 }
65
66 fn write_reduce_and(&mut self, block: &mut Block, mut bools: Vec<spirv::Word>) -> spirv::Word {
67 let mut current_combined = bools.pop().unwrap();
69 for boolean in bools {
70 current_combined = self.write_logical_and(block, current_combined, boolean)
71 }
72 current_combined
73 }
74
75 fn write_function_signature(
77 &mut self,
78 arg_types: &[spirv::Word],
79 return_ty: spirv::Word,
80 ) -> (spirv::Word, Function, Vec<spirv::Word>) {
81 let func_ty = self.get_function_type(LookupFunctionType {
82 parameter_type_ids: Vec::from(arg_types),
83 return_type_id: return_ty,
84 });
85
86 let mut function = Function::default();
87 let func_id = self.id_gen.next();
88 function.signature = Some(Instruction::function(
89 return_ty,
90 func_id,
91 spirv::FunctionControl::empty(),
92 func_ty,
93 ));
94
95 let mut arg_ids = Vec::with_capacity(arg_types.len());
96
97 for (idx, &arg_ty) in arg_types.iter().enumerate() {
98 let id = self.id_gen.next();
99 let instruction = Instruction::function_parameter(arg_ty, id);
100 function.parameters.push(FunctionArgument {
101 instruction,
102 handle_id: idx as u32,
103 });
104 arg_ids.push(id);
105 }
106 (func_id, function, arg_ids)
107 }
108
109 pub(super) fn write_ray_query_get_intersection_function(
110 &mut self,
111 is_committed: bool,
112 ir_module: &crate::Module,
113 ) -> spirv::Word {
114 if let Some(&word) =
115 self.ray_query_functions
116 .get(&LookupRayQueryFunction::GetIntersection {
117 committed: is_committed,
118 })
119 {
120 return word;
121 }
122 let ray_intersection = ir_module.special_types.ray_intersection.unwrap();
123 let intersection_type_id = self.get_handle_type_id(ray_intersection);
124 let intersection_pointer_type_id =
125 self.get_pointer_type_id(intersection_type_id, spirv::StorageClass::Function);
126
127 let flag_type_id = self.get_u32_type_id();
128 let flag_pointer_type_id =
129 self.get_pointer_type_id(flag_type_id, spirv::StorageClass::Function);
130
131 let transform_type_id = self.get_numeric_type_id(NumericType::Matrix {
132 columns: crate::VectorSize::Quad,
133 rows: crate::VectorSize::Tri,
134 scalar: crate::Scalar::F32,
135 });
136 let transform_pointer_type_id =
137 self.get_pointer_type_id(transform_type_id, spirv::StorageClass::Function);
138
139 let barycentrics_type_id = self.get_numeric_type_id(NumericType::Vector {
140 size: crate::VectorSize::Bi,
141 scalar: crate::Scalar::F32,
142 });
143 let barycentrics_pointer_type_id =
144 self.get_pointer_type_id(barycentrics_type_id, spirv::StorageClass::Function);
145
146 let bool_type_id = self.get_bool_type_id();
147 let bool_pointer_type_id =
148 self.get_pointer_type_id(bool_type_id, spirv::StorageClass::Function);
149
150 let scalar_type_id = self.get_f32_type_id();
151 let float_pointer_type_id = self.get_f32_pointer_type_id(spirv::StorageClass::Function);
152
153 let argument_type_id = self.get_ray_query_pointer_id();
154
155 let (func_id, mut function, arg_ids) = self.write_function_signature(
156 &[argument_type_id, flag_pointer_type_id],
157 intersection_type_id,
158 );
159
160 let query_id = arg_ids[0];
161 let intersection_tracker_id = arg_ids[1];
162
163 let label_id = self.id_gen.next();
164 let mut block = Block::new(label_id);
165
166 let blank_intersection = self.get_constant_null(intersection_type_id);
167 let blank_intersection_id = self.id_gen.next();
168 block.body.push(Instruction::variable(
170 intersection_pointer_type_id,
171 blank_intersection_id,
172 spirv::StorageClass::Function,
173 Some(blank_intersection),
174 ));
175
176 let intersection_id = self.get_constant_scalar(crate::Literal::U32(if is_committed {
177 spirv::RayQueryIntersection::RayQueryCommittedIntersectionKHR
178 } else {
179 spirv::RayQueryIntersection::RayQueryCandidateIntersectionKHR
180 } as _));
181
182 let loaded_ray_query_tracker_id = self.id_gen.next();
183 block.body.push(Instruction::load(
184 flag_type_id,
185 loaded_ray_query_tracker_id,
186 intersection_tracker_id,
187 None,
188 ));
189 let proceeded_id = write_ray_flags_contains_flags(
190 self,
191 &mut block,
192 loaded_ray_query_tracker_id,
193 super::RayQueryPoint::PROCEED.bits(),
194 );
195 let finished_proceed_id = write_ray_flags_contains_flags(
196 self,
197 &mut block,
198 loaded_ray_query_tracker_id,
199 super::RayQueryPoint::FINISHED_TRAVERSAL.bits(),
200 );
201 let proceed_finished_correct_id = if is_committed {
202 finished_proceed_id
203 } else {
204 let not_finished_id = self.id_gen.next();
205 block.body.push(Instruction::unary(
206 spirv::Op::LogicalNot,
207 bool_type_id,
208 not_finished_id,
209 finished_proceed_id,
210 ));
211 not_finished_id
212 };
213
214 let is_valid_id =
215 self.write_logical_and(&mut block, proceed_finished_correct_id, proceeded_id);
216
217 let valid_id = self.id_gen.next();
218 let mut valid_block = Block::new(valid_id);
219
220 let final_label_id = self.id_gen.next();
221 let mut final_block = Block::new(final_label_id);
222
223 block.body.push(Instruction::selection_merge(
224 final_label_id,
225 spirv::SelectionControl::NONE,
226 ));
227 function.consume(
228 block,
229 Instruction::branch_conditional(is_valid_id, valid_id, final_label_id),
230 );
231
232 let raw_kind_id = self.id_gen.next();
233 valid_block
234 .body
235 .push(Instruction::ray_query_get_intersection(
236 spirv::Op::RayQueryGetIntersectionTypeKHR,
237 flag_type_id,
238 raw_kind_id,
239 query_id,
240 intersection_id,
241 ));
242 let kind_id = if is_committed {
243 raw_kind_id
245 } else {
246 let condition_id = self.id_gen.next();
248 let committed_triangle_kind_id = self.get_constant_scalar(crate::Literal::U32(
249 spirv::RayQueryCandidateIntersectionType::RayQueryCandidateIntersectionTriangleKHR
250 as _,
251 ));
252 valid_block.body.push(Instruction::binary(
253 spirv::Op::IEqual,
254 self.get_bool_type_id(),
255 condition_id,
256 raw_kind_id,
257 committed_triangle_kind_id,
258 ));
259 let kind_id = self.id_gen.next();
260 valid_block.body.push(Instruction::select(
261 flag_type_id,
262 kind_id,
263 condition_id,
264 self.get_constant_scalar(crate::Literal::U32(
265 crate::RayQueryIntersection::Triangle as _,
266 )),
267 self.get_constant_scalar(crate::Literal::U32(
268 crate::RayQueryIntersection::Aabb as _,
269 )),
270 ));
271 kind_id
272 };
273 let idx_id = self.get_index_constant(0);
274 let access_idx = self.id_gen.next();
275 valid_block.body.push(Instruction::access_chain(
276 flag_pointer_type_id,
277 access_idx,
278 blank_intersection_id,
279 &[idx_id],
280 ));
281 valid_block
282 .body
283 .push(Instruction::store(access_idx, kind_id, None));
284
285 let not_none_comp_id = self.id_gen.next();
286 let none_id =
287 self.get_constant_scalar(crate::Literal::U32(crate::RayQueryIntersection::None as _));
288 valid_block.body.push(Instruction::binary(
289 spirv::Op::INotEqual,
290 self.get_bool_type_id(),
291 not_none_comp_id,
292 kind_id,
293 none_id,
294 ));
295
296 let not_none_label_id = self.id_gen.next();
297 let mut not_none_block = Block::new(not_none_label_id);
298
299 let outer_merge_label_id = self.id_gen.next();
300 let outer_merge_block = Block::new(outer_merge_label_id);
301
302 valid_block.body.push(Instruction::selection_merge(
303 outer_merge_label_id,
304 spirv::SelectionControl::NONE,
305 ));
306 function.consume(
307 valid_block,
308 Instruction::branch_conditional(
309 not_none_comp_id,
310 not_none_label_id,
311 outer_merge_label_id,
312 ),
313 );
314
315 let instance_custom_index_id = self.id_gen.next();
316 not_none_block
317 .body
318 .push(Instruction::ray_query_get_intersection(
319 spirv::Op::RayQueryGetIntersectionInstanceCustomIndexKHR,
320 flag_type_id,
321 instance_custom_index_id,
322 query_id,
323 intersection_id,
324 ));
325 let instance_id = self.id_gen.next();
326 not_none_block
327 .body
328 .push(Instruction::ray_query_get_intersection(
329 spirv::Op::RayQueryGetIntersectionInstanceIdKHR,
330 flag_type_id,
331 instance_id,
332 query_id,
333 intersection_id,
334 ));
335 let sbt_record_offset_id = self.id_gen.next();
336 not_none_block
337 .body
338 .push(Instruction::ray_query_get_intersection(
339 spirv::Op::RayQueryGetIntersectionInstanceShaderBindingTableRecordOffsetKHR,
340 flag_type_id,
341 sbt_record_offset_id,
342 query_id,
343 intersection_id,
344 ));
345 let geometry_index_id = self.id_gen.next();
346 not_none_block
347 .body
348 .push(Instruction::ray_query_get_intersection(
349 spirv::Op::RayQueryGetIntersectionGeometryIndexKHR,
350 flag_type_id,
351 geometry_index_id,
352 query_id,
353 intersection_id,
354 ));
355 let primitive_index_id = self.id_gen.next();
356 not_none_block
357 .body
358 .push(Instruction::ray_query_get_intersection(
359 spirv::Op::RayQueryGetIntersectionPrimitiveIndexKHR,
360 flag_type_id,
361 primitive_index_id,
362 query_id,
363 intersection_id,
364 ));
365
366 let object_to_world_id = self.id_gen.next();
370 not_none_block
371 .body
372 .push(Instruction::ray_query_get_intersection(
373 spirv::Op::RayQueryGetIntersectionObjectToWorldKHR,
374 transform_type_id,
375 object_to_world_id,
376 query_id,
377 intersection_id,
378 ));
379 let world_to_object_id = self.id_gen.next();
380 not_none_block
381 .body
382 .push(Instruction::ray_query_get_intersection(
383 spirv::Op::RayQueryGetIntersectionWorldToObjectKHR,
384 transform_type_id,
385 world_to_object_id,
386 query_id,
387 intersection_id,
388 ));
389
390 let idx_id = self.get_index_constant(2);
392 let access_idx = self.id_gen.next();
393 not_none_block.body.push(Instruction::access_chain(
394 flag_pointer_type_id,
395 access_idx,
396 blank_intersection_id,
397 &[idx_id],
398 ));
399 not_none_block.body.push(Instruction::store(
400 access_idx,
401 instance_custom_index_id,
402 None,
403 ));
404
405 let idx_id = self.get_index_constant(3);
407 let access_idx = self.id_gen.next();
408 not_none_block.body.push(Instruction::access_chain(
409 flag_pointer_type_id,
410 access_idx,
411 blank_intersection_id,
412 &[idx_id],
413 ));
414 not_none_block
415 .body
416 .push(Instruction::store(access_idx, instance_id, None));
417
418 let idx_id = self.get_index_constant(4);
419 let access_idx = self.id_gen.next();
420 not_none_block.body.push(Instruction::access_chain(
421 flag_pointer_type_id,
422 access_idx,
423 blank_intersection_id,
424 &[idx_id],
425 ));
426 not_none_block
427 .body
428 .push(Instruction::store(access_idx, sbt_record_offset_id, None));
429
430 let idx_id = self.get_index_constant(5);
431 let access_idx = self.id_gen.next();
432 not_none_block.body.push(Instruction::access_chain(
433 flag_pointer_type_id,
434 access_idx,
435 blank_intersection_id,
436 &[idx_id],
437 ));
438 not_none_block
439 .body
440 .push(Instruction::store(access_idx, geometry_index_id, None));
441
442 let idx_id = self.get_index_constant(6);
443 let access_idx = self.id_gen.next();
444 not_none_block.body.push(Instruction::access_chain(
445 flag_pointer_type_id,
446 access_idx,
447 blank_intersection_id,
448 &[idx_id],
449 ));
450 not_none_block
451 .body
452 .push(Instruction::store(access_idx, primitive_index_id, None));
453
454 let idx_id = self.get_index_constant(9);
455 let access_idx = self.id_gen.next();
456 not_none_block.body.push(Instruction::access_chain(
457 transform_pointer_type_id,
458 access_idx,
459 blank_intersection_id,
460 &[idx_id],
461 ));
462 not_none_block
463 .body
464 .push(Instruction::store(access_idx, object_to_world_id, None));
465
466 let idx_id = self.get_index_constant(10);
467 let access_idx = self.id_gen.next();
468 not_none_block.body.push(Instruction::access_chain(
469 transform_pointer_type_id,
470 access_idx,
471 blank_intersection_id,
472 &[idx_id],
473 ));
474 not_none_block
475 .body
476 .push(Instruction::store(access_idx, world_to_object_id, None));
477
478 let tri_comp_id = self.id_gen.next();
479 let tri_id = self.get_constant_scalar(crate::Literal::U32(
480 crate::RayQueryIntersection::Triangle as _,
481 ));
482 not_none_block.body.push(Instruction::binary(
483 spirv::Op::IEqual,
484 self.get_bool_type_id(),
485 tri_comp_id,
486 kind_id,
487 tri_id,
488 ));
489
490 let tri_label_id = self.id_gen.next();
491 let mut tri_block = Block::new(tri_label_id);
492
493 let merge_label_id = self.id_gen.next();
494 let merge_block = Block::new(merge_label_id);
495 {
497 let block = if is_committed {
498 &mut not_none_block
499 } else {
500 &mut tri_block
501 };
502 let t_id = self.id_gen.next();
503 block.body.push(Instruction::ray_query_get_intersection(
504 spirv::Op::RayQueryGetIntersectionTKHR,
505 scalar_type_id,
506 t_id,
507 query_id,
508 intersection_id,
509 ));
510 let idx_id = self.get_index_constant(1);
511 let access_idx = self.id_gen.next();
512 block.body.push(Instruction::access_chain(
513 float_pointer_type_id,
514 access_idx,
515 blank_intersection_id,
516 &[idx_id],
517 ));
518 block.body.push(Instruction::store(access_idx, t_id, None));
519 }
520 not_none_block.body.push(Instruction::selection_merge(
521 merge_label_id,
522 spirv::SelectionControl::NONE,
523 ));
524 function.consume(
525 not_none_block,
526 Instruction::branch_conditional(not_none_comp_id, tri_label_id, merge_label_id),
527 );
528
529 let barycentrics_id = self.id_gen.next();
530 tri_block.body.push(Instruction::ray_query_get_intersection(
531 spirv::Op::RayQueryGetIntersectionBarycentricsKHR,
532 barycentrics_type_id,
533 barycentrics_id,
534 query_id,
535 intersection_id,
536 ));
537
538 let front_face_id = self.id_gen.next();
539 tri_block.body.push(Instruction::ray_query_get_intersection(
540 spirv::Op::RayQueryGetIntersectionFrontFaceKHR,
541 bool_type_id,
542 front_face_id,
543 query_id,
544 intersection_id,
545 ));
546
547 let idx_id = self.get_index_constant(7);
548 let access_idx = self.id_gen.next();
549 tri_block.body.push(Instruction::access_chain(
550 barycentrics_pointer_type_id,
551 access_idx,
552 blank_intersection_id,
553 &[idx_id],
554 ));
555 tri_block
556 .body
557 .push(Instruction::store(access_idx, barycentrics_id, None));
558
559 let idx_id = self.get_index_constant(8);
560 let access_idx = self.id_gen.next();
561 tri_block.body.push(Instruction::access_chain(
562 bool_pointer_type_id,
563 access_idx,
564 blank_intersection_id,
565 &[idx_id],
566 ));
567 tri_block
568 .body
569 .push(Instruction::store(access_idx, front_face_id, None));
570 function.consume(tri_block, Instruction::branch(merge_label_id));
571 function.consume(merge_block, Instruction::branch(outer_merge_label_id));
572 function.consume(outer_merge_block, Instruction::branch(final_label_id));
573
574 let loaded_blank_intersection_id = self.id_gen.next();
575 final_block.body.push(Instruction::load(
576 intersection_type_id,
577 loaded_blank_intersection_id,
578 blank_intersection_id,
579 None,
580 ));
581 function.consume(
582 final_block,
583 Instruction::return_value(loaded_blank_intersection_id),
584 );
585
586 function.to_words(&mut self.logical_layout.function_definitions);
587 self.ray_query_functions.insert(
588 LookupRayQueryFunction::GetIntersection {
589 committed: is_committed,
590 },
591 func_id,
592 );
593 func_id
594 }
595
596 fn write_ray_query_initialize(&mut self, ir_module: &crate::Module) -> spirv::Word {
597 if let Some(&word) = self
598 .ray_query_functions
599 .get(&LookupRayQueryFunction::Initialize)
600 {
601 return word;
602 }
603
604 let ray_query_type_id = self.get_ray_query_pointer_id();
605 let acceleration_structure_type_id =
606 self.get_localtype_id(super::LocalType::AccelerationStructure);
607 let ray_desc_type_id = self.get_handle_type_id(
608 ir_module
609 .special_types
610 .ray_desc
611 .expect("ray desc should be set if ray queries are being initialized"),
612 );
613
614 let u32_ty = self.get_u32_type_id();
615 let u32_ptr_ty = self.get_pointer_type_id(u32_ty, spirv::StorageClass::Function);
616
617 let f32_type_id = self.get_f32_type_id();
618 let f32_ptr_ty = self.get_pointer_type_id(f32_type_id, spirv::StorageClass::Function);
619
620 let bool_type_id = self.get_bool_type_id();
621 let bool_vec3_type_id = self.get_vec3_bool_type_id();
622
623 let (func_id, mut function, arg_ids) = self.write_function_signature(
624 &[
625 ray_query_type_id,
626 acceleration_structure_type_id,
627 ray_desc_type_id,
628 u32_ptr_ty,
629 f32_ptr_ty,
630 ],
631 self.void_type,
632 );
633
634 let query_id = arg_ids[0];
635 let acceleration_structure_id = arg_ids[1];
636 let desc_id = arg_ids[2];
637 let init_tracker_id = arg_ids[3];
638 let t_max_tracker_id = arg_ids[4];
639
640 let label_id = self.id_gen.next();
641 let mut block = Block::new(label_id);
642
643 let flag_type_id = self.get_numeric_type_id(NumericType::Scalar(crate::Scalar::U32));
644
645 let ray_flags_id = self.id_gen.next();
647 block.body.push(Instruction::composite_extract(
648 flag_type_id,
649 ray_flags_id,
650 desc_id,
651 &[0],
652 ));
653 let cull_mask_id = self.id_gen.next();
654 block.body.push(Instruction::composite_extract(
655 flag_type_id,
656 cull_mask_id,
657 desc_id,
658 &[1],
659 ));
660
661 let tmin_id = self.id_gen.next();
662 block.body.push(Instruction::composite_extract(
663 f32_type_id,
664 tmin_id,
665 desc_id,
666 &[2],
667 ));
668 let tmax_id = self.id_gen.next();
669 block.body.push(Instruction::composite_extract(
670 f32_type_id,
671 tmax_id,
672 desc_id,
673 &[3],
674 ));
675 block
676 .body
677 .push(Instruction::store(t_max_tracker_id, tmax_id, None));
678
679 let vector_type_id = self.get_numeric_type_id(NumericType::Vector {
680 size: crate::VectorSize::Tri,
681 scalar: crate::Scalar::F32,
682 });
683 let ray_origin_id = self.id_gen.next();
684 block.body.push(Instruction::composite_extract(
685 vector_type_id,
686 ray_origin_id,
687 desc_id,
688 &[4],
689 ));
690 let ray_dir_id = self.id_gen.next();
691 block.body.push(Instruction::composite_extract(
692 vector_type_id,
693 ray_dir_id,
694 desc_id,
695 &[5],
696 ));
697
698 let valid_id = self.ray_query_initialization_tracking.then(||{
699 let tmin_le_tmax_id = self.id_gen.next();
700 block.body.push(Instruction::binary(
704 spirv::Op::FOrdLessThanEqual,
705 bool_type_id,
706 tmin_le_tmax_id,
707 tmin_id,
708 tmax_id,
709 ));
710
711 let tmin_ge_zero_id = self.id_gen.next();
715 let zero_id = self.get_constant_scalar(crate::Literal::F32(0.0));
716 block.body.push(Instruction::binary(
717 spirv::Op::FOrdGreaterThanEqual,
718 bool_type_id,
719 tmin_ge_zero_id,
720 tmin_id,
721 zero_id,
722 ));
723
724 let ray_origin_infinite_id = self.id_gen.next();
726 block.body.push(Instruction::unary(
727 spirv::Op::IsInf,
728 bool_vec3_type_id,
729 ray_origin_infinite_id,
730 ray_origin_id,
731 ));
732 let any_ray_origin_infinite_id = self.id_gen.next();
733 block.body.push(Instruction::unary(
734 spirv::Op::Any,
735 bool_type_id,
736 any_ray_origin_infinite_id,
737 ray_origin_infinite_id,
738 ));
739
740 let ray_origin_nan_id = self.id_gen.next();
741 block.body.push(Instruction::unary(
742 spirv::Op::IsNan,
743 bool_vec3_type_id,
744 ray_origin_nan_id,
745 ray_origin_id,
746 ));
747 let any_ray_origin_nan_id = self.id_gen.next();
748 block.body.push(Instruction::unary(
749 spirv::Op::Any,
750 bool_type_id,
751 any_ray_origin_nan_id,
752 ray_origin_nan_id,
753 ));
754
755 let ray_origin_not_finite_id = self.id_gen.next();
756 block.body.push(Instruction::binary(
757 spirv::Op::LogicalOr,
758 bool_type_id,
759 ray_origin_not_finite_id,
760 any_ray_origin_nan_id,
761 any_ray_origin_infinite_id,
762 ));
763
764 let all_ray_origin_finite_id = self.id_gen.next();
765 block.body.push(Instruction::unary(
766 spirv::Op::LogicalNot,
767 bool_type_id,
768 all_ray_origin_finite_id,
769 ray_origin_not_finite_id,
770 ));
771
772 let ray_dir_infinite_id = self.id_gen.next();
774 block.body.push(Instruction::unary(
775 spirv::Op::IsInf,
776 bool_vec3_type_id,
777 ray_dir_infinite_id,
778 ray_dir_id,
779 ));
780 let any_ray_dir_infinite_id = self.id_gen.next();
781 block.body.push(Instruction::unary(
782 spirv::Op::Any,
783 bool_type_id,
784 any_ray_dir_infinite_id,
785 ray_dir_infinite_id,
786 ));
787
788 let ray_dir_nan_id = self.id_gen.next();
789 block.body.push(Instruction::unary(
790 spirv::Op::IsNan,
791 bool_vec3_type_id,
792 ray_dir_nan_id,
793 ray_dir_id,
794 ));
795 let any_ray_dir_nan_id = self.id_gen.next();
796 block.body.push(Instruction::unary(
797 spirv::Op::Any,
798 bool_type_id,
799 any_ray_dir_nan_id,
800 ray_dir_nan_id,
801 ));
802
803 let ray_dir_not_finite_id = self.id_gen.next();
804 block.body.push(Instruction::binary(
805 spirv::Op::LogicalOr,
806 bool_type_id,
807 ray_dir_not_finite_id,
808 any_ray_dir_nan_id,
809 any_ray_dir_infinite_id,
810 ));
811
812 let all_ray_dir_finite_id = self.id_gen.next();
813 block.body.push(Instruction::unary(
814 spirv::Op::LogicalNot,
815 bool_type_id,
816 all_ray_dir_finite_id,
817 ray_dir_not_finite_id,
818 ));
819
820 fn write_less_than_2_true(
825 writer: &mut Writer,
826 block: &mut Block,
827 mut bools: Vec<spirv::Word>,
828 ) -> spirv::Word {
829 assert!(bools.len() > 1, "Must have multiple booleans!");
830 let bool_ty = writer.get_bool_type_id();
831 let mut each_two_true = Vec::new();
832 while let Some(last_bool) = bools.pop() {
833 for &bool in &bools {
834 let both_true_id = writer.write_logical_and(
835 block,
836 last_bool,
837 bool,
838 );
839 each_two_true.push(both_true_id);
840 }
841 }
842 let mut all_or_id = each_two_true.pop().expect("since this must have multiple booleans, there must be at least one thing in `each_two_true`");
843 for two_true in each_two_true {
844 let new_all_or_id = writer.id_gen.next();
845 block.body.push(Instruction::binary(
846 spirv::Op::LogicalOr,
847 bool_ty,
848 new_all_or_id,
849 all_or_id,
850 two_true,
851 ));
852 all_or_id = new_all_or_id;
853 }
854
855 let less_than_two_id = writer.id_gen.next();
856 block.body.push(Instruction::unary(
857 spirv::Op::LogicalNot,
858 bool_ty,
859 less_than_two_id,
860 all_or_id,
861 ));
862 less_than_two_id
863 }
864
865 let contains_skip_triangles = write_ray_flags_contains_flags(
868 self,
869 &mut block,
870 ray_flags_id,
871 crate::RayFlag::SKIP_TRIANGLES.bits(),
872 );
873 let contains_skip_aabbs = write_ray_flags_contains_flags(
874 self,
875 &mut block,
876 ray_flags_id,
877 crate::RayFlag::SKIP_AABBS.bits(),
878 );
879
880 let not_contain_skip_triangles_aabbs = write_less_than_2_true(
881 self,
882 &mut block,
883 vec![contains_skip_triangles, contains_skip_aabbs],
884 );
885
886 let contains_cull_back = write_ray_flags_contains_flags(
889 self,
890 &mut block,
891 ray_flags_id,
892 crate::RayFlag::CULL_BACK_FACING.bits(),
893 );
894 let contains_cull_front = write_ray_flags_contains_flags(
895 self,
896 &mut block,
897 ray_flags_id,
898 crate::RayFlag::CULL_FRONT_FACING.bits(),
899 );
900
901 let not_contain_skip_triangles_cull = write_less_than_2_true(
902 self,
903 &mut block,
904 vec![
905 contains_skip_triangles,
906 contains_cull_back,
907 contains_cull_front,
908 ],
909 );
910
911 let contains_opaque = write_ray_flags_contains_flags(
914 self,
915 &mut block,
916 ray_flags_id,
917 crate::RayFlag::FORCE_OPAQUE.bits(),
918 );
919 let contains_no_opaque = write_ray_flags_contains_flags(
920 self,
921 &mut block,
922 ray_flags_id,
923 crate::RayFlag::FORCE_NO_OPAQUE.bits(),
924 );
925 let contains_cull_opaque = write_ray_flags_contains_flags(
926 self,
927 &mut block,
928 ray_flags_id,
929 crate::RayFlag::CULL_OPAQUE.bits(),
930 );
931 let contains_cull_no_opaque = write_ray_flags_contains_flags(
932 self,
933 &mut block,
934 ray_flags_id,
935 crate::RayFlag::CULL_NO_OPAQUE.bits(),
936 );
937
938 let not_contain_multiple_opaque = write_less_than_2_true(
939 self,
940 &mut block,
941 vec![
942 contains_opaque,
943 contains_no_opaque,
944 contains_cull_opaque,
945 contains_cull_no_opaque,
946 ],
947 );
948
949 self.write_reduce_and(
951 &mut block,
952 vec![
953 tmin_le_tmax_id,
954 tmin_ge_zero_id,
955 all_ray_origin_finite_id,
956 all_ray_dir_finite_id,
957 not_contain_skip_triangles_aabbs,
958 not_contain_skip_triangles_cull,
959 not_contain_multiple_opaque,
960 ],
961 )
962 });
963
964 let merge_label_id = self.id_gen.next();
965 let merge_block = Block::new(merge_label_id);
966
967 let invalid_label_id = self.id_gen.next();
969 let mut invalid_block = Block::new(invalid_label_id);
970
971 let valid_label_id = self.id_gen.next();
972 let mut valid_block = Block::new(valid_label_id);
973
974 match valid_id {
975 Some(all_valid_id) => {
976 block.body.push(Instruction::selection_merge(
977 merge_label_id,
978 spirv::SelectionControl::NONE,
979 ));
980 function.consume(
981 block,
982 Instruction::branch_conditional(all_valid_id, valid_label_id, invalid_label_id),
983 );
984 }
985 None => {
986 function.consume(block, Instruction::branch(valid_label_id));
987 }
988 }
989
990 valid_block.body.push(Instruction::ray_query_initialize(
991 query_id,
992 acceleration_structure_id,
993 ray_flags_id,
994 cull_mask_id,
995 ray_origin_id,
996 tmin_id,
997 ray_dir_id,
998 tmax_id,
999 ));
1000
1001 let const_initialized = self.get_constant_scalar(crate::Literal::U32(
1002 super::RayQueryPoint::INITIALIZED.bits(),
1003 ));
1004 valid_block
1005 .body
1006 .push(Instruction::store(init_tracker_id, const_initialized, None));
1007
1008 function.consume(valid_block, Instruction::branch(merge_label_id));
1009
1010 if self
1011 .flags
1012 .contains(super::WriterFlags::PRINT_ON_RAY_QUERY_INITIALIZATION_FAIL)
1013 {
1014 self.write_debug_printf(
1015 &mut invalid_block,
1016 "Naga ignored invalid arguments to rayQueryInitialize with flags: %u t_min: %f t_max: %f origin: %v4f dir: %v4f",
1017 &[
1018 ray_flags_id,
1019 tmin_id,
1020 tmax_id,
1021 ray_origin_id,
1022 ray_dir_id,
1023 ],
1024 );
1025 }
1026
1027 function.consume(invalid_block, Instruction::branch(merge_label_id));
1028
1029 function.consume(merge_block, Instruction::return_void());
1030
1031 function.to_words(&mut self.logical_layout.function_definitions);
1032
1033 self.ray_query_functions
1034 .insert(LookupRayQueryFunction::Initialize, func_id);
1035 func_id
1036 }
1037
1038 fn write_ray_query_proceed(&mut self) -> spirv::Word {
1039 if let Some(&word) = self
1040 .ray_query_functions
1041 .get(&LookupRayQueryFunction::Proceed)
1042 {
1043 return word;
1044 }
1045
1046 let ray_query_type_id = self.get_ray_query_pointer_id();
1047
1048 let u32_ty = self.get_u32_type_id();
1049 let u32_ptr_ty = self.get_pointer_type_id(u32_ty, spirv::StorageClass::Function);
1050
1051 let bool_type_id = self.get_bool_type_id();
1052 let bool_ptr_ty = self.get_pointer_type_id(bool_type_id, spirv::StorageClass::Function);
1053
1054 let (func_id, mut function, arg_ids) =
1055 self.write_function_signature(&[ray_query_type_id, u32_ptr_ty], bool_type_id);
1056
1057 let query_id = arg_ids[0];
1058 let init_tracker_id = arg_ids[1];
1059
1060 let block_id = self.id_gen.next();
1061 let mut block = Block::new(block_id);
1062
1063 let proceeded_id = self.id_gen.next();
1065 let const_false = self.get_constant_scalar(crate::Literal::Bool(false));
1066 block.body.push(Instruction::variable(
1067 bool_ptr_ty,
1068 proceeded_id,
1069 spirv::StorageClass::Function,
1070 Some(const_false),
1071 ));
1072
1073 let initialized_tracker_id = self.id_gen.next();
1074 block.body.push(Instruction::load(
1075 u32_ty,
1076 initialized_tracker_id,
1077 init_tracker_id,
1078 None,
1079 ));
1080
1081 let merge_id = self.id_gen.next();
1082 let mut merge_block = Block::new(merge_id);
1083
1084 let valid_block_id = self.id_gen.next();
1085 let mut valid_block = Block::new(valid_block_id);
1086
1087 let instruction = if self.ray_query_initialization_tracking {
1088 let is_initialized = write_ray_flags_contains_flags(
1089 self,
1090 &mut block,
1091 initialized_tracker_id,
1092 super::RayQueryPoint::INITIALIZED.bits(),
1093 );
1094
1095 block.body.push(Instruction::selection_merge(
1096 merge_id,
1097 spirv::SelectionControl::NONE,
1098 ));
1099
1100 Instruction::branch_conditional(is_initialized, valid_block_id, merge_id)
1101 } else {
1102 Instruction::branch(valid_block_id)
1103 };
1104
1105 function.consume(block, instruction);
1106
1107 let has_proceeded = self.id_gen.next();
1108 valid_block.body.push(Instruction::ray_query_proceed(
1109 bool_type_id,
1110 has_proceeded,
1111 query_id,
1112 ));
1113
1114 valid_block
1115 .body
1116 .push(Instruction::store(proceeded_id, has_proceeded, None));
1117
1118 let add_flag_finished = self.get_constant_scalar(crate::Literal::U32(
1119 (super::RayQueryPoint::PROCEED | super::RayQueryPoint::FINISHED_TRAVERSAL).bits(),
1120 ));
1121 let add_flag_continuing =
1122 self.get_constant_scalar(crate::Literal::U32(super::RayQueryPoint::PROCEED.bits()));
1123
1124 let add_flags_id = self.id_gen.next();
1125 valid_block.body.push(Instruction::select(
1126 u32_ty,
1127 add_flags_id,
1128 has_proceeded,
1129 add_flag_continuing,
1130 add_flag_finished,
1131 ));
1132 let final_flags = self.id_gen.next();
1133 valid_block.body.push(Instruction::binary(
1134 spirv::Op::BitwiseOr,
1135 u32_ty,
1136 final_flags,
1137 initialized_tracker_id,
1138 add_flags_id,
1139 ));
1140 valid_block
1141 .body
1142 .push(Instruction::store(init_tracker_id, final_flags, None));
1143
1144 function.consume(valid_block, Instruction::branch(merge_id));
1145
1146 let loaded_proceeded_id = self.id_gen.next();
1147 merge_block.body.push(Instruction::load(
1148 bool_type_id,
1149 loaded_proceeded_id,
1150 proceeded_id,
1151 None,
1152 ));
1153
1154 function.consume(merge_block, Instruction::return_value(loaded_proceeded_id));
1155
1156 function.to_words(&mut self.logical_layout.function_definitions);
1157
1158 self.ray_query_functions
1159 .insert(LookupRayQueryFunction::Proceed, func_id);
1160 func_id
1161 }
1162
1163 fn write_ray_query_generate_intersection(&mut self) -> spirv::Word {
1164 if let Some(&word) = self
1165 .ray_query_functions
1166 .get(&LookupRayQueryFunction::GenerateIntersection)
1167 {
1168 return word;
1169 }
1170
1171 let ray_query_type_id = self.get_ray_query_pointer_id();
1172
1173 let u32_ty = self.get_u32_type_id();
1174 let u32_ptr_ty = self.get_pointer_type_id(u32_ty, spirv::StorageClass::Function);
1175
1176 let f32_type_id = self.get_f32_type_id();
1177 let f32_ptr_type_id = self.get_pointer_type_id(f32_type_id, spirv::StorageClass::Function);
1178
1179 let bool_type_id = self.get_bool_type_id();
1180
1181 let (func_id, mut function, arg_ids) = self.write_function_signature(
1182 &[ray_query_type_id, u32_ptr_ty, f32_type_id, f32_ptr_type_id],
1183 self.void_type,
1184 );
1185
1186 let query_id = arg_ids[0];
1187 let init_tracker_id = arg_ids[1];
1188 let depth_id = arg_ids[2];
1189 let t_max_tracker_id = arg_ids[3];
1190
1191 let block_id = self.id_gen.next();
1192 let mut block = Block::new(block_id);
1193
1194 let current_t = self.id_gen.next();
1195 block.body.push(Instruction::variable(
1196 f32_ptr_type_id,
1197 current_t,
1198 spirv::StorageClass::Function,
1199 None,
1200 ));
1201
1202 let current_t = self.id_gen.next();
1203 block.body.push(Instruction::variable(
1204 f32_ptr_type_id,
1205 current_t,
1206 spirv::StorageClass::Function,
1207 None,
1208 ));
1209
1210 let valid_id = self.id_gen.next();
1211 let mut valid_block = Block::new(valid_id);
1212
1213 let final_label_id = self.id_gen.next();
1214 let final_block = Block::new(final_label_id);
1215
1216 let instruction = if self.ray_query_initialization_tracking {
1217 let initialized_tracker_id = self.id_gen.next();
1218 block.body.push(Instruction::load(
1219 u32_ty,
1220 initialized_tracker_id,
1221 init_tracker_id,
1222 None,
1223 ));
1224
1225 let proceeded_id = write_ray_flags_contains_flags(
1226 self,
1227 &mut block,
1228 initialized_tracker_id,
1229 super::RayQueryPoint::PROCEED.bits(),
1230 );
1231 let finished_proceed_id = write_ray_flags_contains_flags(
1232 self,
1233 &mut block,
1234 initialized_tracker_id,
1235 super::RayQueryPoint::FINISHED_TRAVERSAL.bits(),
1236 );
1237
1238 let not_finished_id = self.id_gen.next();
1241 block.body.push(Instruction::unary(
1242 spirv::Op::LogicalNot,
1243 bool_type_id,
1244 not_finished_id,
1245 finished_proceed_id,
1246 ));
1247
1248 let is_valid_id = self.write_logical_and(&mut block, not_finished_id, proceeded_id);
1249
1250 block.body.push(Instruction::selection_merge(
1251 final_label_id,
1252 spirv::SelectionControl::NONE,
1253 ));
1254
1255 Instruction::branch_conditional(is_valid_id, valid_id, final_label_id)
1256 } else {
1257 Instruction::branch(valid_id)
1258 };
1259
1260 function.consume(block, instruction);
1261
1262 let intersection_id = self.get_constant_scalar(crate::Literal::U32(
1263 spirv::RayQueryIntersection::RayQueryCandidateIntersectionKHR as _,
1264 ));
1265 let committed_intersection_id = self.get_constant_scalar(crate::Literal::U32(
1266 spirv::RayQueryIntersection::RayQueryCommittedIntersectionKHR as _,
1267 ));
1268 let raw_kind_id = self.id_gen.next();
1269 valid_block
1270 .body
1271 .push(Instruction::ray_query_get_intersection(
1272 spirv::Op::RayQueryGetIntersectionTypeKHR,
1273 u32_ty,
1274 raw_kind_id,
1275 query_id,
1276 intersection_id,
1277 ));
1278
1279 let candidate_aabb_id = self.get_constant_scalar(crate::Literal::U32(
1280 spirv::RayQueryCandidateIntersectionType::RayQueryCandidateIntersectionAABBKHR as _,
1281 ));
1282 let intersection_aabb_id = self.id_gen.next();
1283 valid_block.body.push(Instruction::binary(
1284 spirv::Op::IEqual,
1285 bool_type_id,
1286 intersection_aabb_id,
1287 raw_kind_id,
1288 candidate_aabb_id,
1289 ));
1290
1291 let t_min_id = self.id_gen.next();
1296 valid_block.body.push(Instruction::ray_query_get_t_min(
1297 f32_type_id,
1298 t_min_id,
1299 query_id,
1300 ));
1301
1302 let committed_type_id = self.id_gen.next();
1319 valid_block
1320 .body
1321 .push(Instruction::ray_query_get_intersection(
1322 spirv::Op::RayQueryGetIntersectionTypeKHR,
1323 u32_ty,
1324 committed_type_id,
1325 query_id,
1326 committed_intersection_id,
1327 ));
1328
1329 let no_committed = self.id_gen.next();
1330 valid_block.body.push(Instruction::binary(
1331 spirv::Op::IEqual,
1332 bool_type_id,
1333 no_committed,
1334 committed_type_id,
1335 self.get_constant_scalar(crate::Literal::U32(
1336 spirv::RayQueryCommittedIntersectionType::RayQueryCommittedIntersectionNoneKHR as _,
1337 )),
1338 ));
1339
1340 let next_valid_block_id = self.id_gen.next();
1341 let no_committed_block_id = self.id_gen.next();
1342 let mut no_committed_block = Block::new(no_committed_block_id);
1343 let committed_block_id = self.id_gen.next();
1344 let mut committed_block = Block::new(committed_block_id);
1345 valid_block.body.push(Instruction::selection_merge(
1346 next_valid_block_id,
1347 spirv::SelectionControl::NONE,
1348 ));
1349 function.consume(
1350 valid_block,
1351 Instruction::branch_conditional(
1352 no_committed,
1353 no_committed_block_id,
1354 committed_block_id,
1355 ),
1356 );
1357
1358 let t_max_id = self.id_gen.next();
1360 no_committed_block.body.push(Instruction::load(
1361 f32_type_id,
1362 t_max_id,
1363 t_max_tracker_id,
1364 None,
1365 ));
1366 no_committed_block
1367 .body
1368 .push(Instruction::store(current_t, t_max_id, None));
1369 function.consume(no_committed_block, Instruction::branch(next_valid_block_id));
1370
1371 let latest_t_id = self.id_gen.next();
1373 committed_block
1374 .body
1375 .push(Instruction::ray_query_get_intersection(
1376 spirv::Op::RayQueryGetIntersectionTKHR,
1377 f32_type_id,
1378 latest_t_id,
1379 query_id,
1380 intersection_id,
1381 ));
1382 committed_block
1383 .body
1384 .push(Instruction::store(current_t, latest_t_id, None));
1385 function.consume(committed_block, Instruction::branch(next_valid_block_id));
1386
1387 let mut valid_block = Block::new(next_valid_block_id);
1388
1389 let t_ge_t_min = self.id_gen.next();
1390 valid_block.body.push(Instruction::binary(
1391 spirv::Op::FOrdGreaterThanEqual,
1392 bool_type_id,
1393 t_ge_t_min,
1394 depth_id,
1395 t_min_id,
1396 ));
1397 let t_current = self.id_gen.next();
1398 valid_block
1399 .body
1400 .push(Instruction::load(f32_type_id, t_current, current_t, None));
1401 let t_le_t_current = self.id_gen.next();
1402 valid_block.body.push(Instruction::binary(
1403 spirv::Op::FOrdLessThanEqual,
1404 bool_type_id,
1405 t_le_t_current,
1406 depth_id,
1407 t_current,
1408 ));
1409
1410 let t_in_range = self.id_gen.next();
1411 valid_block.body.push(Instruction::binary(
1412 spirv::Op::LogicalAnd,
1413 bool_type_id,
1414 t_in_range,
1415 t_ge_t_min,
1416 t_le_t_current,
1417 ));
1418
1419 let call_valid_id = self.id_gen.next();
1420 valid_block.body.push(Instruction::binary(
1421 spirv::Op::LogicalAnd,
1422 bool_type_id,
1423 call_valid_id,
1424 t_in_range,
1425 intersection_aabb_id,
1426 ));
1427
1428 let generate_label_id = self.id_gen.next();
1429 let mut generate_block = Block::new(generate_label_id);
1430
1431 let merge_label_id = self.id_gen.next();
1432 let merge_block = Block::new(merge_label_id);
1433
1434 valid_block.body.push(Instruction::selection_merge(
1435 merge_label_id,
1436 spirv::SelectionControl::NONE,
1437 ));
1438 function.consume(
1439 valid_block,
1440 Instruction::branch_conditional(call_valid_id, generate_label_id, merge_label_id),
1441 );
1442
1443 generate_block
1444 .body
1445 .push(Instruction::ray_query_generate_intersection(
1446 query_id, depth_id,
1447 ));
1448
1449 function.consume(generate_block, Instruction::branch(merge_label_id));
1450 function.consume(merge_block, Instruction::branch(final_label_id));
1451
1452 function.consume(final_block, Instruction::return_void());
1453
1454 function.to_words(&mut self.logical_layout.function_definitions);
1455
1456 self.ray_query_functions
1457 .insert(LookupRayQueryFunction::GenerateIntersection, func_id);
1458 func_id
1459 }
1460
1461 fn write_ray_query_confirm_intersection(&mut self) -> spirv::Word {
1462 if let Some(&word) = self
1463 .ray_query_functions
1464 .get(&LookupRayQueryFunction::ConfirmIntersection)
1465 {
1466 return word;
1467 }
1468
1469 let ray_query_type_id = self.get_ray_query_pointer_id();
1470
1471 let u32_ty = self.get_u32_type_id();
1472 let u32_ptr_ty = self.get_pointer_type_id(u32_ty, spirv::StorageClass::Function);
1473
1474 let bool_type_id = self.get_bool_type_id();
1475
1476 let (func_id, mut function, arg_ids) =
1477 self.write_function_signature(&[ray_query_type_id, u32_ptr_ty], self.void_type);
1478
1479 let query_id = arg_ids[0];
1480 let init_tracker_id = arg_ids[1];
1481
1482 let block_id = self.id_gen.next();
1483 let mut block = Block::new(block_id);
1484
1485 let valid_id = self.id_gen.next();
1486 let mut valid_block = Block::new(valid_id);
1487
1488 let final_label_id = self.id_gen.next();
1489 let final_block = Block::new(final_label_id);
1490
1491 let instruction = if self.ray_query_initialization_tracking {
1492 let initialized_tracker_id = self.id_gen.next();
1493 block.body.push(Instruction::load(
1494 u32_ty,
1495 initialized_tracker_id,
1496 init_tracker_id,
1497 None,
1498 ));
1499
1500 let proceeded_id = write_ray_flags_contains_flags(
1501 self,
1502 &mut block,
1503 initialized_tracker_id,
1504 super::RayQueryPoint::PROCEED.bits(),
1505 );
1506 let finished_proceed_id = write_ray_flags_contains_flags(
1507 self,
1508 &mut block,
1509 initialized_tracker_id,
1510 super::RayQueryPoint::FINISHED_TRAVERSAL.bits(),
1511 );
1512 let not_finished_id = self.id_gen.next();
1514 block.body.push(Instruction::unary(
1515 spirv::Op::LogicalNot,
1516 bool_type_id,
1517 not_finished_id,
1518 finished_proceed_id,
1519 ));
1520
1521 let is_valid_id = self.write_logical_and(&mut block, not_finished_id, proceeded_id);
1522
1523 block.body.push(Instruction::selection_merge(
1524 final_label_id,
1525 spirv::SelectionControl::NONE,
1526 ));
1527
1528 Instruction::branch_conditional(is_valid_id, valid_id, final_label_id)
1529 } else {
1530 Instruction::branch(valid_id)
1531 };
1532
1533 function.consume(block, instruction);
1534
1535 let intersection_id = self.get_constant_scalar(crate::Literal::U32(
1536 spirv::RayQueryIntersection::RayQueryCandidateIntersectionKHR as _,
1537 ));
1538 let raw_kind_id = self.id_gen.next();
1539 valid_block
1540 .body
1541 .push(Instruction::ray_query_get_intersection(
1542 spirv::Op::RayQueryGetIntersectionTypeKHR,
1543 u32_ty,
1544 raw_kind_id,
1545 query_id,
1546 intersection_id,
1547 ));
1548
1549 let candidate_tri_id = self.get_constant_scalar(crate::Literal::U32(
1550 spirv::RayQueryCandidateIntersectionType::RayQueryCandidateIntersectionTriangleKHR as _,
1551 ));
1552 let intersection_tri_id = self.id_gen.next();
1553 valid_block.body.push(Instruction::binary(
1554 spirv::Op::IEqual,
1555 bool_type_id,
1556 intersection_tri_id,
1557 raw_kind_id,
1558 candidate_tri_id,
1559 ));
1560
1561 let generate_label_id = self.id_gen.next();
1562 let mut generate_block = Block::new(generate_label_id);
1563
1564 let merge_label_id = self.id_gen.next();
1565 let merge_block = Block::new(merge_label_id);
1566
1567 valid_block.body.push(Instruction::selection_merge(
1568 merge_label_id,
1569 spirv::SelectionControl::NONE,
1570 ));
1571 function.consume(
1572 valid_block,
1573 Instruction::branch_conditional(intersection_tri_id, generate_label_id, merge_label_id),
1574 );
1575
1576 generate_block
1577 .body
1578 .push(Instruction::ray_query_confirm_intersection(query_id));
1579
1580 function.consume(generate_block, Instruction::branch(merge_label_id));
1581 function.consume(merge_block, Instruction::branch(final_label_id));
1582
1583 function.consume(final_block, Instruction::return_void());
1584
1585 self.ray_query_functions
1586 .insert(LookupRayQueryFunction::ConfirmIntersection, func_id);
1587
1588 function.to_words(&mut self.logical_layout.function_definitions);
1589
1590 func_id
1591 }
1592
1593 fn write_ray_query_get_vertex_positions(
1594 &mut self,
1595 is_committed: bool,
1596 ir_module: &crate::Module,
1597 ) -> spirv::Word {
1598 if let Some(&word) =
1599 self.ray_query_functions
1600 .get(&LookupRayQueryFunction::GetVertexPositions {
1601 committed: is_committed,
1602 })
1603 {
1604 return word;
1605 }
1606
1607 let (committed_ty, committed_tri_ty) = if is_committed {
1608 (
1609 spirv::RayQueryIntersection::RayQueryCommittedIntersectionKHR as u32,
1610 spirv::RayQueryCommittedIntersectionType::RayQueryCommittedIntersectionTriangleKHR
1611 as u32,
1612 )
1613 } else {
1614 (
1615 spirv::RayQueryIntersection::RayQueryCandidateIntersectionKHR as u32,
1616 spirv::RayQueryCandidateIntersectionType::RayQueryCandidateIntersectionTriangleKHR
1617 as u32,
1618 )
1619 };
1620
1621 let ray_query_type_id = self.get_ray_query_pointer_id();
1622
1623 let u32_ty = self.get_u32_type_id();
1624 let u32_ptr_ty = self.get_pointer_type_id(u32_ty, spirv::StorageClass::Function);
1625
1626 let rq_get_vertex_positions_ty_id = self.get_handle_type_id(
1627 *ir_module
1628 .special_types
1629 .ray_vertex_return
1630 .as_ref()
1631 .expect("must be generated when reading in get vertex position"),
1632 );
1633 let ptr_return_ty =
1634 self.get_pointer_type_id(rq_get_vertex_positions_ty_id, spirv::StorageClass::Function);
1635
1636 let bool_type_id = self.get_bool_type_id();
1637
1638 let (func_id, mut function, arg_ids) = self.write_function_signature(
1639 &[ray_query_type_id, u32_ptr_ty],
1640 rq_get_vertex_positions_ty_id,
1641 );
1642
1643 let query_id = arg_ids[0];
1644 let init_tracker_id = arg_ids[1];
1645
1646 let block_id = self.id_gen.next();
1647 let mut block = Block::new(block_id);
1648
1649 let return_id = self.id_gen.next();
1650 block.body.push(Instruction::variable(
1651 ptr_return_ty,
1652 return_id,
1653 spirv::StorageClass::Function,
1654 Some(self.get_constant_null(rq_get_vertex_positions_ty_id)),
1655 ));
1656
1657 let valid_id = self.id_gen.next();
1658 let mut valid_block = Block::new(valid_id);
1659
1660 let final_label_id = self.id_gen.next();
1661 let mut final_block = Block::new(final_label_id);
1662
1663 let instruction = if self.ray_query_initialization_tracking {
1664 let initialized_tracker_id = self.id_gen.next();
1665 block.body.push(Instruction::load(
1666 u32_ty,
1667 initialized_tracker_id,
1668 init_tracker_id,
1669 None,
1670 ));
1671
1672 let proceeded_id = write_ray_flags_contains_flags(
1673 self,
1674 &mut block,
1675 initialized_tracker_id,
1676 super::RayQueryPoint::PROCEED.bits(),
1677 );
1678 let finished_proceed_id = write_ray_flags_contains_flags(
1679 self,
1680 &mut block,
1681 initialized_tracker_id,
1682 super::RayQueryPoint::FINISHED_TRAVERSAL.bits(),
1683 );
1684
1685 let correct_finish_id = if is_committed {
1686 finished_proceed_id
1687 } else {
1688 let not_finished_id = self.id_gen.next();
1689 block.body.push(Instruction::unary(
1690 spirv::Op::LogicalNot,
1691 bool_type_id,
1692 not_finished_id,
1693 finished_proceed_id,
1694 ));
1695 not_finished_id
1696 };
1697
1698 let is_valid_id = self.write_logical_and(&mut block, correct_finish_id, proceeded_id);
1699 block.body.push(Instruction::selection_merge(
1700 final_label_id,
1701 spirv::SelectionControl::NONE,
1702 ));
1703 Instruction::branch_conditional(is_valid_id, valid_id, final_label_id)
1704 } else {
1705 Instruction::branch(valid_id)
1706 };
1707
1708 function.consume(block, instruction);
1709
1710 let intersection_id = self.get_constant_scalar(crate::Literal::U32(committed_ty));
1711 let raw_kind_id = self.id_gen.next();
1712 valid_block
1713 .body
1714 .push(Instruction::ray_query_get_intersection(
1715 spirv::Op::RayQueryGetIntersectionTypeKHR,
1716 u32_ty,
1717 raw_kind_id,
1718 query_id,
1719 intersection_id,
1720 ));
1721
1722 let candidate_tri_id = self.get_constant_scalar(crate::Literal::U32(committed_tri_ty));
1723 let intersection_tri_id = self.id_gen.next();
1724 valid_block.body.push(Instruction::binary(
1725 spirv::Op::IEqual,
1726 bool_type_id,
1727 intersection_tri_id,
1728 raw_kind_id,
1729 candidate_tri_id,
1730 ));
1731
1732 let generate_label_id = self.id_gen.next();
1733 let mut vertex_return_block = Block::new(generate_label_id);
1734
1735 let merge_label_id = self.id_gen.next();
1736 let merge_block = Block::new(merge_label_id);
1737
1738 valid_block.body.push(Instruction::selection_merge(
1739 merge_label_id,
1740 spirv::SelectionControl::NONE,
1741 ));
1742 function.consume(
1743 valid_block,
1744 Instruction::branch_conditional(intersection_tri_id, generate_label_id, merge_label_id),
1745 );
1746
1747 let vertices_id = self.id_gen.next();
1748 vertex_return_block
1749 .body
1750 .push(Instruction::ray_query_return_vertex_position(
1751 rq_get_vertex_positions_ty_id,
1752 vertices_id,
1753 query_id,
1754 intersection_id,
1755 ));
1756 vertex_return_block
1757 .body
1758 .push(Instruction::store(return_id, vertices_id, None));
1759
1760 function.consume(vertex_return_block, Instruction::branch(merge_label_id));
1761 function.consume(merge_block, Instruction::branch(final_label_id));
1762
1763 let loaded_pos_id = self.id_gen.next();
1764 final_block.body.push(Instruction::load(
1765 rq_get_vertex_positions_ty_id,
1766 loaded_pos_id,
1767 return_id,
1768 None,
1769 ));
1770
1771 function.consume(final_block, Instruction::return_value(loaded_pos_id));
1772
1773 self.ray_query_functions.insert(
1774 LookupRayQueryFunction::GetVertexPositions {
1775 committed: is_committed,
1776 },
1777 func_id,
1778 );
1779
1780 function.to_words(&mut self.logical_layout.function_definitions);
1781
1782 func_id
1783 }
1784}
1785
1786impl BlockContext<'_> {
1787 pub(super) fn write_ray_query_function(
1788 &mut self,
1789 query: Handle<crate::Expression>,
1790 function: &crate::RayQueryFunction,
1791 block: &mut Block,
1792 ) {
1793 let query_id = self.cached[query];
1794 let tracker_ids = *self
1795 .ray_query_tracker_expr
1796 .get(&query)
1797 .expect("not a cached ray query");
1798
1799 match *function {
1800 crate::RayQueryFunction::Initialize {
1801 acceleration_structure,
1802 descriptor,
1803 } => {
1804 let desc_id = self.cached[descriptor];
1805 let acc_struct_id = self.get_handle_id(acceleration_structure);
1806
1807 let func = self.writer.write_ray_query_initialize(self.ir_module);
1808
1809 let func_id = self.gen_id();
1810 block.body.push(Instruction::function_call(
1811 self.writer.void_type,
1812 func_id,
1813 func,
1814 &[
1815 query_id,
1816 acc_struct_id,
1817 desc_id,
1818 tracker_ids.initialized_tracker,
1819 tracker_ids.t_max_tracker,
1820 ],
1821 ));
1822 }
1823 crate::RayQueryFunction::Proceed { result } => {
1824 let id = self.gen_id();
1825 self.cached[result] = id;
1826
1827 let bool_ty = self.writer.get_bool_type_id();
1828
1829 let func_id = self.writer.write_ray_query_proceed();
1830 block.body.push(Instruction::function_call(
1831 bool_ty,
1832 id,
1833 func_id,
1834 &[query_id, tracker_ids.initialized_tracker],
1835 ));
1836 }
1837 crate::RayQueryFunction::GenerateIntersection { hit_t } => {
1838 let hit_id = self.cached[hit_t];
1839
1840 let func_id = self.writer.write_ray_query_generate_intersection();
1841
1842 let func_call_id = self.gen_id();
1843 block.body.push(Instruction::function_call(
1844 self.writer.void_type,
1845 func_call_id,
1846 func_id,
1847 &[
1848 query_id,
1849 tracker_ids.initialized_tracker,
1850 hit_id,
1851 tracker_ids.t_max_tracker,
1852 ],
1853 ));
1854 }
1855 crate::RayQueryFunction::ConfirmIntersection => {
1856 let func_id = self.writer.write_ray_query_confirm_intersection();
1857
1858 let func_call_id = self.gen_id();
1859 block.body.push(Instruction::function_call(
1860 self.writer.void_type,
1861 func_call_id,
1862 func_id,
1863 &[query_id, tracker_ids.initialized_tracker],
1864 ));
1865 }
1866 crate::RayQueryFunction::Terminate => {}
1867 }
1868 }
1869
1870 pub(super) fn write_ray_query_return_vertex_position(
1871 &mut self,
1872 query: Handle<crate::Expression>,
1873 block: &mut Block,
1874 is_committed: bool,
1875 ) -> spirv::Word {
1876 let fn_id = self
1877 .writer
1878 .write_ray_query_get_vertex_positions(is_committed, self.ir_module);
1879
1880 let query_id = self.cached[query];
1881 let tracker_id = *self
1882 .ray_query_tracker_expr
1883 .get(&query)
1884 .expect("not a cached ray query");
1885
1886 let rq_get_vertex_positions_ty_id = self.get_handle_type_id(
1887 *self
1888 .ir_module
1889 .special_types
1890 .ray_vertex_return
1891 .as_ref()
1892 .expect("must be generated when reading in get vertex position"),
1893 );
1894
1895 let func_call_id = self.gen_id();
1896 block.body.push(Instruction::function_call(
1897 rq_get_vertex_positions_ty_id,
1898 func_call_id,
1899 fn_id,
1900 &[query_id, tracker_id.initialized_tracker],
1901 ));
1902 func_call_id
1903 }
1904}