1use alloc::{
2 format,
3 string::{String, ToString},
4 vec,
5 vec::Vec,
6};
7use core::fmt::Write;
8
9use crate::{
10 back::{hlsl::BackendResult, Baked, Level},
11 Handle,
12};
13use crate::{RayQueryIntersection, TypeInner};
14
15impl<W: Write> super::Writer<'_, W> {
16 fn write_not_finite(&mut self, expr: &str) -> BackendResult {
18 self.write_contains_flags(&format!("asuint({expr})"), 0x7f800000)
19 }
20
21 fn write_nan(&mut self, expr: &str) -> BackendResult {
22 write!(self.out, "(")?;
23 self.write_not_finite(expr)?;
24 write!(self.out, " && ((asuint({expr}) & 0x7fffff) != 0))")?;
25 Ok(())
26 }
27
28 fn write_contains_flags(&mut self, expr: &str, flags: u32) -> BackendResult {
29 write!(self.out, "(({expr} & {flags}) == {flags})")?;
30 Ok(())
31 }
32
33 pub(super) fn write_ray_desc_from_ray_desc_constructor_function(
35 &mut self,
36 module: &crate::Module,
37 ) -> BackendResult {
38 write!(self.out, "RayDesc RayDescFromRayDesc_(")?;
39 self.write_type(module, module.special_types.ray_desc.unwrap())?;
40 writeln!(self.out, " arg0) {{")?;
41 writeln!(self.out, " RayDesc ret = (RayDesc)0;")?;
42 writeln!(self.out, " ret.Origin = arg0.origin;")?;
43 writeln!(self.out, " ret.TMin = arg0.tmin;")?;
44 writeln!(self.out, " ret.Direction = arg0.dir;")?;
45 writeln!(self.out, " ret.TMax = arg0.tmax;")?;
46 writeln!(self.out, " return ret;")?;
47 writeln!(self.out, "}}")?;
48 writeln!(self.out)?;
49 Ok(())
50 }
51 pub(super) fn write_committed_intersection_function(
52 &mut self,
53 module: &crate::Module,
54 ) -> BackendResult {
55 self.write_type(module, module.special_types.ray_intersection.unwrap())?;
56 write!(self.out, " GetCommittedIntersection(")?;
57 self.write_value_type(
58 module,
59 &TypeInner::RayQuery {
60 vertex_return: false,
61 },
62 )?;
63 write!(self.out, " rq, ")?;
64 self.write_value_type(module, &TypeInner::Scalar(crate::Scalar::U32))?;
65 writeln!(self.out, " rq_tracker) {{")?;
66 write!(self.out, " ")?;
67 self.write_type(module, module.special_types.ray_intersection.unwrap())?;
68 write!(self.out, " ret = (")?;
69 self.write_type(module, module.special_types.ray_intersection.unwrap())?;
70 writeln!(self.out, ")0;")?;
71 let mut extra_level = Level(0);
72 if self.options.ray_query_initialization_tracking {
73 write!(self.out, " if (")?;
77 self.write_contains_flags(
78 "rq_tracker",
79 crate::back::RayQueryPoint::FINISHED_TRAVERSAL.bits(),
80 )?;
81 writeln!(self.out, ") {{")?;
82 extra_level = extra_level.next();
83 }
84 writeln!(
85 self.out,
86 " {extra_level}ret.kind = rq.CommittedStatus();"
87 )?;
88 writeln!(
89 self.out,
90 " {extra_level}if( rq.CommittedStatus() == COMMITTED_NOTHING) {{}} else {{"
91 )?;
92 writeln!(self.out, " {extra_level}ret.t = rq.CommittedRayT();")?;
93 writeln!(
94 self.out,
95 " {extra_level}ret.instance_custom_data = rq.CommittedInstanceID();"
96 )?;
97 writeln!(
98 self.out,
99 " {extra_level}ret.instance_index = rq.CommittedInstanceIndex();"
100 )?;
101 writeln!(
102 self.out,
103 " {extra_level}ret.sbt_record_offset = rq.CommittedInstanceContributionToHitGroupIndex();"
104 )?;
105 writeln!(
106 self.out,
107 " {extra_level}ret.geometry_index = rq.CommittedGeometryIndex();"
108 )?;
109 writeln!(
110 self.out,
111 " {extra_level}ret.primitive_index = rq.CommittedPrimitiveIndex();"
112 )?;
113 writeln!(
114 self.out,
115 " {extra_level}if( rq.CommittedStatus() == COMMITTED_TRIANGLE_HIT ) {{"
116 )?;
117 writeln!(
118 self.out,
119 " {extra_level}ret.barycentrics = rq.CommittedTriangleBarycentrics();"
120 )?;
121 writeln!(
122 self.out,
123 " {extra_level}ret.front_face = rq.CommittedTriangleFrontFace();"
124 )?;
125 writeln!(self.out, " {extra_level}}}")?;
126 writeln!(
127 self.out,
128 " {extra_level}ret.object_to_world = rq.CommittedObjectToWorld4x3();"
129 )?;
130 writeln!(
131 self.out,
132 " {extra_level}ret.world_to_object = rq.CommittedWorldToObject4x3();"
133 )?;
134 writeln!(self.out, " {extra_level}}}")?;
135 if self.options.ray_query_initialization_tracking {
136 writeln!(self.out, " }}")?;
137 }
138 writeln!(self.out, " return ret;")?;
139 writeln!(self.out, "}}")?;
140 writeln!(self.out)?;
141 Ok(())
142 }
143 pub(super) fn write_candidate_intersection_function(
144 &mut self,
145 module: &crate::Module,
146 ) -> BackendResult {
147 self.write_type(module, module.special_types.ray_intersection.unwrap())?;
148 write!(self.out, " GetCandidateIntersection(")?;
149 self.write_value_type(
150 module,
151 &TypeInner::RayQuery {
152 vertex_return: false,
153 },
154 )?;
155 write!(self.out, " rq, ")?;
156 self.write_value_type(module, &TypeInner::Scalar(crate::Scalar::U32))?;
157 writeln!(self.out, " rq_tracker) {{")?;
158 write!(self.out, " ")?;
159 self.write_type(module, module.special_types.ray_intersection.unwrap())?;
160 write!(self.out, " ret = (")?;
161 self.write_type(module, module.special_types.ray_intersection.unwrap())?;
162 writeln!(self.out, ")0;")?;
163 let mut extra_level = Level(0);
164 if self.options.ray_query_initialization_tracking {
165 write!(self.out, " if (")?;
166 self.write_contains_flags("rq_tracker", crate::back::RayQueryPoint::PROCEED.bits())?;
167 write!(self.out, " && !")?;
168 self.write_contains_flags(
169 "rq_tracker",
170 crate::back::RayQueryPoint::FINISHED_TRAVERSAL.bits(),
171 )?;
172 writeln!(self.out, ") {{")?;
173 extra_level = extra_level.next();
174 }
175 writeln!(
176 self.out,
177 " {extra_level}CANDIDATE_TYPE kind = rq.CandidateType();"
178 )?;
179 writeln!(
180 self.out,
181 " {extra_level}if (kind == CANDIDATE_NON_OPAQUE_TRIANGLE) {{"
182 )?;
183 writeln!(
184 self.out,
185 " {extra_level}ret.kind = {};",
186 RayQueryIntersection::Triangle as u32
187 )?;
188 writeln!(
189 self.out,
190 " {extra_level}ret.t = rq.CandidateTriangleRayT();"
191 )?;
192 writeln!(
193 self.out,
194 " {extra_level}ret.barycentrics = rq.CandidateTriangleBarycentrics();"
195 )?;
196 writeln!(
197 self.out,
198 " {extra_level}ret.front_face = rq.CandidateTriangleFrontFace();"
199 )?;
200 writeln!(self.out, " {extra_level}}} else {{")?;
201 writeln!(
202 self.out,
203 " {extra_level}ret.kind = {};",
204 RayQueryIntersection::Aabb as u32
205 )?;
206 writeln!(self.out, " {extra_level}}}")?;
207
208 writeln!(
209 self.out,
210 " {extra_level}ret.instance_custom_data = rq.CandidateInstanceID();"
211 )?;
212 writeln!(
213 self.out,
214 " {extra_level}ret.instance_index = rq.CandidateInstanceIndex();"
215 )?;
216 writeln!(
217 self.out,
218 " {extra_level}ret.sbt_record_offset = rq.CandidateInstanceContributionToHitGroupIndex();"
219 )?;
220 writeln!(
221 self.out,
222 " {extra_level}ret.geometry_index = rq.CandidateGeometryIndex();"
223 )?;
224 writeln!(
225 self.out,
226 " {extra_level}ret.primitive_index = rq.CandidatePrimitiveIndex();"
227 )?;
228 writeln!(
229 self.out,
230 " {extra_level}ret.object_to_world = rq.CandidateObjectToWorld4x3();"
231 )?;
232 writeln!(
233 self.out,
234 " {extra_level}ret.world_to_object = rq.CandidateWorldToObject4x3();"
235 )?;
236 if self.options.ray_query_initialization_tracking {
237 writeln!(self.out, " }}")?;
238 }
239 writeln!(self.out, " return ret;")?;
240 writeln!(self.out, "}}")?;
241 writeln!(self.out)?;
242 Ok(())
243 }
244
245 #[expect(clippy::too_many_arguments)]
246 pub(super) fn write_initialize_function(
247 &mut self,
248 module: &crate::Module,
249 mut level: Level,
250 query: Handle<crate::Expression>,
251 acceleration_structure: Handle<crate::Expression>,
252 descriptor: Handle<crate::Expression>,
253 rq_tracker: &str,
254 func_ctx: &crate::back::FunctionCtx<'_>,
255 ) -> BackendResult {
256 let base_level = level;
257
258 writeln!(self.out, "{level}{{")?;
260 level = level.next();
261 write!(self.out, "{level}")?;
262 self.write_type(
263 module,
264 module
265 .special_types
266 .ray_desc
267 .expect("should have been generated"),
268 )?;
269 write!(self.out, " naga_desc = ")?;
270 self.write_expr(module, descriptor, func_ctx)?;
271 writeln!(self.out, ";")?;
272
273 if self.options.ray_query_initialization_tracking {
274 writeln!(self.out, "{level}float naga_tmin = naga_desc.tmin;")?;
278 writeln!(self.out, "{level}float naga_tmax = naga_desc.tmax;")?;
279 writeln!(self.out, "{level}float3 naga_origin = naga_desc.origin;")?;
280 writeln!(self.out, "{level}float3 naga_dir = naga_desc.dir;")?;
281 writeln!(self.out, "{level}uint naga_flags = naga_desc.flags;")?;
282 write!(
283 self.out,
284 "{level}bool naga_tmin_valid = (naga_tmin >= 0.0) && (naga_tmin <= naga_tmax) && !"
285 )?;
286 self.write_nan("naga_tmin")?;
287 writeln!(self.out, ";")?;
288 write!(self.out, "{level}bool naga_tmax_valid = !")?;
289 self.write_nan("naga_tmax")?;
290 writeln!(self.out, ";")?;
291 write!(self.out, "{level}bool naga_origin_valid = !any(")?;
293 self.write_nan("naga_origin")?;
294 writeln!(self.out, ");")?;
295 write!(self.out, "{level}bool naga_dir_valid = !any(")?;
296 self.write_nan("naga_dir")?;
297 writeln!(self.out, ");")?;
298 write!(self.out, "{level}bool naga_contains_opaque = ")?;
299 self.write_contains_flags("naga_flags", crate::RayFlag::FORCE_OPAQUE.bits())?;
300 writeln!(self.out, ";")?;
301 write!(self.out, "{level}bool naga_contains_no_opaque = ")?;
302 self.write_contains_flags("naga_flags", crate::RayFlag::FORCE_NO_OPAQUE.bits())?;
303 writeln!(self.out, ";")?;
304 write!(self.out, "{level}bool naga_contains_cull_opaque = ")?;
305 self.write_contains_flags("naga_flags", crate::RayFlag::CULL_OPAQUE.bits())?;
306 writeln!(self.out, ";")?;
307 write!(self.out, "{level}bool naga_contains_cull_no_opaque = ")?;
308 self.write_contains_flags("naga_flags", crate::RayFlag::CULL_NO_OPAQUE.bits())?;
309 writeln!(self.out, ";")?;
310 write!(self.out, "{level}bool naga_contains_cull_front = ")?;
311 self.write_contains_flags("naga_flags", crate::RayFlag::CULL_FRONT_FACING.bits())?;
312 writeln!(self.out, ";")?;
313 write!(self.out, "{level}bool naga_contains_cull_back = ")?;
314 self.write_contains_flags("naga_flags", crate::RayFlag::CULL_BACK_FACING.bits())?;
315 writeln!(self.out, ";")?;
316 write!(self.out, "{level}bool naga_contains_skip_triangles = ")?;
317 self.write_contains_flags("naga_flags", crate::RayFlag::SKIP_TRIANGLES.bits())?;
318 writeln!(self.out, ";")?;
319 write!(self.out, "{level}bool naga_contains_skip_aabbs = ")?;
320 self.write_contains_flags("naga_flags", crate::RayFlag::SKIP_AABBS.bits())?;
321 writeln!(self.out, ";")?;
322 fn less_than_two_true(mut bools: Vec<&str>) -> Result<String, super::Error> {
324 assert!(bools.len() > 1, "Must have multiple booleans!");
325 let mut final_expr = String::new();
326 while let Some(last_bool) = bools.pop() {
327 for &bool in &bools {
328 if !final_expr.is_empty() {
329 final_expr.push_str("||");
330 }
331 write!(final_expr, " ({last_bool} && {bool}) ")?;
332 }
333 }
334 Ok(final_expr)
335 }
336 writeln!(
337 self.out,
338 "{level}bool naga_contains_skip_triangles_aabbs = {};",
339 less_than_two_true(vec![
340 "naga_contains_skip_triangles",
341 "naga_contains_skip_aabbs"
342 ])?
343 )?;
344 writeln!(
345 self.out,
346 "{level}bool naga_contains_skip_triangles_cull = {};",
347 less_than_two_true(vec![
348 "naga_contains_skip_triangles",
349 "naga_contains_cull_back",
350 "naga_contains_cull_front"
351 ])?
352 )?;
353 writeln!(
354 self.out,
355 "{level}bool naga_contains_multiple_opaque = {};",
356 less_than_two_true(vec![
357 "naga_contains_opaque",
358 "naga_contains_no_opaque",
359 "naga_contains_cull_opaque",
360 "naga_contains_cull_no_opaque"
361 ])?
362 )?;
363 writeln!(
364 self.out,
365 "{level}if (naga_tmin_valid && naga_tmax_valid && naga_origin_valid && naga_dir_valid && !(naga_contains_skip_triangles_aabbs || naga_contains_skip_triangles_cull || naga_contains_multiple_opaque)) {{"
366 )?;
367 level = level.next();
368 writeln!(
369 self.out,
370 "{level}{rq_tracker} = {rq_tracker} | {};",
371 crate::back::RayQueryPoint::INITIALIZED.bits()
372 )?;
373 }
374 write!(self.out, "{level}")?;
375 self.write_expr(module, query, func_ctx)?;
376 write!(self.out, ".TraceRayInline(")?;
377 self.write_expr(module, acceleration_structure, func_ctx)?;
378 writeln!(
379 self.out,
380 ", naga_desc.flags, naga_desc.cull_mask, RayDescFromRayDesc_(naga_desc));"
381 )?;
382 if self.options.ray_query_initialization_tracking {
383 writeln!(self.out, "{base_level} }}")?;
384 }
385 writeln!(self.out, "{base_level}}}")?;
386 Ok(())
387 }
388
389 pub(super) fn write_proceed(
390 &mut self,
391 module: &crate::Module,
392 mut level: Level,
393 query: Handle<crate::Expression>,
394 result: Handle<crate::Expression>,
395 rq_tracker: &str,
396 func_ctx: &crate::back::FunctionCtx<'_>,
397 ) -> BackendResult {
398 let base_level = level;
399 write!(self.out, "{level}")?;
400 let name = Baked(result).to_string();
401 writeln!(self.out, "bool {name} = false;")?;
402 if self.options.ray_query_initialization_tracking {
404 writeln!(self.out, "{level}{{")?;
405 level = level.next();
406 write!(self.out, "{level}bool naga_has_initialized = ")?;
407 self.write_contains_flags(rq_tracker, crate::back::RayQueryPoint::INITIALIZED.bits())?;
408 writeln!(self.out, ";")?;
409 write!(self.out, "{level}bool naga_has_finished = ")?;
410 self.write_contains_flags(
411 rq_tracker,
412 crate::back::RayQueryPoint::FINISHED_TRAVERSAL.bits(),
413 )?;
414 writeln!(self.out, ";")?;
415 writeln!(
416 self.out,
417 "{level}if (naga_has_initialized && !naga_has_finished) {{"
418 )?;
419 level = level.next();
420 }
421
422 write!(self.out, "{level}{name} = ")?;
423 self.write_expr(module, query, func_ctx)?;
424 writeln!(self.out, ".Proceed();")?;
425
426 if self.options.ray_query_initialization_tracking {
427 writeln!(
428 self.out,
429 "{level}{rq_tracker} = {rq_tracker} | {};",
430 crate::back::RayQueryPoint::PROCEED.bits()
431 )?;
432 writeln!(
433 self.out,
434 "{level}if (!{name}) {{ {rq_tracker} = {rq_tracker} | {}; }}",
435 crate::back::RayQueryPoint::FINISHED_TRAVERSAL.bits()
436 )?;
437 writeln!(self.out, "{base_level}}}}}")?;
438 }
439
440 self.named_expressions.insert(result, name);
441
442 Ok(())
443 }
444
445 pub(super) fn write_generate_intersection(
446 &mut self,
447 module: &crate::Module,
448 mut level: Level,
449 query: Handle<crate::Expression>,
450 hit_t: Handle<crate::Expression>,
451 rq_tracker: &str,
452 func_ctx: &crate::back::FunctionCtx<'_>,
453 ) -> BackendResult {
454 let base_level = level;
455 if self.options.ray_query_initialization_tracking {
456 write!(self.out, "{level}if (")?;
457 self.write_contains_flags(rq_tracker, crate::back::RayQueryPoint::PROCEED.bits())?;
458 write!(self.out, " && !")?;
459 self.write_contains_flags(
460 rq_tracker,
461 crate::back::RayQueryPoint::FINISHED_TRAVERSAL.bits(),
462 )?;
463 writeln!(self.out, ") {{")?;
464 level = level.next();
465 write!(self.out, "{level}CANDIDATE_TYPE naga_kind = ")?;
466 self.write_expr(module, query, func_ctx)?;
467 writeln!(self.out, ".CandidateType();")?;
468 write!(self.out, "{level}float naga_tmin = ")?;
469 self.write_expr(module, query, func_ctx)?;
470 writeln!(self.out, ".RayTMin();")?;
471 write!(self.out, "{level}float naga_tcurrentmax = ")?;
472 self.write_expr(module, query, func_ctx)?;
473 writeln!(self.out, ".CommittedRayT();")?;
477 write!(
478 self.out,
479 "{level}if ((naga_kind == CANDIDATE_PROCEDURAL_PRIMITIVE) && (naga_tmin <="
480 )?;
481 self.write_expr(module, hit_t, func_ctx)?;
482 write!(self.out, ") && (")?;
483 self.write_expr(module, hit_t, func_ctx)?;
484 writeln!(self.out, " <= naga_tcurrentmax)) {{")?;
485 level = level.next();
486 }
487
488 write!(self.out, "{level}")?;
489 self.write_expr(module, query, func_ctx)?;
490 write!(self.out, ".CommitProceduralPrimitiveHit(")?;
491 self.write_expr(module, hit_t, func_ctx)?;
492 writeln!(self.out, ");")?;
493 if self.options.ray_query_initialization_tracking {
494 writeln!(self.out, "{base_level}}}}}")?;
495 }
496 Ok(())
497 }
498 pub(super) fn write_confirm_intersection(
499 &mut self,
500 module: &crate::Module,
501 mut level: Level,
502 query: Handle<crate::Expression>,
503 rq_tracker: &str,
504 func_ctx: &crate::back::FunctionCtx<'_>,
505 ) -> BackendResult {
506 let base_level = level;
507 if self.options.ray_query_initialization_tracking {
508 write!(self.out, "{level}if (")?;
509 self.write_contains_flags(rq_tracker, crate::back::RayQueryPoint::PROCEED.bits())?;
510 write!(self.out, " && !")?;
511 self.write_contains_flags(
512 rq_tracker,
513 crate::back::RayQueryPoint::FINISHED_TRAVERSAL.bits(),
514 )?;
515 writeln!(self.out, ") {{")?;
516 level = level.next();
517 write!(self.out, "{level}CANDIDATE_TYPE naga_kind = ")?;
518 self.write_expr(module, query, func_ctx)?;
519 writeln!(self.out, ".CandidateType();")?;
520 writeln!(
521 self.out,
522 "{level}if (naga_kind == CANDIDATE_NON_OPAQUE_TRIANGLE) {{"
523 )?;
524 level = level.next();
525 }
526
527 write!(self.out, "{level}")?;
528 self.write_expr(module, query, func_ctx)?;
529 writeln!(self.out, ".CommitNonOpaqueTriangleHit();")?;
530 if self.options.ray_query_initialization_tracking {
531 writeln!(self.out, "{base_level}}}}}")?;
532 }
533 Ok(())
534 }
535
536 pub(super) fn write_terminate(
537 &mut self,
538 module: &crate::Module,
539 mut level: Level,
540 query: Handle<crate::Expression>,
541 rq_tracker: &str,
542 func_ctx: &crate::back::FunctionCtx<'_>,
543 ) -> BackendResult {
544 let base_level = level;
545 if self.options.ray_query_initialization_tracking {
546 write!(self.out, "{level}if (")?;
547 self.write_contains_flags(rq_tracker, crate::back::RayQueryPoint::INITIALIZED.bits())?;
550 writeln!(self.out, ") {{")?;
551 level = level.next();
552 }
553
554 write!(self.out, "{level}")?;
555 self.write_expr(module, query, func_ctx)?;
556 writeln!(self.out, ".Abort();")?;
557
558 if self.options.ray_query_initialization_tracking {
559 writeln!(self.out, "{base_level}}}")?;
560 }
561
562 Ok(())
563 }
564}