1use alloc::vec;
6
7use super::{
8 Block, BlockContext, Function, FunctionArgument, Instruction, LookupFunctionType, NumericType,
9 Writer,
10};
11use crate::arena::Handle;
12
13impl Writer {
14 pub(super) fn write_ray_query_get_intersection_function(
15 &mut self,
16 is_committed: bool,
17 ir_module: &crate::Module,
18 ) -> spirv::Word {
19 if is_committed {
20 if let Some(func_id) = self.ray_get_committed_intersection_function {
21 return func_id;
22 }
23 } else if let Some(func_id) = self.ray_get_candidate_intersection_function {
24 return func_id;
25 };
26 let ray_intersection = ir_module.special_types.ray_intersection.unwrap();
27 let intersection_type_id = self.get_handle_type_id(ray_intersection);
28 let intersection_pointer_type_id =
29 self.get_pointer_type_id(intersection_type_id, spirv::StorageClass::Function);
30
31 let flag_type_id = self.get_u32_type_id();
32 let flag_pointer_type_id =
33 self.get_pointer_type_id(flag_type_id, spirv::StorageClass::Function);
34
35 let transform_type_id = self.get_numeric_type_id(NumericType::Matrix {
36 columns: crate::VectorSize::Quad,
37 rows: crate::VectorSize::Tri,
38 scalar: crate::Scalar::F32,
39 });
40 let transform_pointer_type_id =
41 self.get_pointer_type_id(transform_type_id, spirv::StorageClass::Function);
42
43 let barycentrics_type_id = self.get_numeric_type_id(NumericType::Vector {
44 size: crate::VectorSize::Bi,
45 scalar: crate::Scalar::F32,
46 });
47 let barycentrics_pointer_type_id =
48 self.get_pointer_type_id(barycentrics_type_id, spirv::StorageClass::Function);
49
50 let bool_type_id = self.get_bool_type_id();
51 let bool_pointer_type_id =
52 self.get_pointer_type_id(bool_type_id, spirv::StorageClass::Function);
53
54 let scalar_type_id = self.get_f32_type_id();
55 let float_pointer_type_id = self.get_f32_pointer_type_id(spirv::StorageClass::Function);
56
57 let argument_type_id = self.get_ray_query_pointer_id();
58
59 let func_ty = self.get_function_type(LookupFunctionType {
60 parameter_type_ids: vec![argument_type_id],
61 return_type_id: intersection_type_id,
62 });
63
64 let mut function = Function::default();
65 let func_id = self.id_gen.next();
66 function.signature = Some(Instruction::function(
67 intersection_type_id,
68 func_id,
69 spirv::FunctionControl::empty(),
70 func_ty,
71 ));
72 let blank_intersection = self.get_constant_null(intersection_type_id);
73 let query_id = self.id_gen.next();
74 let instruction = Instruction::function_parameter(argument_type_id, query_id);
75 function.parameters.push(FunctionArgument {
76 instruction,
77 handle_id: 0,
78 });
79
80 let label_id = self.id_gen.next();
81 let mut block = Block::new(label_id);
82
83 let blank_intersection_id = self.id_gen.next();
84 block.body.push(Instruction::variable(
85 intersection_pointer_type_id,
86 blank_intersection_id,
87 spirv::StorageClass::Function,
88 Some(blank_intersection),
89 ));
90
91 let intersection_id = self.get_constant_scalar(crate::Literal::U32(if is_committed {
92 spirv::RayQueryIntersection::RayQueryCommittedIntersectionKHR
93 } else {
94 spirv::RayQueryIntersection::RayQueryCandidateIntersectionKHR
95 } as _));
96 let raw_kind_id = self.id_gen.next();
97 block.body.push(Instruction::ray_query_get_intersection(
98 spirv::Op::RayQueryGetIntersectionTypeKHR,
99 flag_type_id,
100 raw_kind_id,
101 query_id,
102 intersection_id,
103 ));
104 let kind_id = if is_committed {
105 raw_kind_id
107 } else {
108 let condition_id = self.id_gen.next();
110 let committed_triangle_kind_id = self.get_constant_scalar(crate::Literal::U32(
111 spirv::RayQueryCandidateIntersectionType::RayQueryCandidateIntersectionTriangleKHR
112 as _,
113 ));
114 block.body.push(Instruction::binary(
115 spirv::Op::IEqual,
116 self.get_bool_type_id(),
117 condition_id,
118 raw_kind_id,
119 committed_triangle_kind_id,
120 ));
121 let kind_id = self.id_gen.next();
122 block.body.push(Instruction::select(
123 flag_type_id,
124 kind_id,
125 condition_id,
126 self.get_constant_scalar(crate::Literal::U32(
127 crate::RayQueryIntersection::Triangle as _,
128 )),
129 self.get_constant_scalar(crate::Literal::U32(
130 crate::RayQueryIntersection::Aabb as _,
131 )),
132 ));
133 kind_id
134 };
135 let idx_id = self.get_index_constant(0);
136 let access_idx = self.id_gen.next();
137 block.body.push(Instruction::access_chain(
138 flag_pointer_type_id,
139 access_idx,
140 blank_intersection_id,
141 &[idx_id],
142 ));
143 block
144 .body
145 .push(Instruction::store(access_idx, kind_id, None));
146
147 let not_none_comp_id = self.id_gen.next();
148 let none_id =
149 self.get_constant_scalar(crate::Literal::U32(crate::RayQueryIntersection::None as _));
150 block.body.push(Instruction::binary(
151 spirv::Op::INotEqual,
152 self.get_bool_type_id(),
153 not_none_comp_id,
154 kind_id,
155 none_id,
156 ));
157
158 let not_none_label_id = self.id_gen.next();
159 let mut not_none_block = Block::new(not_none_label_id);
160
161 let final_label_id = self.id_gen.next();
162 let mut final_block = Block::new(final_label_id);
163
164 block.body.push(Instruction::selection_merge(
165 final_label_id,
166 spirv::SelectionControl::NONE,
167 ));
168 function.consume(
169 block,
170 Instruction::branch_conditional(not_none_comp_id, not_none_label_id, final_label_id),
171 );
172
173 let instance_custom_index_id = self.id_gen.next();
174 not_none_block
175 .body
176 .push(Instruction::ray_query_get_intersection(
177 spirv::Op::RayQueryGetIntersectionInstanceCustomIndexKHR,
178 flag_type_id,
179 instance_custom_index_id,
180 query_id,
181 intersection_id,
182 ));
183 let instance_id = self.id_gen.next();
184 not_none_block
185 .body
186 .push(Instruction::ray_query_get_intersection(
187 spirv::Op::RayQueryGetIntersectionInstanceIdKHR,
188 flag_type_id,
189 instance_id,
190 query_id,
191 intersection_id,
192 ));
193 let sbt_record_offset_id = self.id_gen.next();
194 not_none_block
195 .body
196 .push(Instruction::ray_query_get_intersection(
197 spirv::Op::RayQueryGetIntersectionInstanceShaderBindingTableRecordOffsetKHR,
198 flag_type_id,
199 sbt_record_offset_id,
200 query_id,
201 intersection_id,
202 ));
203 let geometry_index_id = self.id_gen.next();
204 not_none_block
205 .body
206 .push(Instruction::ray_query_get_intersection(
207 spirv::Op::RayQueryGetIntersectionGeometryIndexKHR,
208 flag_type_id,
209 geometry_index_id,
210 query_id,
211 intersection_id,
212 ));
213 let primitive_index_id = self.id_gen.next();
214 not_none_block
215 .body
216 .push(Instruction::ray_query_get_intersection(
217 spirv::Op::RayQueryGetIntersectionPrimitiveIndexKHR,
218 flag_type_id,
219 primitive_index_id,
220 query_id,
221 intersection_id,
222 ));
223
224 let object_to_world_id = self.id_gen.next();
228 not_none_block
229 .body
230 .push(Instruction::ray_query_get_intersection(
231 spirv::Op::RayQueryGetIntersectionObjectToWorldKHR,
232 transform_type_id,
233 object_to_world_id,
234 query_id,
235 intersection_id,
236 ));
237 let world_to_object_id = self.id_gen.next();
238 not_none_block
239 .body
240 .push(Instruction::ray_query_get_intersection(
241 spirv::Op::RayQueryGetIntersectionWorldToObjectKHR,
242 transform_type_id,
243 world_to_object_id,
244 query_id,
245 intersection_id,
246 ));
247
248 let idx_id = self.get_index_constant(2);
250 let access_idx = self.id_gen.next();
251 not_none_block.body.push(Instruction::access_chain(
252 flag_pointer_type_id,
253 access_idx,
254 blank_intersection_id,
255 &[idx_id],
256 ));
257 not_none_block.body.push(Instruction::store(
258 access_idx,
259 instance_custom_index_id,
260 None,
261 ));
262
263 let idx_id = self.get_index_constant(3);
265 let access_idx = self.id_gen.next();
266 not_none_block.body.push(Instruction::access_chain(
267 flag_pointer_type_id,
268 access_idx,
269 blank_intersection_id,
270 &[idx_id],
271 ));
272 not_none_block
273 .body
274 .push(Instruction::store(access_idx, instance_id, None));
275
276 let idx_id = self.get_index_constant(4);
277 let access_idx = self.id_gen.next();
278 not_none_block.body.push(Instruction::access_chain(
279 flag_pointer_type_id,
280 access_idx,
281 blank_intersection_id,
282 &[idx_id],
283 ));
284 not_none_block
285 .body
286 .push(Instruction::store(access_idx, sbt_record_offset_id, None));
287
288 let idx_id = self.get_index_constant(5);
289 let access_idx = self.id_gen.next();
290 not_none_block.body.push(Instruction::access_chain(
291 flag_pointer_type_id,
292 access_idx,
293 blank_intersection_id,
294 &[idx_id],
295 ));
296 not_none_block
297 .body
298 .push(Instruction::store(access_idx, geometry_index_id, None));
299
300 let idx_id = self.get_index_constant(6);
301 let access_idx = self.id_gen.next();
302 not_none_block.body.push(Instruction::access_chain(
303 flag_pointer_type_id,
304 access_idx,
305 blank_intersection_id,
306 &[idx_id],
307 ));
308 not_none_block
309 .body
310 .push(Instruction::store(access_idx, primitive_index_id, None));
311
312 let idx_id = self.get_index_constant(9);
313 let access_idx = self.id_gen.next();
314 not_none_block.body.push(Instruction::access_chain(
315 transform_pointer_type_id,
316 access_idx,
317 blank_intersection_id,
318 &[idx_id],
319 ));
320 not_none_block
321 .body
322 .push(Instruction::store(access_idx, object_to_world_id, None));
323
324 let idx_id = self.get_index_constant(10);
325 let access_idx = self.id_gen.next();
326 not_none_block.body.push(Instruction::access_chain(
327 transform_pointer_type_id,
328 access_idx,
329 blank_intersection_id,
330 &[idx_id],
331 ));
332 not_none_block
333 .body
334 .push(Instruction::store(access_idx, world_to_object_id, None));
335
336 let tri_comp_id = self.id_gen.next();
337 let tri_id = self.get_constant_scalar(crate::Literal::U32(
338 crate::RayQueryIntersection::Triangle as _,
339 ));
340 not_none_block.body.push(Instruction::binary(
341 spirv::Op::IEqual,
342 self.get_bool_type_id(),
343 tri_comp_id,
344 kind_id,
345 tri_id,
346 ));
347
348 let tri_label_id = self.id_gen.next();
349 let mut tri_block = Block::new(tri_label_id);
350
351 let merge_label_id = self.id_gen.next();
352 let merge_block = Block::new(merge_label_id);
353 {
355 let block = if is_committed {
356 &mut not_none_block
357 } else {
358 &mut tri_block
359 };
360 let t_id = self.id_gen.next();
361 block.body.push(Instruction::ray_query_get_intersection(
362 spirv::Op::RayQueryGetIntersectionTKHR,
363 scalar_type_id,
364 t_id,
365 query_id,
366 intersection_id,
367 ));
368 let idx_id = self.get_index_constant(1);
369 let access_idx = self.id_gen.next();
370 block.body.push(Instruction::access_chain(
371 float_pointer_type_id,
372 access_idx,
373 blank_intersection_id,
374 &[idx_id],
375 ));
376 block.body.push(Instruction::store(access_idx, t_id, None));
377 }
378 not_none_block.body.push(Instruction::selection_merge(
379 merge_label_id,
380 spirv::SelectionControl::NONE,
381 ));
382 function.consume(
383 not_none_block,
384 Instruction::branch_conditional(not_none_comp_id, tri_label_id, merge_label_id),
385 );
386
387 let barycentrics_id = self.id_gen.next();
388 tri_block.body.push(Instruction::ray_query_get_intersection(
389 spirv::Op::RayQueryGetIntersectionBarycentricsKHR,
390 barycentrics_type_id,
391 barycentrics_id,
392 query_id,
393 intersection_id,
394 ));
395
396 let front_face_id = self.id_gen.next();
397 tri_block.body.push(Instruction::ray_query_get_intersection(
398 spirv::Op::RayQueryGetIntersectionFrontFaceKHR,
399 bool_type_id,
400 front_face_id,
401 query_id,
402 intersection_id,
403 ));
404
405 let idx_id = self.get_index_constant(7);
406 let access_idx = self.id_gen.next();
407 tri_block.body.push(Instruction::access_chain(
408 barycentrics_pointer_type_id,
409 access_idx,
410 blank_intersection_id,
411 &[idx_id],
412 ));
413 tri_block
414 .body
415 .push(Instruction::store(access_idx, barycentrics_id, None));
416
417 let idx_id = self.get_index_constant(8);
418 let access_idx = self.id_gen.next();
419 tri_block.body.push(Instruction::access_chain(
420 bool_pointer_type_id,
421 access_idx,
422 blank_intersection_id,
423 &[idx_id],
424 ));
425 tri_block
426 .body
427 .push(Instruction::store(access_idx, front_face_id, None));
428 function.consume(tri_block, Instruction::branch(merge_label_id));
429 function.consume(merge_block, Instruction::branch(final_label_id));
430
431 let loaded_blank_intersection_id = self.id_gen.next();
432 final_block.body.push(Instruction::load(
433 intersection_type_id,
434 loaded_blank_intersection_id,
435 blank_intersection_id,
436 None,
437 ));
438 function.consume(
439 final_block,
440 Instruction::return_value(loaded_blank_intersection_id),
441 );
442
443 function.to_words(&mut self.logical_layout.function_definitions);
444 if is_committed {
445 self.ray_get_committed_intersection_function = Some(func_id);
446 } else {
447 self.ray_get_candidate_intersection_function = Some(func_id);
448 }
449 func_id
450 }
451}
452
453impl BlockContext<'_> {
454 pub(super) fn write_ray_query_function(
455 &mut self,
456 query: Handle<crate::Expression>,
457 function: &crate::RayQueryFunction,
458 block: &mut Block,
459 ) {
460 let query_id = self.cached[query];
461 match *function {
462 crate::RayQueryFunction::Initialize {
463 acceleration_structure,
464 descriptor,
465 } => {
466 let desc_id = self.cached[descriptor];
468 let acc_struct_id = self.get_handle_id(acceleration_structure);
469
470 let flag_type_id =
471 self.get_numeric_type_id(NumericType::Scalar(crate::Scalar::U32));
472 let ray_flags_id = self.gen_id();
473 block.body.push(Instruction::composite_extract(
474 flag_type_id,
475 ray_flags_id,
476 desc_id,
477 &[0],
478 ));
479 let cull_mask_id = self.gen_id();
480 block.body.push(Instruction::composite_extract(
481 flag_type_id,
482 cull_mask_id,
483 desc_id,
484 &[1],
485 ));
486
487 let scalar_type_id =
488 self.get_numeric_type_id(NumericType::Scalar(crate::Scalar::F32));
489 let tmin_id = self.gen_id();
490 block.body.push(Instruction::composite_extract(
491 scalar_type_id,
492 tmin_id,
493 desc_id,
494 &[2],
495 ));
496 let tmax_id = self.gen_id();
497 block.body.push(Instruction::composite_extract(
498 scalar_type_id,
499 tmax_id,
500 desc_id,
501 &[3],
502 ));
503
504 let vector_type_id = self.get_numeric_type_id(NumericType::Vector {
505 size: crate::VectorSize::Tri,
506 scalar: crate::Scalar::F32,
507 });
508 let ray_origin_id = self.gen_id();
509 block.body.push(Instruction::composite_extract(
510 vector_type_id,
511 ray_origin_id,
512 desc_id,
513 &[4],
514 ));
515 let ray_dir_id = self.gen_id();
516 block.body.push(Instruction::composite_extract(
517 vector_type_id,
518 ray_dir_id,
519 desc_id,
520 &[5],
521 ));
522
523 block.body.push(Instruction::ray_query_initialize(
524 query_id,
525 acc_struct_id,
526 ray_flags_id,
527 cull_mask_id,
528 ray_origin_id,
529 tmin_id,
530 ray_dir_id,
531 tmax_id,
532 ));
533 }
534 crate::RayQueryFunction::Proceed { result } => {
535 let id = self.gen_id();
536 self.cached[result] = id;
537 let result_type_id = self.get_expression_type_id(&self.fun_info[result].ty);
538
539 block
540 .body
541 .push(Instruction::ray_query_proceed(result_type_id, id, query_id));
542 }
543 crate::RayQueryFunction::GenerateIntersection { hit_t } => {
544 let hit_id = self.cached[hit_t];
545 block
546 .body
547 .push(Instruction::ray_query_generate_intersection(
548 query_id, hit_id,
549 ));
550 }
551 crate::RayQueryFunction::ConfirmIntersection => {
552 block
553 .body
554 .push(Instruction::ray_query_confirm_intersection(query_id));
555 }
556 crate::RayQueryFunction::Terminate => {}
557 }
558 }
559
560 pub(super) fn write_ray_query_return_vertex_position(
561 &mut self,
562 query: Handle<crate::Expression>,
563 block: &mut Block,
564 is_committed: bool,
565 ) -> spirv::Word {
566 let query_id = self.cached[query];
567 let id = self.gen_id();
568 let ray_vertex_return_ty = self
569 .ir_module
570 .special_types
571 .ray_vertex_return
572 .expect("type should have been populated");
573 let ray_vertex_return_ty_id = self.writer.get_handle_type_id(ray_vertex_return_ty);
574 let intersection_id =
575 self.writer
576 .get_constant_scalar(crate::Literal::U32(if is_committed {
577 spirv::RayQueryIntersection::RayQueryCommittedIntersectionKHR
578 } else {
579 spirv::RayQueryIntersection::RayQueryCandidateIntersectionKHR
580 } as _));
581 block
582 .body
583 .push(Instruction::ray_query_return_vertex_position(
584 ray_vertex_return_ty_id,
585 id,
586 query_id,
587 intersection_id,
588 ));
589 id
590 }
591}