naga/back/msl/
ray.rs

1use alloc::{
2    format,
3    string::{String, ToString},
4};
5use core::fmt::Write;
6
7use crate::{
8    back::{
9        self,
10        msl::{
11            writer::{NameKeyExt, StatementContext, TypeContext, WrappedFunction},
12            BackendResult, Error, Writer, NAMESPACE,
13        },
14        Baked, INDENT,
15    },
16    Handle,
17};
18
19pub(super) const RT_NAMESPACE: &str = "metal::raytracing";
20
21/// The ray query type, needs to be a function so it can format the constants.
22pub(super) fn metal_intersector_ty() -> String {
23    format!("{RT_NAMESPACE}::intersection_query<{RT_NAMESPACE}::instancing, {RT_NAMESPACE}::triangle_data>")
24}
25
26pub(super) const INTERSECTION_FUNCTION_NAME: &str = "ray_query_get_intersection";
27pub(crate) const RAY_QUERY_TRACKER_VARIABLE_PREFIX: &str = "naga_query_init_tracker_for_";
28pub(crate) const RAY_QUERY_T_MAX_TRACKER_VARIABLE_PREFIX: &str = "naga_query_tmax_tracker_for_";
29
30impl<W: Write> Writer<W> {
31    fn write_not_finite(&mut self, expr: &str) -> BackendResult {
32        self.write_contains_flags(&format!("as_type<uint>({expr})"), 0x7f800000)
33    }
34
35    /// Checks whether `expr` does not have the bitpattern of IEEE f32 `NaN`.
36    ///
37    /// Note that this evaluates `expr` in the written code multiple times.
38    fn write_is_nan(&mut self, expr: &str) -> BackendResult {
39        write!(self.out, "(")?;
40        self.write_not_finite(expr)?;
41        write!(self.out, " && ((as_type<uint>({expr}) & 0x7fffff) != 0))")?;
42        Ok(())
43    }
44
45    fn write_contains_flags(&mut self, expr: &str, flags: u32) -> BackendResult {
46        write!(self.out, "(({expr} & {flags}) == {flags})")?;
47        Ok(())
48    }
49
50    /// Writes a function to get the current intersection from the ray query
51    ///
52    /// Like other backends, this is needed to have a single branch for constructing
53    /// the parts of the intersection that need to be checked whether they do or don't
54    /// hit.
55    pub(super) fn write_rq_get_intersection_function(
56        &mut self,
57        module: &crate::Module,
58        committed: bool,
59        options: &super::Options,
60    ) -> BackendResult {
61        let wrapped = WrappedFunction::RayQueryGetIntersection { committed };
62        if !self.wrapped_functions.insert(wrapped) {
63            return Ok(());
64        }
65
66        let ty = if committed { "committed" } else { "candidate" };
67        let intersection = TypeContext {
68            handle: module
69                .special_types
70                .ray_intersection
71                .expect("intersection ty should be there for intersection function"),
72            gctx: module.to_ctx(),
73            names: &self.names,
74            access: crate::StorageAccess::empty(),
75            first_time: false,
76        };
77        let mut base_level = back::Level(1);
78        writeln!(
79            self.out,
80            "{intersection} {INTERSECTION_FUNCTION_NAME}_{committed}({} intersector",
81            metal_intersector_ty()
82        )?;
83        if options.ray_query_initialization_tracking {
84            writeln!(self.out, ", uint intersector_tracker")?;
85        }
86        writeln!(self.out, ") {{")?;
87        // Initialize the intersection to its default values (which should be zero).
88        writeln!(
89            self.out,
90            "{base_level}{intersection} intersection = {intersection} {{}};"
91        )?;
92
93        if options.ray_query_initialization_tracking {
94            write!(self.out, "{base_level}if (")?;
95            if committed {
96                self.write_contains_flags(
97                    "intersector_tracker",
98                    back::RayQueryPoint::FINISHED_TRAVERSAL.bits(),
99                )?;
100            } else {
101                self.write_contains_flags(
102                    "intersector_tracker",
103                    back::RayQueryPoint::PROCEED.bits(),
104                )?;
105                write!(self.out, " && !")?;
106                self.write_contains_flags(
107                    "intersector_tracker",
108                    back::RayQueryPoint::FINISHED_TRAVERSAL.bits(),
109                )?;
110            }
111            writeln!(self.out, ") {{")?;
112            base_level = base_level.next();
113        }
114
115        writeln!(self.out, "{base_level}{RT_NAMESPACE}::intersection_type ty = intersector.get_{ty}_intersection_type();")?;
116        // If the ray hit a triangle, call all methods that require that and set the intersection type.
117        writeln!(
118            self.out,
119            "{base_level}if (ty == {RT_NAMESPACE}::intersection_type::triangle) {{"
120        )?;
121        writeln!(
122            self.out,
123            "{base_level}{INDENT}intersection.kind = {};",
124            crate::RayQueryIntersection::Triangle as u32
125        )?;
126        if !committed {
127            writeln!(
128                self.out,
129                "{base_level}{INDENT}intersection.t = intersector.get_candidate_triangle_distance();"
130            )?;
131        }
132        writeln!(self.out, "{base_level}{INDENT}intersection.barycentrics = intersector.get_{ty}_triangle_barycentric_coord();")?;
133        writeln!(
134            self.out,
135            "{base_level}{INDENT}intersection.front_face = intersector.is_{ty}_triangle_front_facing();"
136        )?;
137        // Otherwise, if the ray hit an AABB (called a bounding box in metal) set the intersection type
138        // (which depends on whether this is a committed or candidate intersection).
139        writeln!(
140            self.out,
141            "{base_level}}} else if (ty == {RT_NAMESPACE}::intersection_type::bounding_box) {{"
142        )?;
143        if committed {
144            writeln!(
145                self.out,
146                "{base_level}{INDENT}intersection.kind = {};",
147                crate::RayQueryIntersection::Generated as u32
148            )?;
149        } else {
150            writeln!(
151                self.out,
152                "{base_level}{INDENT}intersection.kind = {};",
153                crate::RayQueryIntersection::Aabb as u32
154            )?;
155        }
156        writeln!(self.out, "{base_level}}}")?;
157
158        // If the ray hit anything at all, call all methods that require that.
159        writeln!(
160            self.out,
161            "{base_level}if (ty != {RT_NAMESPACE}::intersection_type::none) {{"
162        )?;
163        if committed {
164            writeln!(
165                self.out,
166                "{base_level}{INDENT}intersection.t = intersector.get_committed_distance();"
167            )?;
168        }
169        writeln!(self.out, "{base_level}{INDENT}intersection.instance_custom_data = intersector.get_{ty}_user_instance_id();")?;
170        writeln!(
171            self.out,
172            "{base_level}{INDENT}intersection.instance_index = intersector.get_{ty}_instance_id();"
173        )?;
174        // Metal does not appear to support obtaining the intersection offset from a ray query.
175        //writeln!(self.out, "{level}{level}intersection.sbt_record_offset = intersector.get_{ty}_user_instance_id();")?;
176        writeln!(
177            self.out,
178            "{base_level}{INDENT}intersection.geometry_index = intersector.get_{ty}_geometry_id();"
179        )?;
180        writeln!(
181            self.out,
182            "{base_level}{INDENT}intersection.primitive_index = intersector.get_{ty}_primitive_id();"
183        )?;
184        writeln!(self.out, "{base_level}{INDENT}intersection.object_to_world = intersector.get_{ty}_object_to_world_transform();")?;
185        writeln!(self.out, "{base_level}{INDENT}intersection.world_to_object = intersector.get_{ty}_world_to_object_transform();")?;
186        writeln!(self.out, "{base_level}}}")?;
187
188        if options.ray_query_initialization_tracking {
189            writeln!(self.out, "{INDENT}}}")?;
190        }
191
192        writeln!(self.out, "{INDENT}return intersection;")?;
193        writeln!(self.out, "}}")?;
194
195        Ok(())
196    }
197
198    pub(super) fn write_ray_query_stmt(
199        &mut self,
200        level: back::Level,
201        context: &StatementContext,
202        query: Handle<crate::Expression>,
203        fun: &crate::RayQueryFunction,
204    ) -> BackendResult {
205        if context.expression.lang_version < (2, 4) {
206            return Err(Error::UnsupportedRayTracing);
207        }
208
209        // There are three possibilities for a ptr to be:
210        // 1. A variable
211        // 2. A function argument
212        // 3. part of a struct
213        //
214        // 2 and 3 are not possible, a ray query (in naga IR)
215        // is not allowed to be passed into a function, and
216        // all languages disallow it in a struct (you get fun results if
217        // you try it :) ).
218        //
219        // Therefore, the ray query expression must be a variable.
220        let crate::Expression::LocalVariable(query_var) =
221            context.expression.function.expressions[query]
222        else {
223            unreachable!()
224        };
225
226        let tracker_expr_name = format!(
227            "{RAY_QUERY_TRACKER_VARIABLE_PREFIX}{}",
228            self.names[&crate::proc::NameKey::local(context.expression.origin, query_var)]
229        );
230
231        let tmax_tracker_expr_name = format!(
232            "{RAY_QUERY_T_MAX_TRACKER_VARIABLE_PREFIX}{}",
233            self.names[&crate::proc::NameKey::local(context.expression.origin, query_var)]
234        );
235
236        // TODO: check for misuse.
237        match *fun {
238            crate::RayQueryFunction::Initialize {
239                acceleration_structure,
240                descriptor,
241            } => {
242                //TODO: how to deal with winding? Is it by default the same as the other APIs?
243
244                // Put everything in a block so that the variable names
245                // do not conflict with user variable names
246                writeln!(self.out, "{level}{{")?;
247
248                let inner_level = level.next();
249
250                let naga_ray_desc_ty = TypeContext {
251                    handle: context
252                        .expression
253                        .module
254                        .special_types
255                        .ray_desc
256                        .expect("ray desc is required as an argument so should be there"),
257                    gctx: context.expression.module.to_ctx(),
258                    names: &self.names,
259                    access: crate::StorageAccess::empty(),
260                    first_time: false,
261                };
262
263                write!(self.out, "{inner_level}{naga_ray_desc_ty} desc = ")?;
264                self.put_expression(descriptor, &context.expression, false)?;
265                writeln!(self.out, ";")?;
266
267                // Set up intersection parameters
268                writeln!(
269                    self.out,
270                    "{inner_level}{RT_NAMESPACE}::intersection_params params;"
271                )?;
272
273                {
274                    // Determine whether or not to cull opaque/non-opaques
275                    let f_opaque = back::RayFlag::CULL_OPAQUE.bits();
276                    let f_no_opaque = back::RayFlag::CULL_NO_OPAQUE.bits();
277                    writeln!(self.out, "{inner_level}{RT_NAMESPACE}::opacity_cull_mode cull_mode = 
278{inner_level}{INDENT}(desc.flags & {f_opaque}) != 0 ? {RT_NAMESPACE}::opacity_cull_mode::opaque : (
279{inner_level}{INDENT}{INDENT}(desc.flags & {f_no_opaque}) != 0 ? {RT_NAMESPACE}::opacity_cull_mode::non_opaque : {RT_NAMESPACE}::opacity_cull_mode::none
280{inner_level}{INDENT});")?;
281                    writeln!(
282                        self.out,
283                        "{inner_level}params.set_opacity_cull_mode(cull_mode);"
284                    )?;
285
286                    if context.expression.ray_query_initialization_tracking {
287                        writeln!(self.out, "{inner_level}bool force_opacity = cull_mode == {RT_NAMESPACE}::opacity_cull_mode::none;")?;
288                    }
289                }
290                {
291                    let mut current_level = inner_level;
292                    if context.expression.ray_query_initialization_tracking {
293                        writeln!(self.out, "{inner_level}if (force_opacity) {{")?;
294                        current_level = current_level.next();
295                    }
296                    // Determine whether to force a particular opacity
297                    let f_opaque = back::RayFlag::OPAQUE.bits();
298                    let f_no_opaque = back::RayFlag::NO_OPAQUE.bits();
299                    writeln!(self.out, "{current_level}params.force_opacity(
300{current_level}    (desc.flags & {f_opaque}) != 0 ? {RT_NAMESPACE}::forced_opacity::opaque : (
301{current_level}        (desc.flags & {f_no_opaque}) != 0 ? {RT_NAMESPACE}::forced_opacity::non_opaque : {RT_NAMESPACE}::forced_opacity::none
302{current_level}    )
303{current_level});")?;
304
305                    if context.expression.ray_query_initialization_tracking {
306                        writeln!(self.out, "{inner_level}}}")?;
307                    }
308                }
309                {
310                    let flag = back::RayFlag::TERMINATE_ON_FIRST_HIT.bits();
311                    writeln!(
312                        self.out,
313                        "{inner_level}params.accept_any_intersection((desc.flags & {flag}) != 0);"
314                    )?;
315                }
316
317                writeln!(
318                    self.out,
319                    "{inner_level}{RT_NAMESPACE}::ray ray = {RT_NAMESPACE}::ray(desc.origin, desc.dir, desc.tmin, desc.tmax);"
320                )?;
321
322                let mut init_level = inner_level;
323
324                // The `reset` function is virtually undocumented (many of the Metal ray tracing functions lack it), so to be safe,
325                // this assumes an invalid ray is UB (NOTE: invalid ray behaviour is defined for intersectors).
326                if context.expression.ray_query_initialization_tracking {
327                    write!(self.out, "{inner_level}bool invalid_nan_infs = ")?;
328                    // tmax needs special handling because it can be INF
329                    for (idx, &field_access) in [
330                        "origin.x", "origin.y", "origin.z", "dir.x", "dir.y", "dir.z", "tmin",
331                    ]
332                    .iter()
333                    .enumerate()
334                    {
335                        if idx != 0 {
336                            write!(self.out, " || ")?;
337                        }
338
339                        self.write_not_finite(&format!("desc.{field_access}"))?;
340                    }
341
342                    write!(self.out, " || ")?;
343                    self.write_is_nan("desc.tmax")?;
344                    writeln!(self.out, ";")?;
345
346                    // Metal also requires that tmax >= 0.0, but if tmax >= tmin and tmin >= 0.0, tmax must be >= 0.0
347                    writeln!(self.out, "{inner_level}bool invalid_t = (desc.tmin > desc.tmax) || (desc.tmin < 0.0);")?;
348                    // Metal requires that the length of the direction is not 0.0. This is the case only when all the
349                    // components are zero.
350                    //
351                    // Use absolute to cover signed zero.
352                    writeln!(self.out, "{inner_level}bool invalid_dir = {NAMESPACE}::all({NAMESPACE}::abs(desc.dir) == 0.0);")?;
353
354                    writeln!(
355                        self.out,
356                        "{inner_level}if (!(invalid_dir || invalid_t || invalid_nan_infs)) {{"
357                    )?;
358                    init_level = init_level.next();
359                }
360
361                write!(self.out, "{init_level}")?;
362                // A ray query can by initialized in metal by either using a "non-default constructor"
363                // or by calling reset. Ray queries cannot be assigned to in metal, so reset needs to
364                // be called.
365                self.put_expression(query, &context.expression, true)?;
366                write!(self.out, ".reset(ray,")?;
367                self.put_expression(acceleration_structure, &context.expression, true)?;
368                writeln!(self.out, ", desc.cull_mask, params);")?;
369                if context.expression.ray_query_initialization_tracking {
370                    // We don't set the initialization tracker to zero (uninitialized)
371                    // if the call fails. Resetting to uninitialized might be useful
372                    // for debugging, but for everything else it is just extra code.
373                    writeln!(
374                        self.out,
375                        "{init_level}{tracker_expr_name} = {};",
376                        back::RayQueryPoint::INITIALIZED.bits()
377                    )?;
378                    writeln!(
379                        self.out,
380                        "{init_level}{tmax_tracker_expr_name} = desc.tmax;"
381                    )?;
382                    writeln!(self.out, "{inner_level}}}")?;
383                }
384                writeln!(self.out, "{level}}}")?;
385            }
386            crate::RayQueryFunction::Proceed { result } => {
387                let mut current_level = level;
388                write!(self.out, "{current_level}")?;
389                let name = Baked(result).to_string();
390                self.start_baking_expression(result, &context.expression, &name)?;
391                self.named_expressions.insert(result, name.clone());
392
393                writeln!(self.out, "false;")?;
394
395                if context.expression.ray_query_initialization_tracking {
396                    write!(self.out, "{level}if (")?;
397                    self.write_contains_flags(
398                        &tracker_expr_name,
399                        back::RayQueryPoint::INITIALIZED.bits(),
400                    )?;
401                    write!(self.out, " && !")?;
402                    self.write_contains_flags(
403                        &tracker_expr_name,
404                        back::RayQueryPoint::FINISHED_TRAVERSAL.bits(),
405                    )?;
406                    write!(self.out, ")")?;
407                    writeln!(self.out, " {{")?;
408                    current_level = current_level.next();
409                }
410                write!(self.out, "{current_level}{name} = ")?;
411                self.put_expression(query, &context.expression, true)?;
412                writeln!(self.out, ".next();")?;
413                if context.expression.ray_query_initialization_tracking {
414                    writeln!(self.out, "{current_level}{tracker_expr_name} = {tracker_expr_name} | ({name} ? {}: {});", back::RayQueryPoint::PROCEED.bits(), (back::RayQueryPoint::PROCEED | back::RayQueryPoint::FINISHED_TRAVERSAL).bits())?;
415                    writeln!(self.out, "{level}}}")?;
416                }
417            }
418            crate::RayQueryFunction::GenerateIntersection { hit_t } => {
419                let mut current_level = level;
420                if context.expression.ray_query_initialization_tracking {
421                    write!(self.out, "{level}if (")?;
422                    self.write_contains_flags(
423                        &tracker_expr_name,
424                        back::RayQueryPoint::PROCEED.bits(),
425                    )?;
426                    write!(self.out, " && !")?;
427                    self.write_contains_flags(
428                        &tracker_expr_name,
429                        back::RayQueryPoint::FINISHED_TRAVERSAL.bits(),
430                    )?;
431                    write!(self.out, ")")?;
432                } else {
433                    // For readability
434                    write!(self.out, "{level}")?;
435                }
436                writeln!(self.out, "{{")?;
437                current_level = current_level.next();
438                write!(self.out, "{current_level}float t = ")?;
439                self.put_expression(hit_t, &context.expression, true)?;
440                writeln!(self.out, ";")?;
441                if context.expression.ray_query_initialization_tracking {
442                    write!(
443                        self.out,
444                        "{current_level}float current_max_t = {tmax_tracker_expr_name};
445{current_level}if ("
446                    )?;
447                    self.put_expression(query, &context.expression, true)?;
448                    write!(self.out, ".get_committed_intersection_type() != {RT_NAMESPACE}::intersection_type::none) {{
449{current_level}{INDENT}current_max_t = ")?;
450                    self.put_expression(query, &context.expression, true)?;
451                    write!(
452                        self.out,
453                        ".get_committed_distance();
454{current_level}}}
455{current_level}if ("
456                    )?;
457                    self.put_expression(query, &context.expression, true)?;
458                    write!(self.out, ".get_candidate_intersection_type() == {RT_NAMESPACE}::intersection_type::bounding_box && (")?;
459                    self.put_expression(query, &context.expression, true)?;
460                    write!(self.out, ".get_ray_min_distance()")?;
461                    writeln!(self.out, " <= t) && (t <= current_max_t)) {{")?;
462                    current_level = current_level.next();
463                }
464                write!(self.out, "{current_level}")?;
465                self.put_expression(query, &context.expression, true)?;
466                writeln!(self.out, ".commit_bounding_box_intersection(t);")?;
467                if context.expression.ray_query_initialization_tracking {
468                    writeln!(self.out, "{level}{INDENT}}}")?;
469                }
470                writeln!(self.out, "{level}}}")?;
471            }
472            crate::RayQueryFunction::ConfirmIntersection => {
473                let mut current_level = level;
474                if context.expression.ray_query_initialization_tracking {
475                    write!(self.out, "{level}if (")?;
476                    self.write_contains_flags(
477                        &tracker_expr_name,
478                        back::RayQueryPoint::PROCEED.bits(),
479                    )?;
480                    write!(self.out, " && !")?;
481                    self.write_contains_flags(
482                        &tracker_expr_name,
483                        back::RayQueryPoint::FINISHED_TRAVERSAL.bits(),
484                    )?;
485                    writeln!(self.out, ") {{")?;
486                    current_level = current_level.next();
487                    write!(self.out, "{current_level}if (")?;
488                    self.put_expression(query, &context.expression, true)?;
489                    writeln!(self.out, ".get_candidate_intersection_type() == {RT_NAMESPACE}::intersection_type::triangle) {{")?;
490                }
491                write!(self.out, "{level}")?;
492                self.put_expression(query, &context.expression, true)?;
493                writeln!(self.out, ".commit_triangle_intersection();")?;
494                if context.expression.ray_query_initialization_tracking {
495                    writeln!(
496                        self.out,
497                        "{level}{INDENT}}}
498{level}}}"
499                    )?;
500                }
501            }
502            crate::RayQueryFunction::Terminate => {
503                let mut current_level = level;
504                if context.expression.ray_query_initialization_tracking {
505                    write!(self.out, "{level}if (")?;
506                    self.write_contains_flags(
507                        &tracker_expr_name,
508                        back::RayQueryPoint::PROCEED.bits(),
509                    )?;
510                    write!(self.out, " && !")?;
511                    self.write_contains_flags(
512                        &tracker_expr_name,
513                        back::RayQueryPoint::FINISHED_TRAVERSAL.bits(),
514                    )?;
515                    writeln!(self.out, ") {{")?;
516                    current_level = current_level.next();
517                }
518                write!(self.out, "{current_level}")?;
519                self.put_expression(query, &context.expression, true)?;
520                // Terminate appears to map to abort in spirv-cross, but metal only documents
521                // the existence of this method, not what it does.
522                writeln!(self.out, ".abort();")?;
523                // To get the committed intersection, an extra proceed must occur as specified in
524                // the API docs.
525                if context.expression.ray_query_initialization_tracking {
526                    writeln!(self.out, "{level}}}")?;
527                }
528            }
529        }
530
531        Ok(())
532    }
533}