naga/back/msl/
ray.rs

1use alloc::{
2    format,
3    string::{String, ToString},
4};
5use core::fmt::Write;
6
7use crate::{
8    back::{
9        self,
10        msl::{
11            writer::{StatementContext, TypeContext, WrappedFunction},
12            BackendResult, Error, Writer,
13        },
14        Baked,
15    },
16    Handle,
17};
18
19pub(super) const RT_NAMESPACE: &str = "metal::raytracing";
20
21/// The ray query type, needs to be a function so it can format the constants.
22pub(super) fn metal_intersector_ty() -> String {
23    format!("{RT_NAMESPACE}::intersection_query<{RT_NAMESPACE}::instancing, {RT_NAMESPACE}::triangle_data>")
24}
25
26pub(super) const INTERSECTION_FUNCTION_NAME: &str = "ray_query_get_intersection";
27
28impl<W: Write> Writer<W> {
29    /// Writes a function to get the current intersection from the ray query
30    ///
31    /// Like other backends, this is needed to have a single branch for constructing
32    /// the parts of the intersection that need to be checked whether they do or don't
33    /// hit.
34    pub(super) fn write_rq_get_intersection_function(
35        &mut self,
36        module: &crate::Module,
37        committed: bool,
38    ) -> BackendResult {
39        let wrapped = WrappedFunction::RayQueryGetIntersection { committed };
40        if !self.wrapped_functions.insert(wrapped) {
41            return Ok(());
42        }
43
44        let ty = if committed { "committed" } else { "candidate" };
45        let intersection = TypeContext {
46            handle: module
47                .special_types
48                .ray_intersection
49                .expect("intersection ty should be there for intersection function"),
50            gctx: module.to_ctx(),
51            names: &self.names,
52            access: crate::StorageAccess::empty(),
53            first_time: false,
54        };
55        let level = back::Level(1);
56        writeln!(
57            self.out,
58            "{intersection} {INTERSECTION_FUNCTION_NAME}_{committed}({} intersector) {{",
59            metal_intersector_ty()
60        )?;
61        // Initialize the intersection to its default values (which should be zero).
62        writeln!(
63            self.out,
64            "{level}{intersection} intersection = {intersection} {{}};"
65        )?;
66        writeln!(self.out, "{level}{RT_NAMESPACE}::intersection_type ty = intersector.get_{ty}_intersection_type();")?;
67        // If the ray hit a triangle, call all methods that require that and set the intersection type.
68        writeln!(
69            self.out,
70            "{level}if (ty == {RT_NAMESPACE}::intersection_type::triangle) {{"
71        )?;
72        writeln!(
73            self.out,
74            "{level}{level}intersection.kind = {};",
75            crate::RayQueryIntersection::Triangle as u32
76        )?;
77        if !committed {
78            writeln!(
79                self.out,
80                "{level}{level}intersection.t = intersector.get_candidate_triangle_distance();"
81            )?;
82        }
83        writeln!(self.out, "{level}{level}intersection.barycentrics = intersector.get_{ty}_triangle_barycentric_coord();")?;
84        writeln!(
85            self.out,
86            "{level}{level}intersection.front_face = intersector.is_{ty}_triangle_front_facing();"
87        )?;
88        // Otherwise, if the ray hit an AABB (called a bounding box in metal) set the intersection type
89        // (which depends on whether this is a committed or candidate intersection).
90        writeln!(
91            self.out,
92            "{level}}} else if (ty == {RT_NAMESPACE}::intersection_type::bounding_box) {{"
93        )?;
94        if committed {
95            writeln!(
96                self.out,
97                "{level}{level}intersection.kind = {};",
98                crate::RayQueryIntersection::Generated as u32
99            )?;
100        } else {
101            writeln!(
102                self.out,
103                "{level}{level}intersection.kind = {};",
104                crate::RayQueryIntersection::Aabb as u32
105            )?;
106        }
107        writeln!(self.out, "{level}}}")?;
108
109        // If the ray hit anything at all, call all methods that require that.
110        writeln!(
111            self.out,
112            "{level}if (ty != {RT_NAMESPACE}::intersection_type::none) {{"
113        )?;
114        if committed {
115            writeln!(
116                self.out,
117                "{level}{level}intersection.t = intersector.get_committed_distance();"
118            )?;
119        }
120        writeln!(self.out, "{level}{level}intersection.instance_custom_data = intersector.get_{ty}_user_instance_id();")?;
121        writeln!(
122            self.out,
123            "{level}{level}intersection.instance_index = intersector.get_{ty}_instance_id();"
124        )?;
125        // Metal does not appear to support obtaining the intersection offset from a ray query.
126        //writeln!(self.out, "{level}{level}intersection.sbt_record_offset = intersector.get_{ty}_user_instance_id();")?;
127        writeln!(
128            self.out,
129            "{level}{level}intersection.geometry_index = intersector.get_{ty}_geometry_id();"
130        )?;
131        writeln!(
132            self.out,
133            "{level}{level}intersection.primitive_index = intersector.get_{ty}_primitive_id();"
134        )?;
135        writeln!(self.out, "{level}{level}intersection.object_to_world = intersector.get_{ty}_object_to_world_transform();")?;
136        writeln!(self.out, "{level}{level}intersection.world_to_object = intersector.get_{ty}_world_to_object_transform();")?;
137        writeln!(self.out, "{level}}}")?;
138        writeln!(self.out, "{level}return intersection;")?;
139        writeln!(self.out, "}}")?;
140
141        Ok(())
142    }
143
144    pub(super) fn write_ray_query_stmt(
145        &mut self,
146        level: back::Level,
147        context: &StatementContext,
148        query: Handle<crate::Expression>,
149        fun: &crate::RayQueryFunction,
150    ) -> BackendResult {
151        if context.expression.lang_version < (2, 4) {
152            return Err(Error::UnsupportedRayTracing);
153        }
154
155        // TODO: check for misuse.
156        match *fun {
157            crate::RayQueryFunction::Initialize {
158                acceleration_structure,
159                descriptor,
160            } => {
161                //TODO: how to deal with winding? Is it by default the same as the other APIs?
162
163                // Put everything in a block so that the variable names
164                // do not conflict with user variable names
165                writeln!(self.out, "{level}{{")?;
166
167                let inner_level = level.next();
168
169                let naga_ray_desc_ty = TypeContext {
170                    handle: context
171                        .expression
172                        .module
173                        .special_types
174                        .ray_desc
175                        .expect("ray desc is required as an argument so should be there"),
176                    gctx: context.expression.module.to_ctx(),
177                    names: &self.names,
178                    access: crate::StorageAccess::empty(),
179                    first_time: false,
180                };
181
182                write!(self.out, "{inner_level}{naga_ray_desc_ty} desc = ")?;
183                self.put_expression(descriptor, &context.expression, false)?;
184                writeln!(self.out, ";")?;
185
186                // Set up intersection parameters
187                writeln!(
188                    self.out,
189                    "{inner_level}{RT_NAMESPACE}::intersection_params params;"
190                )?;
191
192                {
193                    // Determine whether or not to cull opaque/non-opaques
194                    let f_opaque = back::RayFlag::CULL_OPAQUE.bits();
195                    let f_no_opaque = back::RayFlag::CULL_NO_OPAQUE.bits();
196                    writeln!(
197                        self.out,
198                        "{inner_level}params.set_opacity_cull_mode(
199{inner_level}    (desc.flags & {f_opaque}) != 0 ? {RT_NAMESPACE}::opacity_cull_mode::opaque : (
200{inner_level}        (desc.flags & {f_no_opaque}) != 0 ? {RT_NAMESPACE}::opacity_cull_mode::non_opaque : {RT_NAMESPACE}::opacity_cull_mode::none
201{inner_level}    )
202{inner_level});"
203                    )?;
204                }
205                {
206                    // Determine whether to force a particular opacity
207                    let f_opaque = back::RayFlag::OPAQUE.bits();
208                    let f_no_opaque = back::RayFlag::NO_OPAQUE.bits();
209                    writeln!(self.out, "{inner_level}params.force_opacity(
210{inner_level}    (desc.flags & {f_opaque}) != 0 ? {RT_NAMESPACE}::forced_opacity::opaque : (
211{inner_level}        (desc.flags & {f_no_opaque}) != 0 ? {RT_NAMESPACE}::forced_opacity::non_opaque : {RT_NAMESPACE}::forced_opacity::none
212{inner_level}    )
213{inner_level});")?;
214                }
215                {
216                    let flag = back::RayFlag::TERMINATE_ON_FIRST_HIT.bits();
217                    writeln!(
218                        self.out,
219                        "{inner_level}params.accept_any_intersection((desc.flags & {flag}) != 0);"
220                    )?;
221                }
222
223                writeln!(
224                    self.out,
225                    "{inner_level}{RT_NAMESPACE}::ray ray = {RT_NAMESPACE}::ray(desc.origin, desc.dir, desc.tmin, desc.tmax);"
226                )?;
227
228                write!(self.out, "{inner_level}")?;
229                // A ray query can by initialized in metal by either using a "non-default constructor"
230                // or by calling reset. Ray queries cannot be assigned to in metal, so reset needs to
231                // be called.
232                self.put_expression(query, &context.expression, true)?;
233                write!(self.out, ".reset(ray,")?;
234                self.put_expression(acceleration_structure, &context.expression, true)?;
235                writeln!(self.out, ", desc.cull_mask, params);")?;
236                writeln!(self.out, "{level}}}")?;
237            }
238            crate::RayQueryFunction::Proceed { result } => {
239                write!(self.out, "{level}")?;
240                let name = Baked(result).to_string();
241                self.start_baking_expression(result, &context.expression, &name)?;
242                self.named_expressions.insert(result, name);
243                self.put_expression(query, &context.expression, true)?;
244                writeln!(self.out, ".next();")?;
245            }
246            crate::RayQueryFunction::GenerateIntersection { hit_t } => {
247                write!(self.out, "{level}")?;
248                self.put_expression(query, &context.expression, true)?;
249                write!(self.out, ".commit_bounding_box_intersection(")?;
250                self.put_expression(hit_t, &context.expression, true)?;
251                writeln!(self.out, ");")?;
252            }
253            crate::RayQueryFunction::ConfirmIntersection => {
254                write!(self.out, "{level}")?;
255                self.put_expression(query, &context.expression, true)?;
256                writeln!(self.out, ".commit_triangle_intersection();")?;
257            }
258            crate::RayQueryFunction::Terminate => {
259                write!(self.out, "{level}")?;
260                self.put_expression(query, &context.expression, true)?;
261                // Terminate appears to map to abort in spirv-cross, but metal only documents
262                // the existence of this method, not what it does.
263                writeln!(self.out, ".abort();")?;
264            }
265        }
266
267        Ok(())
268    }
269}