1use alloc::{
2 format,
3 string::{String, ToString},
4};
5use core::fmt::Write;
6
7use crate::{
8 back::{
9 self,
10 msl::{
11 writer::{NameKeyExt, StatementContext, TypeContext, WrappedFunction},
12 BackendResult, Error, Writer, NAMESPACE,
13 },
14 Baked, INDENT,
15 },
16 Handle,
17};
18
19pub(super) const RT_NAMESPACE: &str = "metal::raytracing";
20
21pub(super) fn metal_intersector_ty() -> String {
23 format!("{RT_NAMESPACE}::intersection_query<{RT_NAMESPACE}::instancing, {RT_NAMESPACE}::triangle_data>")
24}
25
26pub(super) const INTERSECTION_FUNCTION_NAME: &str = "ray_query_get_intersection";
27pub(crate) const RAY_QUERY_TRACKER_VARIABLE_PREFIX: &str = "naga_query_init_tracker_for_";
28pub(crate) const RAY_QUERY_T_MAX_TRACKER_VARIABLE_PREFIX: &str = "naga_query_tmax_tracker_for_";
29
30impl<W: Write> Writer<W> {
31 fn write_not_finite(&mut self, expr: &str) -> BackendResult {
32 self.write_contains_flags(&format!("as_type<uint>({expr})"), 0x7f800000)
33 }
34
35 fn write_is_nan(&mut self, expr: &str) -> BackendResult {
39 write!(self.out, "(")?;
40 self.write_not_finite(expr)?;
41 write!(self.out, " && ((as_type<uint>({expr}) & 0x7fffff) != 0))")?;
42 Ok(())
43 }
44
45 fn write_contains_flags(&mut self, expr: &str, flags: u32) -> BackendResult {
46 write!(self.out, "(({expr} & {flags}) == {flags})")?;
47 Ok(())
48 }
49
50 pub(super) fn write_rq_get_intersection_function(
56 &mut self,
57 module: &crate::Module,
58 committed: bool,
59 options: &super::Options,
60 ) -> BackendResult {
61 let wrapped = WrappedFunction::RayQueryGetIntersection { committed };
62 if !self.wrapped_functions.insert(wrapped) {
63 return Ok(());
64 }
65
66 let ty = if committed { "committed" } else { "candidate" };
67 let intersection = TypeContext {
68 handle: module
69 .special_types
70 .ray_intersection
71 .expect("intersection ty should be there for intersection function"),
72 gctx: module.to_ctx(),
73 names: &self.names,
74 access: crate::StorageAccess::empty(),
75 first_time: false,
76 };
77 let mut base_level = back::Level(1);
78 writeln!(
79 self.out,
80 "{intersection} {INTERSECTION_FUNCTION_NAME}_{committed}({} intersector",
81 metal_intersector_ty()
82 )?;
83 if options.ray_query_initialization_tracking {
84 writeln!(self.out, ", uint intersector_tracker")?;
85 }
86 writeln!(self.out, ") {{")?;
87 writeln!(
89 self.out,
90 "{base_level}{intersection} intersection = {intersection} {{}};"
91 )?;
92
93 if options.ray_query_initialization_tracking {
94 write!(self.out, "{base_level}if (")?;
95 if committed {
96 self.write_contains_flags(
97 "intersector_tracker",
98 back::RayQueryPoint::FINISHED_TRAVERSAL.bits(),
99 )?;
100 } else {
101 self.write_contains_flags(
102 "intersector_tracker",
103 back::RayQueryPoint::PROCEED.bits(),
104 )?;
105 write!(self.out, " && !")?;
106 self.write_contains_flags(
107 "intersector_tracker",
108 back::RayQueryPoint::FINISHED_TRAVERSAL.bits(),
109 )?;
110 }
111 writeln!(self.out, ") {{")?;
112 base_level = base_level.next();
113 }
114
115 writeln!(self.out, "{base_level}{RT_NAMESPACE}::intersection_type ty = intersector.get_{ty}_intersection_type();")?;
116 writeln!(
118 self.out,
119 "{base_level}if (ty == {RT_NAMESPACE}::intersection_type::triangle) {{"
120 )?;
121 writeln!(
122 self.out,
123 "{base_level}{INDENT}intersection.kind = {};",
124 crate::RayQueryIntersection::Triangle as u32
125 )?;
126 if !committed {
127 writeln!(
128 self.out,
129 "{base_level}{INDENT}intersection.t = intersector.get_candidate_triangle_distance();"
130 )?;
131 }
132 writeln!(self.out, "{base_level}{INDENT}intersection.barycentrics = intersector.get_{ty}_triangle_barycentric_coord();")?;
133 writeln!(
134 self.out,
135 "{base_level}{INDENT}intersection.front_face = intersector.is_{ty}_triangle_front_facing();"
136 )?;
137 writeln!(
140 self.out,
141 "{base_level}}} else if (ty == {RT_NAMESPACE}::intersection_type::bounding_box) {{"
142 )?;
143 if committed {
144 writeln!(
145 self.out,
146 "{base_level}{INDENT}intersection.kind = {};",
147 crate::RayQueryIntersection::Generated as u32
148 )?;
149 } else {
150 writeln!(
151 self.out,
152 "{base_level}{INDENT}intersection.kind = {};",
153 crate::RayQueryIntersection::Aabb as u32
154 )?;
155 }
156 writeln!(self.out, "{base_level}}}")?;
157
158 writeln!(
160 self.out,
161 "{base_level}if (ty != {RT_NAMESPACE}::intersection_type::none) {{"
162 )?;
163 if committed {
164 writeln!(
165 self.out,
166 "{base_level}{INDENT}intersection.t = intersector.get_committed_distance();"
167 )?;
168 }
169 writeln!(self.out, "{base_level}{INDENT}intersection.instance_custom_data = intersector.get_{ty}_user_instance_id();")?;
170 writeln!(
171 self.out,
172 "{base_level}{INDENT}intersection.instance_index = intersector.get_{ty}_instance_id();"
173 )?;
174 writeln!(
177 self.out,
178 "{base_level}{INDENT}intersection.geometry_index = intersector.get_{ty}_geometry_id();"
179 )?;
180 writeln!(
181 self.out,
182 "{base_level}{INDENT}intersection.primitive_index = intersector.get_{ty}_primitive_id();"
183 )?;
184 writeln!(self.out, "{base_level}{INDENT}intersection.object_to_world = intersector.get_{ty}_object_to_world_transform();")?;
185 writeln!(self.out, "{base_level}{INDENT}intersection.world_to_object = intersector.get_{ty}_world_to_object_transform();")?;
186 writeln!(self.out, "{base_level}}}")?;
187
188 if options.ray_query_initialization_tracking {
189 writeln!(self.out, "{INDENT}}}")?;
190 }
191
192 writeln!(self.out, "{INDENT}return intersection;")?;
193 writeln!(self.out, "}}")?;
194
195 Ok(())
196 }
197
198 pub(super) fn write_ray_query_stmt(
199 &mut self,
200 level: back::Level,
201 context: &StatementContext,
202 query: Handle<crate::Expression>,
203 fun: &crate::RayQueryFunction,
204 ) -> BackendResult {
205 if context.expression.lang_version < (2, 4) {
206 return Err(Error::UnsupportedRayTracing);
207 }
208
209 let crate::Expression::LocalVariable(query_var) =
221 context.expression.function.expressions[query]
222 else {
223 unreachable!()
224 };
225
226 let tracker_expr_name = format!(
227 "{RAY_QUERY_TRACKER_VARIABLE_PREFIX}{}",
228 self.names[&crate::proc::NameKey::local(context.expression.origin, query_var)]
229 );
230
231 let tmax_tracker_expr_name = format!(
232 "{RAY_QUERY_T_MAX_TRACKER_VARIABLE_PREFIX}{}",
233 self.names[&crate::proc::NameKey::local(context.expression.origin, query_var)]
234 );
235
236 match *fun {
238 crate::RayQueryFunction::Initialize {
239 acceleration_structure,
240 descriptor,
241 } => {
242 writeln!(self.out, "{level}{{")?;
247
248 let inner_level = level.next();
249
250 let naga_ray_desc_ty = TypeContext {
251 handle: context
252 .expression
253 .module
254 .special_types
255 .ray_desc
256 .expect("ray desc is required as an argument so should be there"),
257 gctx: context.expression.module.to_ctx(),
258 names: &self.names,
259 access: crate::StorageAccess::empty(),
260 first_time: false,
261 };
262
263 write!(self.out, "{inner_level}{naga_ray_desc_ty} desc = ")?;
264 self.put_expression(descriptor, &context.expression, false)?;
265 writeln!(self.out, ";")?;
266
267 writeln!(
269 self.out,
270 "{inner_level}{RT_NAMESPACE}::intersection_params params;"
271 )?;
272
273 {
274 let f_opaque = back::RayFlag::CULL_OPAQUE.bits();
276 let f_no_opaque = back::RayFlag::CULL_NO_OPAQUE.bits();
277 writeln!(self.out, "{inner_level}{RT_NAMESPACE}::opacity_cull_mode cull_mode =
278{inner_level}{INDENT}(desc.flags & {f_opaque}) != 0 ? {RT_NAMESPACE}::opacity_cull_mode::opaque : (
279{inner_level}{INDENT}{INDENT}(desc.flags & {f_no_opaque}) != 0 ? {RT_NAMESPACE}::opacity_cull_mode::non_opaque : {RT_NAMESPACE}::opacity_cull_mode::none
280{inner_level}{INDENT});")?;
281 writeln!(
282 self.out,
283 "{inner_level}params.set_opacity_cull_mode(cull_mode);"
284 )?;
285
286 if context.expression.ray_query_initialization_tracking {
287 writeln!(self.out, "{inner_level}bool force_opacity = cull_mode == {RT_NAMESPACE}::opacity_cull_mode::none;")?;
288 }
289 }
290 {
291 let mut current_level = inner_level;
292 if context.expression.ray_query_initialization_tracking {
293 writeln!(self.out, "{inner_level}if (force_opacity) {{")?;
294 current_level = current_level.next();
295 }
296 let f_opaque = back::RayFlag::OPAQUE.bits();
298 let f_no_opaque = back::RayFlag::NO_OPAQUE.bits();
299 writeln!(self.out, "{current_level}params.force_opacity(
300{current_level} (desc.flags & {f_opaque}) != 0 ? {RT_NAMESPACE}::forced_opacity::opaque : (
301{current_level} (desc.flags & {f_no_opaque}) != 0 ? {RT_NAMESPACE}::forced_opacity::non_opaque : {RT_NAMESPACE}::forced_opacity::none
302{current_level} )
303{current_level});")?;
304
305 if context.expression.ray_query_initialization_tracking {
306 writeln!(self.out, "{inner_level}}}")?;
307 }
308 }
309 {
310 let flag = back::RayFlag::TERMINATE_ON_FIRST_HIT.bits();
311 writeln!(
312 self.out,
313 "{inner_level}params.accept_any_intersection((desc.flags & {flag}) != 0);"
314 )?;
315 }
316
317 writeln!(
318 self.out,
319 "{inner_level}{RT_NAMESPACE}::ray ray = {RT_NAMESPACE}::ray(desc.origin, desc.dir, desc.tmin, desc.tmax);"
320 )?;
321
322 let mut init_level = inner_level;
323
324 if context.expression.ray_query_initialization_tracking {
327 write!(self.out, "{inner_level}bool invalid_nan_infs = ")?;
328 for (idx, &field_access) in [
330 "origin.x", "origin.y", "origin.z", "dir.x", "dir.y", "dir.z", "tmin",
331 ]
332 .iter()
333 .enumerate()
334 {
335 if idx != 0 {
336 write!(self.out, " || ")?;
337 }
338
339 self.write_not_finite(&format!("desc.{field_access}"))?;
340 }
341
342 write!(self.out, " || ")?;
343 self.write_is_nan("desc.tmax")?;
344 writeln!(self.out, ";")?;
345
346 writeln!(self.out, "{inner_level}bool invalid_t = (desc.tmin > desc.tmax) || (desc.tmin < 0.0);")?;
348 writeln!(self.out, "{inner_level}bool invalid_dir = {NAMESPACE}::all({NAMESPACE}::abs(desc.dir) == 0.0);")?;
353
354 writeln!(
355 self.out,
356 "{inner_level}if (!(invalid_dir || invalid_t || invalid_nan_infs)) {{"
357 )?;
358 init_level = init_level.next();
359 }
360
361 write!(self.out, "{init_level}")?;
362 self.put_expression(query, &context.expression, true)?;
366 write!(self.out, ".reset(ray,")?;
367 self.put_expression(acceleration_structure, &context.expression, true)?;
368 writeln!(self.out, ", desc.cull_mask, params);")?;
369 if context.expression.ray_query_initialization_tracking {
370 writeln!(
374 self.out,
375 "{init_level}{tracker_expr_name} = {};",
376 back::RayQueryPoint::INITIALIZED.bits()
377 )?;
378 writeln!(
379 self.out,
380 "{init_level}{tmax_tracker_expr_name} = desc.tmax;"
381 )?;
382 writeln!(self.out, "{inner_level}}}")?;
383 }
384 writeln!(self.out, "{level}}}")?;
385 }
386 crate::RayQueryFunction::Proceed { result } => {
387 let mut current_level = level;
388 write!(self.out, "{current_level}")?;
389 let name = Baked(result).to_string();
390 self.start_baking_expression(result, &context.expression, &name)?;
391 self.named_expressions.insert(result, name.clone());
392
393 writeln!(self.out, "false;")?;
394
395 if context.expression.ray_query_initialization_tracking {
396 write!(self.out, "{level}if (")?;
397 self.write_contains_flags(
398 &tracker_expr_name,
399 back::RayQueryPoint::INITIALIZED.bits(),
400 )?;
401 write!(self.out, " && !")?;
402 self.write_contains_flags(
403 &tracker_expr_name,
404 back::RayQueryPoint::FINISHED_TRAVERSAL.bits(),
405 )?;
406 write!(self.out, ")")?;
407 writeln!(self.out, " {{")?;
408 current_level = current_level.next();
409 }
410 write!(self.out, "{current_level}{name} = ")?;
411 self.put_expression(query, &context.expression, true)?;
412 writeln!(self.out, ".next();")?;
413 if context.expression.ray_query_initialization_tracking {
414 writeln!(self.out, "{current_level}{tracker_expr_name} = {tracker_expr_name} | ({name} ? {}: {});", back::RayQueryPoint::PROCEED.bits(), (back::RayQueryPoint::PROCEED | back::RayQueryPoint::FINISHED_TRAVERSAL).bits())?;
415 writeln!(self.out, "{level}}}")?;
416 }
417 }
418 crate::RayQueryFunction::GenerateIntersection { hit_t } => {
419 let mut current_level = level;
420 if context.expression.ray_query_initialization_tracking {
421 write!(self.out, "{level}if (")?;
422 self.write_contains_flags(
423 &tracker_expr_name,
424 back::RayQueryPoint::PROCEED.bits(),
425 )?;
426 write!(self.out, " && !")?;
427 self.write_contains_flags(
428 &tracker_expr_name,
429 back::RayQueryPoint::FINISHED_TRAVERSAL.bits(),
430 )?;
431 write!(self.out, ")")?;
432 } else {
433 write!(self.out, "{level}")?;
435 }
436 writeln!(self.out, "{{")?;
437 current_level = current_level.next();
438 write!(self.out, "{current_level}float t = ")?;
439 self.put_expression(hit_t, &context.expression, true)?;
440 writeln!(self.out, ";")?;
441 if context.expression.ray_query_initialization_tracking {
442 write!(
443 self.out,
444 "{current_level}float current_max_t = {tmax_tracker_expr_name};
445{current_level}if ("
446 )?;
447 self.put_expression(query, &context.expression, true)?;
448 write!(self.out, ".get_committed_intersection_type() != {RT_NAMESPACE}::intersection_type::none) {{
449{current_level}{INDENT}current_max_t = ")?;
450 self.put_expression(query, &context.expression, true)?;
451 write!(
452 self.out,
453 ".get_committed_distance();
454{current_level}}}
455{current_level}if ("
456 )?;
457 self.put_expression(query, &context.expression, true)?;
458 write!(self.out, ".get_candidate_intersection_type() == {RT_NAMESPACE}::intersection_type::bounding_box && (")?;
459 self.put_expression(query, &context.expression, true)?;
460 write!(self.out, ".get_ray_min_distance()")?;
461 writeln!(self.out, " <= t) && (t <= current_max_t)) {{")?;
462 current_level = current_level.next();
463 }
464 write!(self.out, "{current_level}")?;
465 self.put_expression(query, &context.expression, true)?;
466 writeln!(self.out, ".commit_bounding_box_intersection(t);")?;
467 if context.expression.ray_query_initialization_tracking {
468 writeln!(self.out, "{level}{INDENT}}}")?;
469 }
470 writeln!(self.out, "{level}}}")?;
471 }
472 crate::RayQueryFunction::ConfirmIntersection => {
473 let mut current_level = level;
474 if context.expression.ray_query_initialization_tracking {
475 write!(self.out, "{level}if (")?;
476 self.write_contains_flags(
477 &tracker_expr_name,
478 back::RayQueryPoint::PROCEED.bits(),
479 )?;
480 write!(self.out, " && !")?;
481 self.write_contains_flags(
482 &tracker_expr_name,
483 back::RayQueryPoint::FINISHED_TRAVERSAL.bits(),
484 )?;
485 writeln!(self.out, ") {{")?;
486 current_level = current_level.next();
487 write!(self.out, "{current_level}if (")?;
488 self.put_expression(query, &context.expression, true)?;
489 writeln!(self.out, ".get_candidate_intersection_type() == {RT_NAMESPACE}::intersection_type::triangle) {{")?;
490 }
491 write!(self.out, "{level}")?;
492 self.put_expression(query, &context.expression, true)?;
493 writeln!(self.out, ".commit_triangle_intersection();")?;
494 if context.expression.ray_query_initialization_tracking {
495 writeln!(
496 self.out,
497 "{level}{INDENT}}}
498{level}}}"
499 )?;
500 }
501 }
502 crate::RayQueryFunction::Terminate => {
503 let mut current_level = level;
504 if context.expression.ray_query_initialization_tracking {
505 write!(self.out, "{level}if (")?;
506 self.write_contains_flags(
507 &tracker_expr_name,
508 back::RayQueryPoint::PROCEED.bits(),
509 )?;
510 write!(self.out, " && !")?;
511 self.write_contains_flags(
512 &tracker_expr_name,
513 back::RayQueryPoint::FINISHED_TRAVERSAL.bits(),
514 )?;
515 writeln!(self.out, ") {{")?;
516 current_level = current_level.next();
517 }
518 write!(self.out, "{current_level}")?;
519 self.put_expression(query, &context.expression, true)?;
520 writeln!(self.out, ".abort();")?;
523 if context.expression.ray_query_initialization_tracking {
526 writeln!(self.out, "{level}}}")?;
527 }
528 }
529 }
530
531 Ok(())
532 }
533}