1use alloc::{vec, vec::Vec};
6
7use super::{
8 Block, BlockContext, Function, FunctionArgument, Instruction, LookupFunctionType,
9 LookupRayQueryFunction, NumericType, Writer,
10};
11use crate::{arena::Handle, back::RayQueryPoint};
12
13fn write_ray_flags_contains_flags(
15 writer: &mut Writer,
16 block: &mut Block,
17 id: spirv::Word,
18 flag: u32,
19) -> spirv::Word {
20 let bit_id = writer.get_constant_scalar(crate::Literal::U32(flag));
21 let zero_id = writer.get_constant_scalar(crate::Literal::U32(0));
22 let u32_type_id = writer.get_u32_type_id();
23 let bool_ty = writer.get_bool_type_id();
24
25 let and_id = writer.id_gen.next();
26 block.body.push(Instruction::binary(
27 spirv::Op::BitwiseAnd,
28 u32_type_id,
29 and_id,
30 id,
31 bit_id,
32 ));
33
34 let eq_id = writer.id_gen.next();
35 block.body.push(Instruction::binary(
36 spirv::Op::INotEqual,
37 bool_ty,
38 eq_id,
39 and_id,
40 zero_id,
41 ));
42
43 eq_id
44}
45
46impl Writer {
47 fn write_logical_and(
49 &mut self,
50 block: &mut Block,
51 one: spirv::Word,
52 two: spirv::Word,
53 ) -> spirv::Word {
54 let id = self.id_gen.next();
55 let bool_id = self.get_bool_type_id();
56 block.body.push(Instruction::binary(
57 spirv::Op::LogicalAnd,
58 bool_id,
59 id,
60 one,
61 two,
62 ));
63 id
64 }
65
66 fn write_reduce_and(&mut self, block: &mut Block, mut bools: Vec<spirv::Word>) -> spirv::Word {
67 let mut current_combined = bools.pop().unwrap();
69 for boolean in bools {
70 current_combined = self.write_logical_and(block, current_combined, boolean)
71 }
72 current_combined
73 }
74
75 fn write_function_signature(
77 &mut self,
78 arg_types: &[spirv::Word],
79 return_ty: spirv::Word,
80 ) -> (spirv::Word, Function, Vec<spirv::Word>) {
81 let func_ty = self.get_function_type(LookupFunctionType {
82 parameter_type_ids: Vec::from(arg_types),
83 return_type_id: return_ty,
84 });
85
86 let mut function = Function::default();
87 let func_id = self.id_gen.next();
88 function.signature = Some(Instruction::function(
89 return_ty,
90 func_id,
91 spirv::FunctionControl::empty(),
92 func_ty,
93 ));
94
95 let mut arg_ids = Vec::with_capacity(arg_types.len());
96
97 for (idx, &arg_ty) in arg_types.iter().enumerate() {
98 let id = self.id_gen.next();
99 let instruction = Instruction::function_parameter(arg_ty, id);
100 function.parameters.push(FunctionArgument {
101 instruction,
102 handle_id: idx as u32,
103 });
104 arg_ids.push(id);
105 }
106 (func_id, function, arg_ids)
107 }
108
109 pub(super) fn write_ray_query_get_intersection_function(
110 &mut self,
111 is_committed: bool,
112 ir_module: &crate::Module,
113 ) -> spirv::Word {
114 if let Some(&word) =
115 self.ray_query_functions
116 .get(&LookupRayQueryFunction::GetIntersection {
117 committed: is_committed,
118 })
119 {
120 return word;
121 }
122 let ray_intersection = ir_module.special_types.ray_intersection.unwrap();
123 let intersection_type_id = self.get_handle_type_id(ray_intersection);
124 let intersection_pointer_type_id =
125 self.get_pointer_type_id(intersection_type_id, spirv::StorageClass::Function);
126
127 let flag_type_id = self.get_u32_type_id();
128 let flag_pointer_type_id =
129 self.get_pointer_type_id(flag_type_id, spirv::StorageClass::Function);
130
131 let transform_type_id = self.get_numeric_type_id(NumericType::Matrix {
132 columns: crate::VectorSize::Quad,
133 rows: crate::VectorSize::Tri,
134 scalar: crate::Scalar::F32,
135 });
136 let transform_pointer_type_id =
137 self.get_pointer_type_id(transform_type_id, spirv::StorageClass::Function);
138
139 let barycentrics_type_id = self.get_numeric_type_id(NumericType::Vector {
140 size: crate::VectorSize::Bi,
141 scalar: crate::Scalar::F32,
142 });
143 let barycentrics_pointer_type_id =
144 self.get_pointer_type_id(barycentrics_type_id, spirv::StorageClass::Function);
145
146 let bool_type_id = self.get_bool_type_id();
147 let bool_pointer_type_id =
148 self.get_pointer_type_id(bool_type_id, spirv::StorageClass::Function);
149
150 let scalar_type_id = self.get_f32_type_id();
151 let float_pointer_type_id = self.get_f32_pointer_type_id(spirv::StorageClass::Function);
152
153 let argument_type_id = self.get_ray_query_pointer_id();
154
155 let (func_id, mut function, arg_ids) = self.write_function_signature(
156 &[argument_type_id, flag_pointer_type_id],
157 intersection_type_id,
158 );
159
160 let query_id = arg_ids[0];
161 let intersection_tracker_id = arg_ids[1];
162
163 let label_id = self.id_gen.next();
164 let mut block = Block::new(label_id);
165
166 let blank_intersection = self.get_constant_null(intersection_type_id);
167 let blank_intersection_id = self.id_gen.next();
168 block.body.push(Instruction::variable(
170 intersection_pointer_type_id,
171 blank_intersection_id,
172 spirv::StorageClass::Function,
173 Some(blank_intersection),
174 ));
175
176 let intersection_id = self.get_constant_scalar(crate::Literal::U32(if is_committed {
177 spirv::RayQueryIntersection::RayQueryCommittedIntersectionKHR
178 } else {
179 spirv::RayQueryIntersection::RayQueryCandidateIntersectionKHR
180 } as _));
181
182 let loaded_ray_query_tracker_id = self.id_gen.next();
183 block.body.push(Instruction::load(
184 flag_type_id,
185 loaded_ray_query_tracker_id,
186 intersection_tracker_id,
187 None,
188 ));
189 let proceeded_id = write_ray_flags_contains_flags(
190 self,
191 &mut block,
192 loaded_ray_query_tracker_id,
193 RayQueryPoint::PROCEED.bits(),
194 );
195 let finished_proceed_id = write_ray_flags_contains_flags(
196 self,
197 &mut block,
198 loaded_ray_query_tracker_id,
199 RayQueryPoint::FINISHED_TRAVERSAL.bits(),
200 );
201 let proceed_finished_correct_id = if is_committed {
202 finished_proceed_id
203 } else {
204 let not_finished_id = self.id_gen.next();
205 block.body.push(Instruction::unary(
206 spirv::Op::LogicalNot,
207 bool_type_id,
208 not_finished_id,
209 finished_proceed_id,
210 ));
211 not_finished_id
212 };
213
214 let is_valid_id =
215 self.write_logical_and(&mut block, proceed_finished_correct_id, proceeded_id);
216
217 let valid_id = self.id_gen.next();
218 let mut valid_block = Block::new(valid_id);
219
220 let final_label_id = self.id_gen.next();
221 let mut final_block = Block::new(final_label_id);
222
223 block.body.push(Instruction::selection_merge(
224 final_label_id,
225 spirv::SelectionControl::NONE,
226 ));
227 function.consume(
228 block,
229 Instruction::branch_conditional(is_valid_id, valid_id, final_label_id),
230 );
231
232 let raw_kind_id = self.id_gen.next();
233 valid_block
234 .body
235 .push(Instruction::ray_query_get_intersection(
236 spirv::Op::RayQueryGetIntersectionTypeKHR,
237 flag_type_id,
238 raw_kind_id,
239 query_id,
240 intersection_id,
241 ));
242 let kind_id = if is_committed {
243 raw_kind_id
245 } else {
246 let condition_id = self.id_gen.next();
248 let committed_triangle_kind_id = self.get_constant_scalar(crate::Literal::U32(
249 spirv::RayQueryCandidateIntersectionType::RayQueryCandidateIntersectionTriangleKHR
250 as _,
251 ));
252 valid_block.body.push(Instruction::binary(
253 spirv::Op::IEqual,
254 self.get_bool_type_id(),
255 condition_id,
256 raw_kind_id,
257 committed_triangle_kind_id,
258 ));
259 let kind_id = self.id_gen.next();
260 valid_block.body.push(Instruction::select(
261 flag_type_id,
262 kind_id,
263 condition_id,
264 self.get_constant_scalar(crate::Literal::U32(
265 crate::RayQueryIntersection::Triangle as _,
266 )),
267 self.get_constant_scalar(crate::Literal::U32(
268 crate::RayQueryIntersection::Aabb as _,
269 )),
270 ));
271 kind_id
272 };
273 let idx_id = self.get_index_constant(0);
274 let access_idx = self.id_gen.next();
275 valid_block.body.push(Instruction::access_chain(
276 flag_pointer_type_id,
277 access_idx,
278 blank_intersection_id,
279 &[idx_id],
280 ));
281 valid_block
282 .body
283 .push(Instruction::store(access_idx, kind_id, None));
284
285 let not_none_comp_id = self.id_gen.next();
286 let none_id =
287 self.get_constant_scalar(crate::Literal::U32(crate::RayQueryIntersection::None as _));
288 valid_block.body.push(Instruction::binary(
289 spirv::Op::INotEqual,
290 self.get_bool_type_id(),
291 not_none_comp_id,
292 kind_id,
293 none_id,
294 ));
295
296 let not_none_label_id = self.id_gen.next();
297 let mut not_none_block = Block::new(not_none_label_id);
298
299 let outer_merge_label_id = self.id_gen.next();
300 let outer_merge_block = Block::new(outer_merge_label_id);
301
302 valid_block.body.push(Instruction::selection_merge(
303 outer_merge_label_id,
304 spirv::SelectionControl::NONE,
305 ));
306 function.consume(
307 valid_block,
308 Instruction::branch_conditional(
309 not_none_comp_id,
310 not_none_label_id,
311 outer_merge_label_id,
312 ),
313 );
314
315 let instance_custom_index_id = self.id_gen.next();
316 not_none_block
317 .body
318 .push(Instruction::ray_query_get_intersection(
319 spirv::Op::RayQueryGetIntersectionInstanceCustomIndexKHR,
320 flag_type_id,
321 instance_custom_index_id,
322 query_id,
323 intersection_id,
324 ));
325 let instance_id = self.id_gen.next();
326 not_none_block
327 .body
328 .push(Instruction::ray_query_get_intersection(
329 spirv::Op::RayQueryGetIntersectionInstanceIdKHR,
330 flag_type_id,
331 instance_id,
332 query_id,
333 intersection_id,
334 ));
335 let sbt_record_offset_id = self.id_gen.next();
336 not_none_block
337 .body
338 .push(Instruction::ray_query_get_intersection(
339 spirv::Op::RayQueryGetIntersectionInstanceShaderBindingTableRecordOffsetKHR,
340 flag_type_id,
341 sbt_record_offset_id,
342 query_id,
343 intersection_id,
344 ));
345 let geometry_index_id = self.id_gen.next();
346 not_none_block
347 .body
348 .push(Instruction::ray_query_get_intersection(
349 spirv::Op::RayQueryGetIntersectionGeometryIndexKHR,
350 flag_type_id,
351 geometry_index_id,
352 query_id,
353 intersection_id,
354 ));
355 let primitive_index_id = self.id_gen.next();
356 not_none_block
357 .body
358 .push(Instruction::ray_query_get_intersection(
359 spirv::Op::RayQueryGetIntersectionPrimitiveIndexKHR,
360 flag_type_id,
361 primitive_index_id,
362 query_id,
363 intersection_id,
364 ));
365
366 let object_to_world_id = self.id_gen.next();
370 not_none_block
371 .body
372 .push(Instruction::ray_query_get_intersection(
373 spirv::Op::RayQueryGetIntersectionObjectToWorldKHR,
374 transform_type_id,
375 object_to_world_id,
376 query_id,
377 intersection_id,
378 ));
379 let world_to_object_id = self.id_gen.next();
380 not_none_block
381 .body
382 .push(Instruction::ray_query_get_intersection(
383 spirv::Op::RayQueryGetIntersectionWorldToObjectKHR,
384 transform_type_id,
385 world_to_object_id,
386 query_id,
387 intersection_id,
388 ));
389
390 let idx_id = self.get_index_constant(2);
392 let access_idx = self.id_gen.next();
393 not_none_block.body.push(Instruction::access_chain(
394 flag_pointer_type_id,
395 access_idx,
396 blank_intersection_id,
397 &[idx_id],
398 ));
399 not_none_block.body.push(Instruction::store(
400 access_idx,
401 instance_custom_index_id,
402 None,
403 ));
404
405 let idx_id = self.get_index_constant(3);
407 let access_idx = self.id_gen.next();
408 not_none_block.body.push(Instruction::access_chain(
409 flag_pointer_type_id,
410 access_idx,
411 blank_intersection_id,
412 &[idx_id],
413 ));
414 not_none_block
415 .body
416 .push(Instruction::store(access_idx, instance_id, None));
417
418 let idx_id = self.get_index_constant(4);
419 let access_idx = self.id_gen.next();
420 not_none_block.body.push(Instruction::access_chain(
421 flag_pointer_type_id,
422 access_idx,
423 blank_intersection_id,
424 &[idx_id],
425 ));
426 not_none_block
427 .body
428 .push(Instruction::store(access_idx, sbt_record_offset_id, None));
429
430 let idx_id = self.get_index_constant(5);
431 let access_idx = self.id_gen.next();
432 not_none_block.body.push(Instruction::access_chain(
433 flag_pointer_type_id,
434 access_idx,
435 blank_intersection_id,
436 &[idx_id],
437 ));
438 not_none_block
439 .body
440 .push(Instruction::store(access_idx, geometry_index_id, None));
441
442 let idx_id = self.get_index_constant(6);
443 let access_idx = self.id_gen.next();
444 not_none_block.body.push(Instruction::access_chain(
445 flag_pointer_type_id,
446 access_idx,
447 blank_intersection_id,
448 &[idx_id],
449 ));
450 not_none_block
451 .body
452 .push(Instruction::store(access_idx, primitive_index_id, None));
453
454 let idx_id = self.get_index_constant(9);
455 let access_idx = self.id_gen.next();
456 not_none_block.body.push(Instruction::access_chain(
457 transform_pointer_type_id,
458 access_idx,
459 blank_intersection_id,
460 &[idx_id],
461 ));
462 not_none_block
463 .body
464 .push(Instruction::store(access_idx, object_to_world_id, None));
465
466 let idx_id = self.get_index_constant(10);
467 let access_idx = self.id_gen.next();
468 not_none_block.body.push(Instruction::access_chain(
469 transform_pointer_type_id,
470 access_idx,
471 blank_intersection_id,
472 &[idx_id],
473 ));
474 not_none_block
475 .body
476 .push(Instruction::store(access_idx, world_to_object_id, None));
477
478 let tri_comp_id = self.id_gen.next();
479 let tri_id = self.get_constant_scalar(crate::Literal::U32(
480 crate::RayQueryIntersection::Triangle as _,
481 ));
482 not_none_block.body.push(Instruction::binary(
483 spirv::Op::IEqual,
484 self.get_bool_type_id(),
485 tri_comp_id,
486 kind_id,
487 tri_id,
488 ));
489
490 let tri_label_id = self.id_gen.next();
491 let mut tri_block = Block::new(tri_label_id);
492
493 let merge_label_id = self.id_gen.next();
494 let merge_block = Block::new(merge_label_id);
495 {
497 let block = if is_committed {
498 &mut not_none_block
499 } else {
500 &mut tri_block
501 };
502 let t_id = self.id_gen.next();
503 block.body.push(Instruction::ray_query_get_intersection(
504 spirv::Op::RayQueryGetIntersectionTKHR,
505 scalar_type_id,
506 t_id,
507 query_id,
508 intersection_id,
509 ));
510 let idx_id = self.get_index_constant(1);
511 let access_idx = self.id_gen.next();
512 block.body.push(Instruction::access_chain(
513 float_pointer_type_id,
514 access_idx,
515 blank_intersection_id,
516 &[idx_id],
517 ));
518 block.body.push(Instruction::store(access_idx, t_id, None));
519 }
520 not_none_block.body.push(Instruction::selection_merge(
521 merge_label_id,
522 spirv::SelectionControl::NONE,
523 ));
524 function.consume(
525 not_none_block,
526 Instruction::branch_conditional(not_none_comp_id, tri_label_id, merge_label_id),
527 );
528
529 let barycentrics_id = self.id_gen.next();
530 tri_block.body.push(Instruction::ray_query_get_intersection(
531 spirv::Op::RayQueryGetIntersectionBarycentricsKHR,
532 barycentrics_type_id,
533 barycentrics_id,
534 query_id,
535 intersection_id,
536 ));
537
538 let front_face_id = self.id_gen.next();
539 tri_block.body.push(Instruction::ray_query_get_intersection(
540 spirv::Op::RayQueryGetIntersectionFrontFaceKHR,
541 bool_type_id,
542 front_face_id,
543 query_id,
544 intersection_id,
545 ));
546
547 let idx_id = self.get_index_constant(7);
548 let access_idx = self.id_gen.next();
549 tri_block.body.push(Instruction::access_chain(
550 barycentrics_pointer_type_id,
551 access_idx,
552 blank_intersection_id,
553 &[idx_id],
554 ));
555 tri_block
556 .body
557 .push(Instruction::store(access_idx, barycentrics_id, None));
558
559 let idx_id = self.get_index_constant(8);
560 let access_idx = self.id_gen.next();
561 tri_block.body.push(Instruction::access_chain(
562 bool_pointer_type_id,
563 access_idx,
564 blank_intersection_id,
565 &[idx_id],
566 ));
567 tri_block
568 .body
569 .push(Instruction::store(access_idx, front_face_id, None));
570 function.consume(tri_block, Instruction::branch(merge_label_id));
571 function.consume(merge_block, Instruction::branch(outer_merge_label_id));
572 function.consume(outer_merge_block, Instruction::branch(final_label_id));
573
574 let loaded_blank_intersection_id = self.id_gen.next();
575 final_block.body.push(Instruction::load(
576 intersection_type_id,
577 loaded_blank_intersection_id,
578 blank_intersection_id,
579 None,
580 ));
581 function.consume(
582 final_block,
583 Instruction::return_value(loaded_blank_intersection_id),
584 );
585
586 function.to_words(&mut self.logical_layout.function_definitions);
587 self.ray_query_functions.insert(
588 LookupRayQueryFunction::GetIntersection {
589 committed: is_committed,
590 },
591 func_id,
592 );
593 func_id
594 }
595
596 fn write_ray_query_initialize(&mut self, ir_module: &crate::Module) -> spirv::Word {
597 if let Some(&word) = self
598 .ray_query_functions
599 .get(&LookupRayQueryFunction::Initialize)
600 {
601 return word;
602 }
603
604 let ray_query_type_id = self.get_ray_query_pointer_id();
605 let acceleration_structure_type_id =
606 self.get_localtype_id(super::LocalType::AccelerationStructure);
607 let ray_desc_type_id = self.get_handle_type_id(
608 ir_module
609 .special_types
610 .ray_desc
611 .expect("ray desc should be set if ray queries are being initialized"),
612 );
613
614 let u32_ty = self.get_u32_type_id();
615 let u32_ptr_ty = self.get_pointer_type_id(u32_ty, spirv::StorageClass::Function);
616
617 let f32_type_id = self.get_f32_type_id();
618 let f32_ptr_ty = self.get_pointer_type_id(f32_type_id, spirv::StorageClass::Function);
619
620 let bool_type_id = self.get_bool_type_id();
621 let bool_vec3_type_id = self.get_vec3_bool_type_id();
622
623 let (func_id, mut function, arg_ids) = self.write_function_signature(
624 &[
625 ray_query_type_id,
626 acceleration_structure_type_id,
627 ray_desc_type_id,
628 u32_ptr_ty,
629 f32_ptr_ty,
630 ],
631 self.void_type,
632 );
633
634 let query_id = arg_ids[0];
635 let acceleration_structure_id = arg_ids[1];
636 let desc_id = arg_ids[2];
637 let init_tracker_id = arg_ids[3];
638 let t_max_tracker_id = arg_ids[4];
639
640 let label_id = self.id_gen.next();
641 let mut block = Block::new(label_id);
642
643 let flag_type_id = self.get_numeric_type_id(NumericType::Scalar(crate::Scalar::U32));
644
645 let ray_flags_id = self.id_gen.next();
647 block.body.push(Instruction::composite_extract(
648 flag_type_id,
649 ray_flags_id,
650 desc_id,
651 &[0],
652 ));
653 let cull_mask_id = self.id_gen.next();
654 block.body.push(Instruction::composite_extract(
655 flag_type_id,
656 cull_mask_id,
657 desc_id,
658 &[1],
659 ));
660
661 let tmin_id = self.id_gen.next();
662 block.body.push(Instruction::composite_extract(
663 f32_type_id,
664 tmin_id,
665 desc_id,
666 &[2],
667 ));
668 let tmax_id = self.id_gen.next();
669 block.body.push(Instruction::composite_extract(
670 f32_type_id,
671 tmax_id,
672 desc_id,
673 &[3],
674 ));
675 block
676 .body
677 .push(Instruction::store(t_max_tracker_id, tmax_id, None));
678
679 let vector_type_id = self.get_numeric_type_id(NumericType::Vector {
680 size: crate::VectorSize::Tri,
681 scalar: crate::Scalar::F32,
682 });
683 let ray_origin_id = self.id_gen.next();
684 block.body.push(Instruction::composite_extract(
685 vector_type_id,
686 ray_origin_id,
687 desc_id,
688 &[4],
689 ));
690 let ray_dir_id = self.id_gen.next();
691 block.body.push(Instruction::composite_extract(
692 vector_type_id,
693 ray_dir_id,
694 desc_id,
695 &[5],
696 ));
697
698 let valid_id = self.ray_query_initialization_tracking.then(||{
699 let tmin_le_tmax_id = self.id_gen.next();
700 block.body.push(Instruction::binary(
704 spirv::Op::FOrdLessThanEqual,
705 bool_type_id,
706 tmin_le_tmax_id,
707 tmin_id,
708 tmax_id,
709 ));
710
711 let tmin_ge_zero_id = self.id_gen.next();
715 let zero_id = self.get_constant_scalar(crate::Literal::F32(0.0));
716 block.body.push(Instruction::binary(
717 spirv::Op::FOrdGreaterThanEqual,
718 bool_type_id,
719 tmin_ge_zero_id,
720 tmin_id,
721 zero_id,
722 ));
723
724 let ray_origin_infinite_id = self.id_gen.next();
726 block.body.push(Instruction::unary(
727 spirv::Op::IsInf,
728 bool_vec3_type_id,
729 ray_origin_infinite_id,
730 ray_origin_id,
731 ));
732 let any_ray_origin_infinite_id = self.id_gen.next();
733 block.body.push(Instruction::unary(
734 spirv::Op::Any,
735 bool_type_id,
736 any_ray_origin_infinite_id,
737 ray_origin_infinite_id,
738 ));
739
740 let ray_origin_nan_id = self.id_gen.next();
741 block.body.push(Instruction::unary(
742 spirv::Op::IsNan,
743 bool_vec3_type_id,
744 ray_origin_nan_id,
745 ray_origin_id,
746 ));
747 let any_ray_origin_nan_id = self.id_gen.next();
748 block.body.push(Instruction::unary(
749 spirv::Op::Any,
750 bool_type_id,
751 any_ray_origin_nan_id,
752 ray_origin_nan_id,
753 ));
754
755 let ray_origin_not_finite_id = self.id_gen.next();
756 block.body.push(Instruction::binary(
757 spirv::Op::LogicalOr,
758 bool_type_id,
759 ray_origin_not_finite_id,
760 any_ray_origin_nan_id,
761 any_ray_origin_infinite_id,
762 ));
763
764 let all_ray_origin_finite_id = self.id_gen.next();
765 block.body.push(Instruction::unary(
766 spirv::Op::LogicalNot,
767 bool_type_id,
768 all_ray_origin_finite_id,
769 ray_origin_not_finite_id,
770 ));
771
772 let ray_dir_infinite_id = self.id_gen.next();
774 block.body.push(Instruction::unary(
775 spirv::Op::IsInf,
776 bool_vec3_type_id,
777 ray_dir_infinite_id,
778 ray_dir_id,
779 ));
780 let any_ray_dir_infinite_id = self.id_gen.next();
781 block.body.push(Instruction::unary(
782 spirv::Op::Any,
783 bool_type_id,
784 any_ray_dir_infinite_id,
785 ray_dir_infinite_id,
786 ));
787
788 let ray_dir_nan_id = self.id_gen.next();
789 block.body.push(Instruction::unary(
790 spirv::Op::IsNan,
791 bool_vec3_type_id,
792 ray_dir_nan_id,
793 ray_dir_id,
794 ));
795 let any_ray_dir_nan_id = self.id_gen.next();
796 block.body.push(Instruction::unary(
797 spirv::Op::Any,
798 bool_type_id,
799 any_ray_dir_nan_id,
800 ray_dir_nan_id,
801 ));
802
803 let ray_dir_not_finite_id = self.id_gen.next();
804 block.body.push(Instruction::binary(
805 spirv::Op::LogicalOr,
806 bool_type_id,
807 ray_dir_not_finite_id,
808 any_ray_dir_nan_id,
809 any_ray_dir_infinite_id,
810 ));
811
812 let all_ray_dir_finite_id = self.id_gen.next();
813 block.body.push(Instruction::unary(
814 spirv::Op::LogicalNot,
815 bool_type_id,
816 all_ray_dir_finite_id,
817 ray_dir_not_finite_id,
818 ));
819
820 fn write_less_than_2_true(
825 writer: &mut Writer,
826 block: &mut Block,
827 mut bools: Vec<spirv::Word>,
828 ) -> spirv::Word {
829 assert!(bools.len() > 1, "Must have multiple booleans!");
830 let bool_ty = writer.get_bool_type_id();
831 let mut each_two_true = Vec::new();
832 while let Some(last_bool) = bools.pop() {
833 for &bool in &bools {
834 let both_true_id = writer.write_logical_and(
835 block,
836 last_bool,
837 bool,
838 );
839 each_two_true.push(both_true_id);
840 }
841 }
842 let mut all_or_id = each_two_true.pop().expect("since this must have multiple booleans, there must be at least one thing in `each_two_true`");
843 for two_true in each_two_true {
844 let new_all_or_id = writer.id_gen.next();
845 block.body.push(Instruction::binary(
846 spirv::Op::LogicalOr,
847 bool_ty,
848 new_all_or_id,
849 all_or_id,
850 two_true,
851 ));
852 all_or_id = new_all_or_id;
853 }
854
855 let less_than_two_id = writer.id_gen.next();
856 block.body.push(Instruction::unary(
857 spirv::Op::LogicalNot,
858 bool_ty,
859 less_than_two_id,
860 all_or_id,
861 ));
862 less_than_two_id
863 }
864
865 let contains_skip_triangles = write_ray_flags_contains_flags(
868 self,
869 &mut block,
870 ray_flags_id,
871 crate::RayFlag::SKIP_TRIANGLES.bits(),
872 );
873 let contains_skip_aabbs = write_ray_flags_contains_flags(
874 self,
875 &mut block,
876 ray_flags_id,
877 crate::RayFlag::SKIP_AABBS.bits(),
878 );
879
880 let not_contain_skip_triangles_aabbs = write_less_than_2_true(
881 self,
882 &mut block,
883 vec![contains_skip_triangles, contains_skip_aabbs],
884 );
885
886 let contains_cull_back = write_ray_flags_contains_flags(
889 self,
890 &mut block,
891 ray_flags_id,
892 crate::RayFlag::CULL_BACK_FACING.bits(),
893 );
894 let contains_cull_front = write_ray_flags_contains_flags(
895 self,
896 &mut block,
897 ray_flags_id,
898 crate::RayFlag::CULL_FRONT_FACING.bits(),
899 );
900
901 let not_contain_skip_triangles_cull = write_less_than_2_true(
902 self,
903 &mut block,
904 vec![
905 contains_skip_triangles,
906 contains_cull_back,
907 contains_cull_front,
908 ],
909 );
910
911 let contains_opaque = write_ray_flags_contains_flags(
914 self,
915 &mut block,
916 ray_flags_id,
917 crate::RayFlag::FORCE_OPAQUE.bits(),
918 );
919 let contains_no_opaque = write_ray_flags_contains_flags(
920 self,
921 &mut block,
922 ray_flags_id,
923 crate::RayFlag::FORCE_NO_OPAQUE.bits(),
924 );
925 let contains_cull_opaque = write_ray_flags_contains_flags(
926 self,
927 &mut block,
928 ray_flags_id,
929 crate::RayFlag::CULL_OPAQUE.bits(),
930 );
931 let contains_cull_no_opaque = write_ray_flags_contains_flags(
932 self,
933 &mut block,
934 ray_flags_id,
935 crate::RayFlag::CULL_NO_OPAQUE.bits(),
936 );
937
938 let not_contain_multiple_opaque = write_less_than_2_true(
939 self,
940 &mut block,
941 vec![
942 contains_opaque,
943 contains_no_opaque,
944 contains_cull_opaque,
945 contains_cull_no_opaque,
946 ],
947 );
948
949 self.write_reduce_and(
951 &mut block,
952 vec![
953 tmin_le_tmax_id,
954 tmin_ge_zero_id,
955 all_ray_origin_finite_id,
956 all_ray_dir_finite_id,
957 not_contain_skip_triangles_aabbs,
958 not_contain_skip_triangles_cull,
959 not_contain_multiple_opaque,
960 ],
961 )
962 });
963
964 let merge_label_id = self.id_gen.next();
965 let merge_block = Block::new(merge_label_id);
966
967 let invalid_label_id = self.id_gen.next();
969 let mut invalid_block = Block::new(invalid_label_id);
970
971 let valid_label_id = self.id_gen.next();
972 let mut valid_block = Block::new(valid_label_id);
973
974 match valid_id {
975 Some(all_valid_id) => {
976 block.body.push(Instruction::selection_merge(
977 merge_label_id,
978 spirv::SelectionControl::NONE,
979 ));
980 function.consume(
981 block,
982 Instruction::branch_conditional(all_valid_id, valid_label_id, invalid_label_id),
983 );
984 }
985 None => {
986 function.consume(block, Instruction::branch(valid_label_id));
987 }
988 }
989
990 valid_block.body.push(Instruction::ray_query_initialize(
991 query_id,
992 acceleration_structure_id,
993 ray_flags_id,
994 cull_mask_id,
995 ray_origin_id,
996 tmin_id,
997 ray_dir_id,
998 tmax_id,
999 ));
1000
1001 let const_initialized =
1002 self.get_constant_scalar(crate::Literal::U32(RayQueryPoint::INITIALIZED.bits()));
1003 valid_block
1004 .body
1005 .push(Instruction::store(init_tracker_id, const_initialized, None));
1006
1007 function.consume(valid_block, Instruction::branch(merge_label_id));
1008
1009 if self
1010 .flags
1011 .contains(super::WriterFlags::PRINT_ON_RAY_QUERY_INITIALIZATION_FAIL)
1012 {
1013 self.write_debug_printf(
1014 &mut invalid_block,
1015 "Naga ignored invalid arguments to rayQueryInitialize with flags: %u t_min: %f t_max: %f origin: %v4f dir: %v4f",
1016 &[
1017 ray_flags_id,
1018 tmin_id,
1019 tmax_id,
1020 ray_origin_id,
1021 ray_dir_id,
1022 ],
1023 );
1024 }
1025
1026 function.consume(invalid_block, Instruction::branch(merge_label_id));
1027
1028 function.consume(merge_block, Instruction::return_void());
1029
1030 function.to_words(&mut self.logical_layout.function_definitions);
1031
1032 self.ray_query_functions
1033 .insert(LookupRayQueryFunction::Initialize, func_id);
1034 func_id
1035 }
1036
1037 fn write_ray_query_proceed(&mut self) -> spirv::Word {
1038 if let Some(&word) = self
1039 .ray_query_functions
1040 .get(&LookupRayQueryFunction::Proceed)
1041 {
1042 return word;
1043 }
1044
1045 let ray_query_type_id = self.get_ray_query_pointer_id();
1046
1047 let u32_ty = self.get_u32_type_id();
1048 let u32_ptr_ty = self.get_pointer_type_id(u32_ty, spirv::StorageClass::Function);
1049
1050 let bool_type_id = self.get_bool_type_id();
1051 let bool_ptr_ty = self.get_pointer_type_id(bool_type_id, spirv::StorageClass::Function);
1052
1053 let (func_id, mut function, arg_ids) =
1054 self.write_function_signature(&[ray_query_type_id, u32_ptr_ty], bool_type_id);
1055
1056 let query_id = arg_ids[0];
1057 let init_tracker_id = arg_ids[1];
1058
1059 let block_id = self.id_gen.next();
1060 let mut block = Block::new(block_id);
1061
1062 let proceeded_id = self.id_gen.next();
1064 let const_false = self.get_constant_scalar(crate::Literal::Bool(false));
1065 block.body.push(Instruction::variable(
1066 bool_ptr_ty,
1067 proceeded_id,
1068 spirv::StorageClass::Function,
1069 Some(const_false),
1070 ));
1071
1072 let initialized_tracker_id = self.id_gen.next();
1073 block.body.push(Instruction::load(
1074 u32_ty,
1075 initialized_tracker_id,
1076 init_tracker_id,
1077 None,
1078 ));
1079
1080 let merge_id = self.id_gen.next();
1081 let mut merge_block = Block::new(merge_id);
1082
1083 let valid_block_id = self.id_gen.next();
1084 let mut valid_block = Block::new(valid_block_id);
1085
1086 let instruction = if self.ray_query_initialization_tracking {
1087 let is_initialized = write_ray_flags_contains_flags(
1088 self,
1089 &mut block,
1090 initialized_tracker_id,
1091 RayQueryPoint::INITIALIZED.bits(),
1092 );
1093
1094 block.body.push(Instruction::selection_merge(
1095 merge_id,
1096 spirv::SelectionControl::NONE,
1097 ));
1098
1099 Instruction::branch_conditional(is_initialized, valid_block_id, merge_id)
1100 } else {
1101 Instruction::branch(valid_block_id)
1102 };
1103
1104 function.consume(block, instruction);
1105
1106 let has_proceeded = self.id_gen.next();
1107 valid_block.body.push(Instruction::ray_query_proceed(
1108 bool_type_id,
1109 has_proceeded,
1110 query_id,
1111 ));
1112
1113 valid_block
1114 .body
1115 .push(Instruction::store(proceeded_id, has_proceeded, None));
1116
1117 let add_flag_finished = self.get_constant_scalar(crate::Literal::U32(
1118 (RayQueryPoint::PROCEED | RayQueryPoint::FINISHED_TRAVERSAL).bits(),
1119 ));
1120 let add_flag_continuing =
1121 self.get_constant_scalar(crate::Literal::U32(RayQueryPoint::PROCEED.bits()));
1122
1123 let add_flags_id = self.id_gen.next();
1124 valid_block.body.push(Instruction::select(
1125 u32_ty,
1126 add_flags_id,
1127 has_proceeded,
1128 add_flag_continuing,
1129 add_flag_finished,
1130 ));
1131 let final_flags = self.id_gen.next();
1132 valid_block.body.push(Instruction::binary(
1133 spirv::Op::BitwiseOr,
1134 u32_ty,
1135 final_flags,
1136 initialized_tracker_id,
1137 add_flags_id,
1138 ));
1139 valid_block
1140 .body
1141 .push(Instruction::store(init_tracker_id, final_flags, None));
1142
1143 function.consume(valid_block, Instruction::branch(merge_id));
1144
1145 let loaded_proceeded_id = self.id_gen.next();
1146 merge_block.body.push(Instruction::load(
1147 bool_type_id,
1148 loaded_proceeded_id,
1149 proceeded_id,
1150 None,
1151 ));
1152
1153 function.consume(merge_block, Instruction::return_value(loaded_proceeded_id));
1154
1155 function.to_words(&mut self.logical_layout.function_definitions);
1156
1157 self.ray_query_functions
1158 .insert(LookupRayQueryFunction::Proceed, func_id);
1159 func_id
1160 }
1161
1162 fn write_ray_query_generate_intersection(&mut self) -> spirv::Word {
1163 if let Some(&word) = self
1164 .ray_query_functions
1165 .get(&LookupRayQueryFunction::GenerateIntersection)
1166 {
1167 return word;
1168 }
1169
1170 let ray_query_type_id = self.get_ray_query_pointer_id();
1171
1172 let u32_ty = self.get_u32_type_id();
1173 let u32_ptr_ty = self.get_pointer_type_id(u32_ty, spirv::StorageClass::Function);
1174
1175 let f32_type_id = self.get_f32_type_id();
1176 let f32_ptr_type_id = self.get_pointer_type_id(f32_type_id, spirv::StorageClass::Function);
1177
1178 let bool_type_id = self.get_bool_type_id();
1179
1180 let (func_id, mut function, arg_ids) = self.write_function_signature(
1181 &[ray_query_type_id, u32_ptr_ty, f32_type_id, f32_ptr_type_id],
1182 self.void_type,
1183 );
1184
1185 let query_id = arg_ids[0];
1186 let init_tracker_id = arg_ids[1];
1187 let depth_id = arg_ids[2];
1188 let t_max_tracker_id = arg_ids[3];
1189
1190 let block_id = self.id_gen.next();
1191 let mut block = Block::new(block_id);
1192
1193 let current_t = self.id_gen.next();
1194 block.body.push(Instruction::variable(
1195 f32_ptr_type_id,
1196 current_t,
1197 spirv::StorageClass::Function,
1198 None,
1199 ));
1200
1201 let current_t = self.id_gen.next();
1202 block.body.push(Instruction::variable(
1203 f32_ptr_type_id,
1204 current_t,
1205 spirv::StorageClass::Function,
1206 None,
1207 ));
1208
1209 let valid_id = self.id_gen.next();
1210 let mut valid_block = Block::new(valid_id);
1211
1212 let final_label_id = self.id_gen.next();
1213 let final_block = Block::new(final_label_id);
1214
1215 let instruction = if self.ray_query_initialization_tracking {
1216 let initialized_tracker_id = self.id_gen.next();
1217 block.body.push(Instruction::load(
1218 u32_ty,
1219 initialized_tracker_id,
1220 init_tracker_id,
1221 None,
1222 ));
1223
1224 let proceeded_id = write_ray_flags_contains_flags(
1225 self,
1226 &mut block,
1227 initialized_tracker_id,
1228 RayQueryPoint::PROCEED.bits(),
1229 );
1230 let finished_proceed_id = write_ray_flags_contains_flags(
1231 self,
1232 &mut block,
1233 initialized_tracker_id,
1234 RayQueryPoint::FINISHED_TRAVERSAL.bits(),
1235 );
1236
1237 let not_finished_id = self.id_gen.next();
1240 block.body.push(Instruction::unary(
1241 spirv::Op::LogicalNot,
1242 bool_type_id,
1243 not_finished_id,
1244 finished_proceed_id,
1245 ));
1246
1247 let is_valid_id = self.write_logical_and(&mut block, not_finished_id, proceeded_id);
1248
1249 block.body.push(Instruction::selection_merge(
1250 final_label_id,
1251 spirv::SelectionControl::NONE,
1252 ));
1253
1254 Instruction::branch_conditional(is_valid_id, valid_id, final_label_id)
1255 } else {
1256 Instruction::branch(valid_id)
1257 };
1258
1259 function.consume(block, instruction);
1260
1261 let intersection_id = self.get_constant_scalar(crate::Literal::U32(
1262 spirv::RayQueryIntersection::RayQueryCandidateIntersectionKHR as _,
1263 ));
1264 let committed_intersection_id = self.get_constant_scalar(crate::Literal::U32(
1265 spirv::RayQueryIntersection::RayQueryCommittedIntersectionKHR as _,
1266 ));
1267 let raw_kind_id = self.id_gen.next();
1268 valid_block
1269 .body
1270 .push(Instruction::ray_query_get_intersection(
1271 spirv::Op::RayQueryGetIntersectionTypeKHR,
1272 u32_ty,
1273 raw_kind_id,
1274 query_id,
1275 intersection_id,
1276 ));
1277
1278 let candidate_aabb_id = self.get_constant_scalar(crate::Literal::U32(
1279 spirv::RayQueryCandidateIntersectionType::RayQueryCandidateIntersectionAABBKHR as _,
1280 ));
1281 let intersection_aabb_id = self.id_gen.next();
1282 valid_block.body.push(Instruction::binary(
1283 spirv::Op::IEqual,
1284 bool_type_id,
1285 intersection_aabb_id,
1286 raw_kind_id,
1287 candidate_aabb_id,
1288 ));
1289
1290 let t_min_id = self.id_gen.next();
1295 valid_block.body.push(Instruction::ray_query_get_t_min(
1296 f32_type_id,
1297 t_min_id,
1298 query_id,
1299 ));
1300
1301 let committed_type_id = self.id_gen.next();
1318 valid_block
1319 .body
1320 .push(Instruction::ray_query_get_intersection(
1321 spirv::Op::RayQueryGetIntersectionTypeKHR,
1322 u32_ty,
1323 committed_type_id,
1324 query_id,
1325 committed_intersection_id,
1326 ));
1327
1328 let no_committed = self.id_gen.next();
1329 valid_block.body.push(Instruction::binary(
1330 spirv::Op::IEqual,
1331 bool_type_id,
1332 no_committed,
1333 committed_type_id,
1334 self.get_constant_scalar(crate::Literal::U32(
1335 spirv::RayQueryCommittedIntersectionType::RayQueryCommittedIntersectionNoneKHR as _,
1336 )),
1337 ));
1338
1339 let next_valid_block_id = self.id_gen.next();
1340 let no_committed_block_id = self.id_gen.next();
1341 let mut no_committed_block = Block::new(no_committed_block_id);
1342 let committed_block_id = self.id_gen.next();
1343 let mut committed_block = Block::new(committed_block_id);
1344 valid_block.body.push(Instruction::selection_merge(
1345 next_valid_block_id,
1346 spirv::SelectionControl::NONE,
1347 ));
1348 function.consume(
1349 valid_block,
1350 Instruction::branch_conditional(
1351 no_committed,
1352 no_committed_block_id,
1353 committed_block_id,
1354 ),
1355 );
1356
1357 let t_max_id = self.id_gen.next();
1359 no_committed_block.body.push(Instruction::load(
1360 f32_type_id,
1361 t_max_id,
1362 t_max_tracker_id,
1363 None,
1364 ));
1365 no_committed_block
1366 .body
1367 .push(Instruction::store(current_t, t_max_id, None));
1368 function.consume(no_committed_block, Instruction::branch(next_valid_block_id));
1369
1370 let latest_t_id = self.id_gen.next();
1372 committed_block
1373 .body
1374 .push(Instruction::ray_query_get_intersection(
1375 spirv::Op::RayQueryGetIntersectionTKHR,
1376 f32_type_id,
1377 latest_t_id,
1378 query_id,
1379 intersection_id,
1380 ));
1381 committed_block
1382 .body
1383 .push(Instruction::store(current_t, latest_t_id, None));
1384 function.consume(committed_block, Instruction::branch(next_valid_block_id));
1385
1386 let mut valid_block = Block::new(next_valid_block_id);
1387
1388 let t_ge_t_min = self.id_gen.next();
1389 valid_block.body.push(Instruction::binary(
1390 spirv::Op::FOrdGreaterThanEqual,
1391 bool_type_id,
1392 t_ge_t_min,
1393 depth_id,
1394 t_min_id,
1395 ));
1396 let t_current = self.id_gen.next();
1397 valid_block
1398 .body
1399 .push(Instruction::load(f32_type_id, t_current, current_t, None));
1400 let t_le_t_current = self.id_gen.next();
1401 valid_block.body.push(Instruction::binary(
1402 spirv::Op::FOrdLessThanEqual,
1403 bool_type_id,
1404 t_le_t_current,
1405 depth_id,
1406 t_current,
1407 ));
1408
1409 let t_in_range = self.id_gen.next();
1410 valid_block.body.push(Instruction::binary(
1411 spirv::Op::LogicalAnd,
1412 bool_type_id,
1413 t_in_range,
1414 t_ge_t_min,
1415 t_le_t_current,
1416 ));
1417
1418 let call_valid_id = self.id_gen.next();
1419 valid_block.body.push(Instruction::binary(
1420 spirv::Op::LogicalAnd,
1421 bool_type_id,
1422 call_valid_id,
1423 t_in_range,
1424 intersection_aabb_id,
1425 ));
1426
1427 let generate_label_id = self.id_gen.next();
1428 let mut generate_block = Block::new(generate_label_id);
1429
1430 let merge_label_id = self.id_gen.next();
1431 let merge_block = Block::new(merge_label_id);
1432
1433 valid_block.body.push(Instruction::selection_merge(
1434 merge_label_id,
1435 spirv::SelectionControl::NONE,
1436 ));
1437 function.consume(
1438 valid_block,
1439 Instruction::branch_conditional(call_valid_id, generate_label_id, merge_label_id),
1440 );
1441
1442 generate_block
1443 .body
1444 .push(Instruction::ray_query_generate_intersection(
1445 query_id, depth_id,
1446 ));
1447
1448 function.consume(generate_block, Instruction::branch(merge_label_id));
1449 function.consume(merge_block, Instruction::branch(final_label_id));
1450
1451 function.consume(final_block, Instruction::return_void());
1452
1453 function.to_words(&mut self.logical_layout.function_definitions);
1454
1455 self.ray_query_functions
1456 .insert(LookupRayQueryFunction::GenerateIntersection, func_id);
1457 func_id
1458 }
1459
1460 fn write_ray_query_confirm_intersection(&mut self) -> spirv::Word {
1461 if let Some(&word) = self
1462 .ray_query_functions
1463 .get(&LookupRayQueryFunction::ConfirmIntersection)
1464 {
1465 return word;
1466 }
1467
1468 let ray_query_type_id = self.get_ray_query_pointer_id();
1469
1470 let u32_ty = self.get_u32_type_id();
1471 let u32_ptr_ty = self.get_pointer_type_id(u32_ty, spirv::StorageClass::Function);
1472
1473 let bool_type_id = self.get_bool_type_id();
1474
1475 let (func_id, mut function, arg_ids) =
1476 self.write_function_signature(&[ray_query_type_id, u32_ptr_ty], self.void_type);
1477
1478 let query_id = arg_ids[0];
1479 let init_tracker_id = arg_ids[1];
1480
1481 let block_id = self.id_gen.next();
1482 let mut block = Block::new(block_id);
1483
1484 let valid_id = self.id_gen.next();
1485 let mut valid_block = Block::new(valid_id);
1486
1487 let final_label_id = self.id_gen.next();
1488 let final_block = Block::new(final_label_id);
1489
1490 let instruction = if self.ray_query_initialization_tracking {
1491 let initialized_tracker_id = self.id_gen.next();
1492 block.body.push(Instruction::load(
1493 u32_ty,
1494 initialized_tracker_id,
1495 init_tracker_id,
1496 None,
1497 ));
1498
1499 let proceeded_id = write_ray_flags_contains_flags(
1500 self,
1501 &mut block,
1502 initialized_tracker_id,
1503 RayQueryPoint::PROCEED.bits(),
1504 );
1505 let finished_proceed_id = write_ray_flags_contains_flags(
1506 self,
1507 &mut block,
1508 initialized_tracker_id,
1509 RayQueryPoint::FINISHED_TRAVERSAL.bits(),
1510 );
1511 let not_finished_id = self.id_gen.next();
1513 block.body.push(Instruction::unary(
1514 spirv::Op::LogicalNot,
1515 bool_type_id,
1516 not_finished_id,
1517 finished_proceed_id,
1518 ));
1519
1520 let is_valid_id = self.write_logical_and(&mut block, not_finished_id, proceeded_id);
1521
1522 block.body.push(Instruction::selection_merge(
1523 final_label_id,
1524 spirv::SelectionControl::NONE,
1525 ));
1526
1527 Instruction::branch_conditional(is_valid_id, valid_id, final_label_id)
1528 } else {
1529 Instruction::branch(valid_id)
1530 };
1531
1532 function.consume(block, instruction);
1533
1534 let intersection_id = self.get_constant_scalar(crate::Literal::U32(
1535 spirv::RayQueryIntersection::RayQueryCandidateIntersectionKHR as _,
1536 ));
1537 let raw_kind_id = self.id_gen.next();
1538 valid_block
1539 .body
1540 .push(Instruction::ray_query_get_intersection(
1541 spirv::Op::RayQueryGetIntersectionTypeKHR,
1542 u32_ty,
1543 raw_kind_id,
1544 query_id,
1545 intersection_id,
1546 ));
1547
1548 let candidate_tri_id = self.get_constant_scalar(crate::Literal::U32(
1549 spirv::RayQueryCandidateIntersectionType::RayQueryCandidateIntersectionTriangleKHR as _,
1550 ));
1551 let intersection_tri_id = self.id_gen.next();
1552 valid_block.body.push(Instruction::binary(
1553 spirv::Op::IEqual,
1554 bool_type_id,
1555 intersection_tri_id,
1556 raw_kind_id,
1557 candidate_tri_id,
1558 ));
1559
1560 let generate_label_id = self.id_gen.next();
1561 let mut generate_block = Block::new(generate_label_id);
1562
1563 let merge_label_id = self.id_gen.next();
1564 let merge_block = Block::new(merge_label_id);
1565
1566 valid_block.body.push(Instruction::selection_merge(
1567 merge_label_id,
1568 spirv::SelectionControl::NONE,
1569 ));
1570 function.consume(
1571 valid_block,
1572 Instruction::branch_conditional(intersection_tri_id, generate_label_id, merge_label_id),
1573 );
1574
1575 generate_block
1576 .body
1577 .push(Instruction::ray_query_confirm_intersection(query_id));
1578
1579 function.consume(generate_block, Instruction::branch(merge_label_id));
1580 function.consume(merge_block, Instruction::branch(final_label_id));
1581
1582 function.consume(final_block, Instruction::return_void());
1583
1584 self.ray_query_functions
1585 .insert(LookupRayQueryFunction::ConfirmIntersection, func_id);
1586
1587 function.to_words(&mut self.logical_layout.function_definitions);
1588
1589 func_id
1590 }
1591
1592 fn write_ray_query_get_vertex_positions(
1593 &mut self,
1594 is_committed: bool,
1595 ir_module: &crate::Module,
1596 ) -> spirv::Word {
1597 if let Some(&word) =
1598 self.ray_query_functions
1599 .get(&LookupRayQueryFunction::GetVertexPositions {
1600 committed: is_committed,
1601 })
1602 {
1603 return word;
1604 }
1605
1606 let (committed_ty, committed_tri_ty) = if is_committed {
1607 (
1608 spirv::RayQueryIntersection::RayQueryCommittedIntersectionKHR as u32,
1609 spirv::RayQueryCommittedIntersectionType::RayQueryCommittedIntersectionTriangleKHR
1610 as u32,
1611 )
1612 } else {
1613 (
1614 spirv::RayQueryIntersection::RayQueryCandidateIntersectionKHR as u32,
1615 spirv::RayQueryCandidateIntersectionType::RayQueryCandidateIntersectionTriangleKHR
1616 as u32,
1617 )
1618 };
1619
1620 let ray_query_type_id = self.get_ray_query_pointer_id();
1621
1622 let u32_ty = self.get_u32_type_id();
1623 let u32_ptr_ty = self.get_pointer_type_id(u32_ty, spirv::StorageClass::Function);
1624
1625 let rq_get_vertex_positions_ty_id = self.get_handle_type_id(
1626 *ir_module
1627 .special_types
1628 .ray_vertex_return
1629 .as_ref()
1630 .expect("must be generated when reading in get vertex position"),
1631 );
1632 let ptr_return_ty =
1633 self.get_pointer_type_id(rq_get_vertex_positions_ty_id, spirv::StorageClass::Function);
1634
1635 let bool_type_id = self.get_bool_type_id();
1636
1637 let (func_id, mut function, arg_ids) = self.write_function_signature(
1638 &[ray_query_type_id, u32_ptr_ty],
1639 rq_get_vertex_positions_ty_id,
1640 );
1641
1642 let query_id = arg_ids[0];
1643 let init_tracker_id = arg_ids[1];
1644
1645 let block_id = self.id_gen.next();
1646 let mut block = Block::new(block_id);
1647
1648 let return_id = self.id_gen.next();
1649 block.body.push(Instruction::variable(
1650 ptr_return_ty,
1651 return_id,
1652 spirv::StorageClass::Function,
1653 Some(self.get_constant_null(rq_get_vertex_positions_ty_id)),
1654 ));
1655
1656 let valid_id = self.id_gen.next();
1657 let mut valid_block = Block::new(valid_id);
1658
1659 let final_label_id = self.id_gen.next();
1660 let mut final_block = Block::new(final_label_id);
1661
1662 let instruction = if self.ray_query_initialization_tracking {
1663 let initialized_tracker_id = self.id_gen.next();
1664 block.body.push(Instruction::load(
1665 u32_ty,
1666 initialized_tracker_id,
1667 init_tracker_id,
1668 None,
1669 ));
1670
1671 let proceeded_id = write_ray_flags_contains_flags(
1672 self,
1673 &mut block,
1674 initialized_tracker_id,
1675 RayQueryPoint::PROCEED.bits(),
1676 );
1677 let finished_proceed_id = write_ray_flags_contains_flags(
1678 self,
1679 &mut block,
1680 initialized_tracker_id,
1681 RayQueryPoint::FINISHED_TRAVERSAL.bits(),
1682 );
1683
1684 let correct_finish_id = if is_committed {
1685 finished_proceed_id
1686 } else {
1687 let not_finished_id = self.id_gen.next();
1688 block.body.push(Instruction::unary(
1689 spirv::Op::LogicalNot,
1690 bool_type_id,
1691 not_finished_id,
1692 finished_proceed_id,
1693 ));
1694 not_finished_id
1695 };
1696
1697 let is_valid_id = self.write_logical_and(&mut block, correct_finish_id, proceeded_id);
1698 block.body.push(Instruction::selection_merge(
1699 final_label_id,
1700 spirv::SelectionControl::NONE,
1701 ));
1702 Instruction::branch_conditional(is_valid_id, valid_id, final_label_id)
1703 } else {
1704 Instruction::branch(valid_id)
1705 };
1706
1707 function.consume(block, instruction);
1708
1709 let intersection_id = self.get_constant_scalar(crate::Literal::U32(committed_ty));
1710 let raw_kind_id = self.id_gen.next();
1711 valid_block
1712 .body
1713 .push(Instruction::ray_query_get_intersection(
1714 spirv::Op::RayQueryGetIntersectionTypeKHR,
1715 u32_ty,
1716 raw_kind_id,
1717 query_id,
1718 intersection_id,
1719 ));
1720
1721 let candidate_tri_id = self.get_constant_scalar(crate::Literal::U32(committed_tri_ty));
1722 let intersection_tri_id = self.id_gen.next();
1723 valid_block.body.push(Instruction::binary(
1724 spirv::Op::IEqual,
1725 bool_type_id,
1726 intersection_tri_id,
1727 raw_kind_id,
1728 candidate_tri_id,
1729 ));
1730
1731 let generate_label_id = self.id_gen.next();
1732 let mut vertex_return_block = Block::new(generate_label_id);
1733
1734 let merge_label_id = self.id_gen.next();
1735 let merge_block = Block::new(merge_label_id);
1736
1737 valid_block.body.push(Instruction::selection_merge(
1738 merge_label_id,
1739 spirv::SelectionControl::NONE,
1740 ));
1741 function.consume(
1742 valid_block,
1743 Instruction::branch_conditional(intersection_tri_id, generate_label_id, merge_label_id),
1744 );
1745
1746 let vertices_id = self.id_gen.next();
1747 vertex_return_block
1748 .body
1749 .push(Instruction::ray_query_return_vertex_position(
1750 rq_get_vertex_positions_ty_id,
1751 vertices_id,
1752 query_id,
1753 intersection_id,
1754 ));
1755 vertex_return_block
1756 .body
1757 .push(Instruction::store(return_id, vertices_id, None));
1758
1759 function.consume(vertex_return_block, Instruction::branch(merge_label_id));
1760 function.consume(merge_block, Instruction::branch(final_label_id));
1761
1762 let loaded_pos_id = self.id_gen.next();
1763 final_block.body.push(Instruction::load(
1764 rq_get_vertex_positions_ty_id,
1765 loaded_pos_id,
1766 return_id,
1767 None,
1768 ));
1769
1770 function.consume(final_block, Instruction::return_value(loaded_pos_id));
1771
1772 self.ray_query_functions.insert(
1773 LookupRayQueryFunction::GetVertexPositions {
1774 committed: is_committed,
1775 },
1776 func_id,
1777 );
1778
1779 function.to_words(&mut self.logical_layout.function_definitions);
1780
1781 func_id
1782 }
1783
1784 fn write_ray_query_terminate(&mut self) -> spirv::Word {
1785 if let Some(&word) = self
1786 .ray_query_functions
1787 .get(&LookupRayQueryFunction::Terminate)
1788 {
1789 return word;
1790 }
1791
1792 let ray_query_type_id = self.get_ray_query_pointer_id();
1793
1794 let u32_ty = self.get_u32_type_id();
1795 let u32_ptr_ty = self.get_pointer_type_id(u32_ty, spirv::StorageClass::Function);
1796
1797 let bool_type_id = self.get_bool_type_id();
1798
1799 let (func_id, mut function, arg_ids) =
1800 self.write_function_signature(&[ray_query_type_id, u32_ptr_ty], self.void_type);
1801
1802 let query_id = arg_ids[0];
1803 let init_tracker_id = arg_ids[1];
1804
1805 let block_id = self.id_gen.next();
1806 let mut block = Block::new(block_id);
1807
1808 let initialized_tracker_id = self.id_gen.next();
1809 block.body.push(Instruction::load(
1810 u32_ty,
1811 initialized_tracker_id,
1812 init_tracker_id,
1813 None,
1814 ));
1815
1816 let merge_id = self.id_gen.next();
1817 let merge_block = Block::new(merge_id);
1818
1819 let valid_block_id = self.id_gen.next();
1820 let mut valid_block = Block::new(valid_block_id);
1821
1822 let instruction = if self.ray_query_initialization_tracking {
1823 let has_proceeded = write_ray_flags_contains_flags(
1824 self,
1825 &mut block,
1826 initialized_tracker_id,
1827 RayQueryPoint::PROCEED.bits(),
1828 );
1829
1830 let finished_proceed_id = write_ray_flags_contains_flags(
1831 self,
1832 &mut block,
1833 initialized_tracker_id,
1834 RayQueryPoint::FINISHED_TRAVERSAL.bits(),
1835 );
1836
1837 let not_finished_id = self.id_gen.next();
1838 block.body.push(Instruction::unary(
1839 spirv::Op::LogicalNot,
1840 bool_type_id,
1841 not_finished_id,
1842 finished_proceed_id,
1843 ));
1844
1845 let valid_call = self.write_logical_and(&mut block, not_finished_id, has_proceeded);
1846
1847 block.body.push(Instruction::selection_merge(
1848 merge_id,
1849 spirv::SelectionControl::NONE,
1850 ));
1851
1852 Instruction::branch_conditional(valid_call, valid_block_id, merge_id)
1853 } else {
1854 Instruction::branch(valid_block_id)
1855 };
1856
1857 function.consume(block, instruction);
1858
1859 valid_block
1860 .body
1861 .push(Instruction::ray_query_terminate(query_id));
1862
1863 function.consume(valid_block, Instruction::branch(merge_id));
1864
1865 function.consume(merge_block, Instruction::return_void());
1866
1867 function.to_words(&mut self.logical_layout.function_definitions);
1868
1869 self.ray_query_functions
1870 .insert(LookupRayQueryFunction::Proceed, func_id);
1871 func_id
1872 }
1873}
1874
1875impl BlockContext<'_> {
1876 pub(super) fn write_ray_query_function(
1877 &mut self,
1878 query: Handle<crate::Expression>,
1879 function: &crate::RayQueryFunction,
1880 block: &mut Block,
1881 ) {
1882 let query_id = self.cached[query];
1883 let tracker_ids = *self
1884 .ray_query_tracker_expr
1885 .get(&query)
1886 .expect("not a cached ray query");
1887
1888 match *function {
1889 crate::RayQueryFunction::Initialize {
1890 acceleration_structure,
1891 descriptor,
1892 } => {
1893 let desc_id = self.cached[descriptor];
1894 let acc_struct_id = self.get_handle_id(acceleration_structure);
1895
1896 let func = self.writer.write_ray_query_initialize(self.ir_module);
1897
1898 let func_id = self.gen_id();
1899 block.body.push(Instruction::function_call(
1900 self.writer.void_type,
1901 func_id,
1902 func,
1903 &[
1904 query_id,
1905 acc_struct_id,
1906 desc_id,
1907 tracker_ids.initialized_tracker,
1908 tracker_ids.t_max_tracker,
1909 ],
1910 ));
1911 }
1912 crate::RayQueryFunction::Proceed { result } => {
1913 let id = self.gen_id();
1914 self.cached[result] = id;
1915
1916 let bool_ty = self.writer.get_bool_type_id();
1917
1918 let func_id = self.writer.write_ray_query_proceed();
1919 block.body.push(Instruction::function_call(
1920 bool_ty,
1921 id,
1922 func_id,
1923 &[query_id, tracker_ids.initialized_tracker],
1924 ));
1925 }
1926 crate::RayQueryFunction::GenerateIntersection { hit_t } => {
1927 let hit_id = self.cached[hit_t];
1928
1929 let func_id = self.writer.write_ray_query_generate_intersection();
1930
1931 let func_call_id = self.gen_id();
1932 block.body.push(Instruction::function_call(
1933 self.writer.void_type,
1934 func_call_id,
1935 func_id,
1936 &[
1937 query_id,
1938 tracker_ids.initialized_tracker,
1939 hit_id,
1940 tracker_ids.t_max_tracker,
1941 ],
1942 ));
1943 }
1944 crate::RayQueryFunction::ConfirmIntersection => {
1945 let func_id = self.writer.write_ray_query_confirm_intersection();
1946
1947 let func_call_id = self.gen_id();
1948 block.body.push(Instruction::function_call(
1949 self.writer.void_type,
1950 func_call_id,
1951 func_id,
1952 &[query_id, tracker_ids.initialized_tracker],
1953 ));
1954 }
1955 crate::RayQueryFunction::Terminate => {
1956 let id = self.gen_id();
1957
1958 let func_id = self.writer.write_ray_query_terminate();
1959 block.body.push(Instruction::function_call(
1960 self.writer.void_type,
1961 id,
1962 func_id,
1963 &[query_id, tracker_ids.initialized_tracker],
1964 ));
1965 }
1966 }
1967 }
1968
1969 pub(super) fn write_ray_query_return_vertex_position(
1970 &mut self,
1971 query: Handle<crate::Expression>,
1972 block: &mut Block,
1973 is_committed: bool,
1974 ) -> spirv::Word {
1975 let fn_id = self
1976 .writer
1977 .write_ray_query_get_vertex_positions(is_committed, self.ir_module);
1978
1979 let query_id = self.cached[query];
1980 let tracker_id = *self
1981 .ray_query_tracker_expr
1982 .get(&query)
1983 .expect("not a cached ray query");
1984
1985 let rq_get_vertex_positions_ty_id = self.get_handle_type_id(
1986 *self
1987 .ir_module
1988 .special_types
1989 .ray_vertex_return
1990 .as_ref()
1991 .expect("must be generated when reading in get vertex position"),
1992 );
1993
1994 let func_call_id = self.gen_id();
1995 block.body.push(Instruction::function_call(
1996 rq_get_vertex_positions_ty_id,
1997 func_call_id,
1998 fn_id,
1999 &[query_id, tracker_id.initialized_tracker],
2000 ));
2001 func_call_id
2002 }
2003}