1use alloc::{
2 format,
3 string::{String, ToString},
4};
5use core::fmt::Write;
6
7use crate::{
8 back::{
9 self,
10 msl::{
11 writer::{StatementContext, TypeContext, WrappedFunction},
12 BackendResult, Error, Writer,
13 },
14 Baked,
15 },
16 Handle,
17};
18
19pub(super) const RT_NAMESPACE: &str = "metal::raytracing";
20
21pub(super) fn metal_intersector_ty() -> String {
23 format!("{RT_NAMESPACE}::intersection_query<{RT_NAMESPACE}::instancing, {RT_NAMESPACE}::triangle_data>")
24}
25
26pub(super) const INTERSECTION_FUNCTION_NAME: &str = "ray_query_get_intersection";
27
28impl<W: Write> Writer<W> {
29 pub(super) fn write_rq_get_intersection_function(
35 &mut self,
36 module: &crate::Module,
37 committed: bool,
38 ) -> BackendResult {
39 let wrapped = WrappedFunction::RayQueryGetIntersection { committed };
40 if !self.wrapped_functions.insert(wrapped) {
41 return Ok(());
42 }
43
44 let ty = if committed { "committed" } else { "candidate" };
45 let intersection = TypeContext {
46 handle: module
47 .special_types
48 .ray_intersection
49 .expect("intersection ty should be there for intersection function"),
50 gctx: module.to_ctx(),
51 names: &self.names,
52 access: crate::StorageAccess::empty(),
53 first_time: false,
54 };
55 let level = back::Level(1);
56 writeln!(
57 self.out,
58 "{intersection} {INTERSECTION_FUNCTION_NAME}_{committed}({} intersector) {{",
59 metal_intersector_ty()
60 )?;
61 writeln!(
63 self.out,
64 "{level}{intersection} intersection = {intersection} {{}};"
65 )?;
66 writeln!(self.out, "{level}{RT_NAMESPACE}::intersection_type ty = intersector.get_{ty}_intersection_type();")?;
67 writeln!(
69 self.out,
70 "{level}if (ty == {RT_NAMESPACE}::intersection_type::triangle) {{"
71 )?;
72 writeln!(
73 self.out,
74 "{level}{level}intersection.kind = {};",
75 crate::RayQueryIntersection::Triangle as u32
76 )?;
77 if !committed {
78 writeln!(
79 self.out,
80 "{level}{level}intersection.t = intersector.get_candidate_triangle_distance();"
81 )?;
82 }
83 writeln!(self.out, "{level}{level}intersection.barycentrics = intersector.get_{ty}_triangle_barycentric_coord();")?;
84 writeln!(
85 self.out,
86 "{level}{level}intersection.front_face = intersector.is_{ty}_triangle_front_facing();"
87 )?;
88 writeln!(
91 self.out,
92 "{level}}} else if (ty == {RT_NAMESPACE}::intersection_type::bounding_box) {{"
93 )?;
94 if committed {
95 writeln!(
96 self.out,
97 "{level}{level}intersection.kind = {};",
98 crate::RayQueryIntersection::Generated as u32
99 )?;
100 } else {
101 writeln!(
102 self.out,
103 "{level}{level}intersection.kind = {};",
104 crate::RayQueryIntersection::Aabb as u32
105 )?;
106 }
107 writeln!(self.out, "{level}}}")?;
108
109 writeln!(
111 self.out,
112 "{level}if (ty != {RT_NAMESPACE}::intersection_type::none) {{"
113 )?;
114 if committed {
115 writeln!(
116 self.out,
117 "{level}{level}intersection.t = intersector.get_committed_distance();"
118 )?;
119 }
120 writeln!(self.out, "{level}{level}intersection.instance_custom_data = intersector.get_{ty}_user_instance_id();")?;
121 writeln!(
122 self.out,
123 "{level}{level}intersection.instance_index = intersector.get_{ty}_instance_id();"
124 )?;
125 writeln!(
128 self.out,
129 "{level}{level}intersection.geometry_index = intersector.get_{ty}_geometry_id();"
130 )?;
131 writeln!(
132 self.out,
133 "{level}{level}intersection.primitive_index = intersector.get_{ty}_primitive_id();"
134 )?;
135 writeln!(self.out, "{level}{level}intersection.object_to_world = intersector.get_{ty}_object_to_world_transform();")?;
136 writeln!(self.out, "{level}{level}intersection.world_to_object = intersector.get_{ty}_world_to_object_transform();")?;
137 writeln!(self.out, "{level}}}")?;
138 writeln!(self.out, "{level}return intersection;")?;
139 writeln!(self.out, "}}")?;
140
141 Ok(())
142 }
143
144 pub(super) fn write_ray_query_stmt(
145 &mut self,
146 level: back::Level,
147 context: &StatementContext,
148 query: Handle<crate::Expression>,
149 fun: &crate::RayQueryFunction,
150 ) -> BackendResult {
151 if context.expression.lang_version < (2, 4) {
152 return Err(Error::UnsupportedRayTracing);
153 }
154
155 match *fun {
157 crate::RayQueryFunction::Initialize {
158 acceleration_structure,
159 descriptor,
160 } => {
161 writeln!(self.out, "{level}{{")?;
166
167 let inner_level = level.next();
168
169 let naga_ray_desc_ty = TypeContext {
170 handle: context
171 .expression
172 .module
173 .special_types
174 .ray_desc
175 .expect("ray desc is required as an argument so should be there"),
176 gctx: context.expression.module.to_ctx(),
177 names: &self.names,
178 access: crate::StorageAccess::empty(),
179 first_time: false,
180 };
181
182 write!(self.out, "{inner_level}{naga_ray_desc_ty} desc = ")?;
183 self.put_expression(descriptor, &context.expression, false)?;
184 writeln!(self.out, ";")?;
185
186 writeln!(
188 self.out,
189 "{inner_level}{RT_NAMESPACE}::intersection_params params;"
190 )?;
191
192 {
193 let f_opaque = back::RayFlag::CULL_OPAQUE.bits();
195 let f_no_opaque = back::RayFlag::CULL_NO_OPAQUE.bits();
196 writeln!(
197 self.out,
198 "{inner_level}params.set_opacity_cull_mode(
199{inner_level} (desc.flags & {f_opaque}) != 0 ? {RT_NAMESPACE}::opacity_cull_mode::opaque : (
200{inner_level} (desc.flags & {f_no_opaque}) != 0 ? {RT_NAMESPACE}::opacity_cull_mode::non_opaque : {RT_NAMESPACE}::opacity_cull_mode::none
201{inner_level} )
202{inner_level});"
203 )?;
204 }
205 {
206 let f_opaque = back::RayFlag::OPAQUE.bits();
208 let f_no_opaque = back::RayFlag::NO_OPAQUE.bits();
209 writeln!(self.out, "{inner_level}params.force_opacity(
210{inner_level} (desc.flags & {f_opaque}) != 0 ? {RT_NAMESPACE}::forced_opacity::opaque : (
211{inner_level} (desc.flags & {f_no_opaque}) != 0 ? {RT_NAMESPACE}::forced_opacity::non_opaque : {RT_NAMESPACE}::forced_opacity::none
212{inner_level} )
213{inner_level});")?;
214 }
215 {
216 let flag = back::RayFlag::TERMINATE_ON_FIRST_HIT.bits();
217 writeln!(
218 self.out,
219 "{inner_level}params.accept_any_intersection((desc.flags & {flag}) != 0);"
220 )?;
221 }
222
223 writeln!(
224 self.out,
225 "{inner_level}{RT_NAMESPACE}::ray ray = {RT_NAMESPACE}::ray(desc.origin, desc.dir, desc.tmin, desc.tmax);"
226 )?;
227
228 write!(self.out, "{inner_level}")?;
229 self.put_expression(query, &context.expression, true)?;
233 write!(self.out, ".reset(ray,")?;
234 self.put_expression(acceleration_structure, &context.expression, true)?;
235 writeln!(self.out, ", desc.cull_mask, params);")?;
236 writeln!(self.out, "{level}}}")?;
237 }
238 crate::RayQueryFunction::Proceed { result } => {
239 write!(self.out, "{level}")?;
240 let name = Baked(result).to_string();
241 self.start_baking_expression(result, &context.expression, &name)?;
242 self.named_expressions.insert(result, name);
243 self.put_expression(query, &context.expression, true)?;
244 writeln!(self.out, ".next();")?;
245 }
246 crate::RayQueryFunction::GenerateIntersection { hit_t } => {
247 write!(self.out, "{level}")?;
248 self.put_expression(query, &context.expression, true)?;
249 write!(self.out, ".commit_bounding_box_intersection(")?;
250 self.put_expression(hit_t, &context.expression, true)?;
251 writeln!(self.out, ");")?;
252 }
253 crate::RayQueryFunction::ConfirmIntersection => {
254 write!(self.out, "{level}")?;
255 self.put_expression(query, &context.expression, true)?;
256 writeln!(self.out, ".commit_triangle_intersection();")?;
257 }
258 crate::RayQueryFunction::Terminate => {
259 write!(self.out, "{level}")?;
260 self.put_expression(query, &context.expression, true)?;
261 writeln!(self.out, ".abort();")?;
264 }
265 }
266
267 Ok(())
268 }
269}