naga/back/hlsl/
ray.rs

1use alloc::{
2    format,
3    string::{String, ToString},
4    vec,
5    vec::Vec,
6};
7use core::fmt::Write;
8
9use crate::{
10    back::{hlsl::BackendResult, Baked, Level},
11    Handle,
12};
13use crate::{RayQueryIntersection, TypeInner};
14
15impl<W: Write> super::Writer<'_, W> {
16    // https://sakibsaikia.github.io/graphics/2022/01/04/Nan-Checks-In-HLSL.html suggests that isnan may not work, unsure if this has changed.
17    fn write_not_finite(&mut self, expr: &str) -> BackendResult {
18        self.write_contains_flags(&format!("asuint({expr})"), 0x7f800000)
19    }
20
21    fn write_nan(&mut self, expr: &str) -> BackendResult {
22        write!(self.out, "(")?;
23        self.write_not_finite(expr)?;
24        write!(self.out, " && ((asuint({expr}) & 0x7fffff) != 0))")?;
25        Ok(())
26    }
27
28    fn write_contains_flags(&mut self, expr: &str, flags: u32) -> BackendResult {
29        write!(self.out, "(({expr} & {flags}) == {flags})")?;
30        Ok(())
31    }
32
33    // constructs hlsl RayDesc from wgsl RayDesc
34    pub(super) fn write_ray_desc_from_ray_desc_constructor_function(
35        &mut self,
36        module: &crate::Module,
37    ) -> BackendResult {
38        write!(self.out, "RayDesc RayDescFromRayDesc_(")?;
39        self.write_type(module, module.special_types.ray_desc.unwrap())?;
40        writeln!(self.out, " arg0) {{")?;
41        writeln!(self.out, "    RayDesc ret = (RayDesc)0;")?;
42        writeln!(self.out, "    ret.Origin = arg0.origin;")?;
43        writeln!(self.out, "    ret.TMin = arg0.tmin;")?;
44        writeln!(self.out, "    ret.Direction = arg0.dir;")?;
45        writeln!(self.out, "    ret.TMax = arg0.tmax;")?;
46        writeln!(self.out, "    return ret;")?;
47        writeln!(self.out, "}}")?;
48        writeln!(self.out)?;
49        Ok(())
50    }
51    pub(super) fn write_committed_intersection_function(
52        &mut self,
53        module: &crate::Module,
54    ) -> BackendResult {
55        self.write_type(module, module.special_types.ray_intersection.unwrap())?;
56        write!(self.out, " GetCommittedIntersection(")?;
57        self.write_value_type(
58            module,
59            &TypeInner::RayQuery {
60                vertex_return: false,
61            },
62        )?;
63        write!(self.out, " rq, ")?;
64        self.write_value_type(module, &TypeInner::Scalar(crate::Scalar::U32))?;
65        writeln!(self.out, " rq_tracker) {{")?;
66        write!(self.out, "    ")?;
67        self.write_type(module, module.special_types.ray_intersection.unwrap())?;
68        write!(self.out, " ret = (")?;
69        self.write_type(module, module.special_types.ray_intersection.unwrap())?;
70        writeln!(self.out, ")0;")?;
71        let mut extra_level = Level(0);
72        if self.options.ray_query_initialization_tracking {
73            // *Technically*, `CommittedStatus` is valid as long as the ray query is initialized, but the metal backend
74            // doesn't support this function unless it has finished traversal, so to encourage portable behaviour we
75            // disallow it here too.
76            write!(self.out, "    if (")?;
77            self.write_contains_flags(
78                "rq_tracker",
79                crate::back::RayQueryPoint::FINISHED_TRAVERSAL.bits(),
80            )?;
81            writeln!(self.out, ") {{")?;
82            extra_level = extra_level.next();
83        }
84        writeln!(
85            self.out,
86            "    {extra_level}ret.kind = rq.CommittedStatus();"
87        )?;
88        writeln!(
89            self.out,
90            "    {extra_level}if( rq.CommittedStatus() == COMMITTED_NOTHING) {{}} else {{"
91        )?;
92        writeln!(self.out, "        {extra_level}ret.t = rq.CommittedRayT();")?;
93        writeln!(
94            self.out,
95            "        {extra_level}ret.instance_custom_data = rq.CommittedInstanceID();"
96        )?;
97        writeln!(
98            self.out,
99            "        {extra_level}ret.instance_index = rq.CommittedInstanceIndex();"
100        )?;
101        writeln!(
102            self.out,
103            "        {extra_level}ret.sbt_record_offset = rq.CommittedInstanceContributionToHitGroupIndex();"
104        )?;
105        writeln!(
106            self.out,
107            "        {extra_level}ret.geometry_index = rq.CommittedGeometryIndex();"
108        )?;
109        writeln!(
110            self.out,
111            "        {extra_level}ret.primitive_index = rq.CommittedPrimitiveIndex();"
112        )?;
113        writeln!(
114            self.out,
115            "        {extra_level}if( rq.CommittedStatus() == COMMITTED_TRIANGLE_HIT ) {{"
116        )?;
117        writeln!(
118            self.out,
119            "            {extra_level}ret.barycentrics = rq.CommittedTriangleBarycentrics();"
120        )?;
121        writeln!(
122            self.out,
123            "            {extra_level}ret.front_face = rq.CommittedTriangleFrontFace();"
124        )?;
125        writeln!(self.out, "        {extra_level}}}")?;
126        writeln!(
127            self.out,
128            "        {extra_level}ret.object_to_world = rq.CommittedObjectToWorld4x3();"
129        )?;
130        writeln!(
131            self.out,
132            "        {extra_level}ret.world_to_object = rq.CommittedWorldToObject4x3();"
133        )?;
134        writeln!(self.out, "    {extra_level}}}")?;
135        if self.options.ray_query_initialization_tracking {
136            writeln!(self.out, "    }}")?;
137        }
138        writeln!(self.out, "    return ret;")?;
139        writeln!(self.out, "}}")?;
140        writeln!(self.out)?;
141        Ok(())
142    }
143    pub(super) fn write_candidate_intersection_function(
144        &mut self,
145        module: &crate::Module,
146    ) -> BackendResult {
147        self.write_type(module, module.special_types.ray_intersection.unwrap())?;
148        write!(self.out, " GetCandidateIntersection(")?;
149        self.write_value_type(
150            module,
151            &TypeInner::RayQuery {
152                vertex_return: false,
153            },
154        )?;
155        write!(self.out, " rq, ")?;
156        self.write_value_type(module, &TypeInner::Scalar(crate::Scalar::U32))?;
157        writeln!(self.out, " rq_tracker) {{")?;
158        write!(self.out, "    ")?;
159        self.write_type(module, module.special_types.ray_intersection.unwrap())?;
160        write!(self.out, " ret = (")?;
161        self.write_type(module, module.special_types.ray_intersection.unwrap())?;
162        writeln!(self.out, ")0;")?;
163        let mut extra_level = Level(0);
164        if self.options.ray_query_initialization_tracking {
165            write!(self.out, "    if (")?;
166            self.write_contains_flags("rq_tracker", crate::back::RayQueryPoint::PROCEED.bits())?;
167            write!(self.out, " && !")?;
168            self.write_contains_flags(
169                "rq_tracker",
170                crate::back::RayQueryPoint::FINISHED_TRAVERSAL.bits(),
171            )?;
172            writeln!(self.out, ") {{")?;
173            extra_level = extra_level.next();
174        }
175        writeln!(
176            self.out,
177            "    {extra_level}CANDIDATE_TYPE kind = rq.CandidateType();"
178        )?;
179        writeln!(
180            self.out,
181            "    {extra_level}if (kind == CANDIDATE_NON_OPAQUE_TRIANGLE) {{"
182        )?;
183        writeln!(
184            self.out,
185            "        {extra_level}ret.kind = {};",
186            RayQueryIntersection::Triangle as u32
187        )?;
188        writeln!(
189            self.out,
190            "        {extra_level}ret.t = rq.CandidateTriangleRayT();"
191        )?;
192        writeln!(
193            self.out,
194            "        {extra_level}ret.barycentrics = rq.CandidateTriangleBarycentrics();"
195        )?;
196        writeln!(
197            self.out,
198            "        {extra_level}ret.front_face = rq.CandidateTriangleFrontFace();"
199        )?;
200        writeln!(self.out, "    {extra_level}}} else {{")?;
201        writeln!(
202            self.out,
203            "        {extra_level}ret.kind = {};",
204            RayQueryIntersection::Aabb as u32
205        )?;
206        writeln!(self.out, "    {extra_level}}}")?;
207
208        writeln!(
209            self.out,
210            "    {extra_level}ret.instance_custom_data = rq.CandidateInstanceID();"
211        )?;
212        writeln!(
213            self.out,
214            "    {extra_level}ret.instance_index = rq.CandidateInstanceIndex();"
215        )?;
216        writeln!(
217            self.out,
218            "    {extra_level}ret.sbt_record_offset = rq.CandidateInstanceContributionToHitGroupIndex();"
219        )?;
220        writeln!(
221            self.out,
222            "    {extra_level}ret.geometry_index = rq.CandidateGeometryIndex();"
223        )?;
224        writeln!(
225            self.out,
226            "    {extra_level}ret.primitive_index = rq.CandidatePrimitiveIndex();"
227        )?;
228        writeln!(
229            self.out,
230            "    {extra_level}ret.object_to_world = rq.CandidateObjectToWorld4x3();"
231        )?;
232        writeln!(
233            self.out,
234            "    {extra_level}ret.world_to_object = rq.CandidateWorldToObject4x3();"
235        )?;
236        if self.options.ray_query_initialization_tracking {
237            writeln!(self.out, "    }}")?;
238        }
239        writeln!(self.out, "    return ret;")?;
240        writeln!(self.out, "}}")?;
241        writeln!(self.out)?;
242        Ok(())
243    }
244
245    #[expect(clippy::too_many_arguments)]
246    pub(super) fn write_initialize_function(
247        &mut self,
248        module: &crate::Module,
249        mut level: Level,
250        query: Handle<crate::Expression>,
251        acceleration_structure: Handle<crate::Expression>,
252        descriptor: Handle<crate::Expression>,
253        rq_tracker: &str,
254        func_ctx: &crate::back::FunctionCtx<'_>,
255    ) -> BackendResult {
256        let base_level = level;
257
258        // This prevents variables flowing down a level and causing compile errors.
259        writeln!(self.out, "{level}{{")?;
260        level = level.next();
261        write!(self.out, "{level}")?;
262        self.write_type(
263            module,
264            module
265                .special_types
266                .ray_desc
267                .expect("should have been generated"),
268        )?;
269        write!(self.out, " naga_desc = ")?;
270        self.write_expr(module, descriptor, func_ctx)?;
271        writeln!(self.out, ";")?;
272
273        if self.options.ray_query_initialization_tracking {
274            // Validate ray extents https://microsoft.github.io/DirectX-Specs/d3d/Raytracing.html#ray-extents
275
276            // just for convenience
277            writeln!(self.out, "{level}float naga_tmin = naga_desc.tmin;")?;
278            writeln!(self.out, "{level}float naga_tmax = naga_desc.tmax;")?;
279            writeln!(self.out, "{level}float3 naga_origin = naga_desc.origin;")?;
280            writeln!(self.out, "{level}float3 naga_dir = naga_desc.dir;")?;
281            writeln!(self.out, "{level}uint naga_flags = naga_desc.flags;")?;
282            write!(
283                self.out,
284                "{level}bool naga_tmin_valid = (naga_tmin >= 0.0) && (naga_tmin <= naga_tmax) && !"
285            )?;
286            self.write_nan("naga_tmin")?;
287            writeln!(self.out, ";")?;
288            write!(self.out, "{level}bool naga_tmax_valid = !")?;
289            self.write_nan("naga_tmax")?;
290            writeln!(self.out, ";")?;
291            // Unlike Vulkan it seems that for DX12, it seems only NaN components of the origin and direction are invalid
292            write!(self.out, "{level}bool naga_origin_valid = !any(")?;
293            self.write_nan("naga_origin")?;
294            writeln!(self.out, ");")?;
295            write!(self.out, "{level}bool naga_dir_valid = !any(")?;
296            self.write_nan("naga_dir")?;
297            writeln!(self.out, ");")?;
298            write!(self.out, "{level}bool naga_contains_opaque = ")?;
299            self.write_contains_flags("naga_flags", crate::RayFlag::FORCE_OPAQUE.bits())?;
300            writeln!(self.out, ";")?;
301            write!(self.out, "{level}bool naga_contains_no_opaque = ")?;
302            self.write_contains_flags("naga_flags", crate::RayFlag::FORCE_NO_OPAQUE.bits())?;
303            writeln!(self.out, ";")?;
304            write!(self.out, "{level}bool naga_contains_cull_opaque = ")?;
305            self.write_contains_flags("naga_flags", crate::RayFlag::CULL_OPAQUE.bits())?;
306            writeln!(self.out, ";")?;
307            write!(self.out, "{level}bool naga_contains_cull_no_opaque = ")?;
308            self.write_contains_flags("naga_flags", crate::RayFlag::CULL_NO_OPAQUE.bits())?;
309            writeln!(self.out, ";")?;
310            write!(self.out, "{level}bool naga_contains_cull_front = ")?;
311            self.write_contains_flags("naga_flags", crate::RayFlag::CULL_FRONT_FACING.bits())?;
312            writeln!(self.out, ";")?;
313            write!(self.out, "{level}bool naga_contains_cull_back = ")?;
314            self.write_contains_flags("naga_flags", crate::RayFlag::CULL_BACK_FACING.bits())?;
315            writeln!(self.out, ";")?;
316            write!(self.out, "{level}bool naga_contains_skip_triangles = ")?;
317            self.write_contains_flags("naga_flags", crate::RayFlag::SKIP_TRIANGLES.bits())?;
318            writeln!(self.out, ";")?;
319            write!(self.out, "{level}bool naga_contains_skip_aabbs = ")?;
320            self.write_contains_flags("naga_flags", crate::RayFlag::SKIP_AABBS.bits())?;
321            writeln!(self.out, ";")?;
322            // A textified version of the same in the spirv writer
323            fn less_than_two_true(mut bools: Vec<&str>) -> Result<String, super::Error> {
324                assert!(bools.len() > 1, "Must have multiple booleans!");
325                let mut final_expr = String::new();
326                while let Some(last_bool) = bools.pop() {
327                    for &bool in &bools {
328                        if !final_expr.is_empty() {
329                            final_expr.push_str("||");
330                        }
331                        write!(final_expr, " ({last_bool} && {bool}) ")?;
332                    }
333                }
334                Ok(final_expr)
335            }
336            writeln!(
337                self.out,
338                "{level}bool naga_contains_skip_triangles_aabbs = {};",
339                less_than_two_true(vec![
340                    "naga_contains_skip_triangles",
341                    "naga_contains_skip_aabbs"
342                ])?
343            )?;
344            writeln!(
345                self.out,
346                "{level}bool naga_contains_skip_triangles_cull = {};",
347                less_than_two_true(vec![
348                    "naga_contains_skip_triangles",
349                    "naga_contains_cull_back",
350                    "naga_contains_cull_front"
351                ])?
352            )?;
353            writeln!(
354                self.out,
355                "{level}bool naga_contains_multiple_opaque = {};",
356                less_than_two_true(vec![
357                    "naga_contains_opaque",
358                    "naga_contains_no_opaque",
359                    "naga_contains_cull_opaque",
360                    "naga_contains_cull_no_opaque"
361                ])?
362            )?;
363            writeln!(
364                self.out,
365                "{level}if (naga_tmin_valid && naga_tmax_valid && naga_origin_valid && naga_dir_valid && !(naga_contains_skip_triangles_aabbs || naga_contains_skip_triangles_cull || naga_contains_multiple_opaque)) {{"
366            )?;
367            level = level.next();
368            writeln!(
369                self.out,
370                "{level}{rq_tracker} = {rq_tracker} | {};",
371                crate::back::RayQueryPoint::INITIALIZED.bits()
372            )?;
373        }
374        write!(self.out, "{level}")?;
375        self.write_expr(module, query, func_ctx)?;
376        write!(self.out, ".TraceRayInline(")?;
377        self.write_expr(module, acceleration_structure, func_ctx)?;
378        writeln!(
379            self.out,
380            ", naga_desc.flags, naga_desc.cull_mask, RayDescFromRayDesc_(naga_desc));"
381        )?;
382        if self.options.ray_query_initialization_tracking {
383            writeln!(self.out, "{base_level}    }}")?;
384        }
385        writeln!(self.out, "{base_level}}}")?;
386        Ok(())
387    }
388
389    pub(super) fn write_proceed(
390        &mut self,
391        module: &crate::Module,
392        mut level: Level,
393        query: Handle<crate::Expression>,
394        result: Handle<crate::Expression>,
395        rq_tracker: &str,
396        func_ctx: &crate::back::FunctionCtx<'_>,
397    ) -> BackendResult {
398        let base_level = level;
399        write!(self.out, "{level}")?;
400        let name = Baked(result).to_string();
401        writeln!(self.out, "bool {name} = false;")?;
402        // This prevents variables flowing down a level and causing compile errors.
403        if self.options.ray_query_initialization_tracking {
404            writeln!(self.out, "{level}{{")?;
405            level = level.next();
406            write!(self.out, "{level}bool naga_has_initialized = ")?;
407            self.write_contains_flags(rq_tracker, crate::back::RayQueryPoint::INITIALIZED.bits())?;
408            writeln!(self.out, ";")?;
409            write!(self.out, "{level}bool naga_has_finished = ")?;
410            self.write_contains_flags(
411                rq_tracker,
412                crate::back::RayQueryPoint::FINISHED_TRAVERSAL.bits(),
413            )?;
414            writeln!(self.out, ";")?;
415            writeln!(
416                self.out,
417                "{level}if (naga_has_initialized && !naga_has_finished) {{"
418            )?;
419            level = level.next();
420        }
421
422        write!(self.out, "{level}{name} = ")?;
423        self.write_expr(module, query, func_ctx)?;
424        writeln!(self.out, ".Proceed();")?;
425
426        if self.options.ray_query_initialization_tracking {
427            writeln!(
428                self.out,
429                "{level}{rq_tracker} = {rq_tracker} | {};",
430                crate::back::RayQueryPoint::PROCEED.bits()
431            )?;
432            writeln!(
433                self.out,
434                "{level}if (!{name}) {{ {rq_tracker} = {rq_tracker} | {}; }}",
435                crate::back::RayQueryPoint::FINISHED_TRAVERSAL.bits()
436            )?;
437            writeln!(self.out, "{base_level}}}}}")?;
438        }
439
440        self.named_expressions.insert(result, name);
441
442        Ok(())
443    }
444
445    pub(super) fn write_generate_intersection(
446        &mut self,
447        module: &crate::Module,
448        mut level: Level,
449        query: Handle<crate::Expression>,
450        hit_t: Handle<crate::Expression>,
451        rq_tracker: &str,
452        func_ctx: &crate::back::FunctionCtx<'_>,
453    ) -> BackendResult {
454        let base_level = level;
455        if self.options.ray_query_initialization_tracking {
456            write!(self.out, "{level}if (")?;
457            self.write_contains_flags(rq_tracker, crate::back::RayQueryPoint::PROCEED.bits())?;
458            write!(self.out, " && !")?;
459            self.write_contains_flags(
460                rq_tracker,
461                crate::back::RayQueryPoint::FINISHED_TRAVERSAL.bits(),
462            )?;
463            writeln!(self.out, ") {{")?;
464            level = level.next();
465            write!(self.out, "{level}CANDIDATE_TYPE naga_kind = ")?;
466            self.write_expr(module, query, func_ctx)?;
467            writeln!(self.out, ".CandidateType();")?;
468            write!(self.out, "{level}float naga_tmin = ")?;
469            self.write_expr(module, query, func_ctx)?;
470            writeln!(self.out, ".RayTMin();")?;
471            write!(self.out, "{level}float naga_tcurrentmax = ")?;
472            self.write_expr(module, query, func_ctx)?;
473            // This gets initialized to tmax and is updated after each intersection is committed so is valid to call.
474            // Note: there is a bug in DXC's spirv backend that makes this technically UB in spirv, but HLSL backend
475            // is intended for DXIL, so it should be fine (hopefully).
476            writeln!(self.out, ".CommittedRayT();")?;
477            write!(
478                self.out,
479                "{level}if ((naga_kind == CANDIDATE_PROCEDURAL_PRIMITIVE) && (naga_tmin <="
480            )?;
481            self.write_expr(module, hit_t, func_ctx)?;
482            write!(self.out, ") && (")?;
483            self.write_expr(module, hit_t, func_ctx)?;
484            writeln!(self.out, " <= naga_tcurrentmax)) {{")?;
485            level = level.next();
486        }
487
488        write!(self.out, "{level}")?;
489        self.write_expr(module, query, func_ctx)?;
490        write!(self.out, ".CommitProceduralPrimitiveHit(")?;
491        self.write_expr(module, hit_t, func_ctx)?;
492        writeln!(self.out, ");")?;
493        if self.options.ray_query_initialization_tracking {
494            writeln!(self.out, "{base_level}}}}}")?;
495        }
496        Ok(())
497    }
498    pub(super) fn write_confirm_intersection(
499        &mut self,
500        module: &crate::Module,
501        mut level: Level,
502        query: Handle<crate::Expression>,
503        rq_tracker: &str,
504        func_ctx: &crate::back::FunctionCtx<'_>,
505    ) -> BackendResult {
506        let base_level = level;
507        if self.options.ray_query_initialization_tracking {
508            write!(self.out, "{level}if (")?;
509            self.write_contains_flags(rq_tracker, crate::back::RayQueryPoint::PROCEED.bits())?;
510            write!(self.out, " && !")?;
511            self.write_contains_flags(
512                rq_tracker,
513                crate::back::RayQueryPoint::FINISHED_TRAVERSAL.bits(),
514            )?;
515            writeln!(self.out, ") {{")?;
516            level = level.next();
517            write!(self.out, "{level}CANDIDATE_TYPE naga_kind = ")?;
518            self.write_expr(module, query, func_ctx)?;
519            writeln!(self.out, ".CandidateType();")?;
520            writeln!(
521                self.out,
522                "{level}if (naga_kind == CANDIDATE_NON_OPAQUE_TRIANGLE) {{"
523            )?;
524            level = level.next();
525        }
526
527        write!(self.out, "{level}")?;
528        self.write_expr(module, query, func_ctx)?;
529        writeln!(self.out, ".CommitNonOpaqueTriangleHit();")?;
530        if self.options.ray_query_initialization_tracking {
531            writeln!(self.out, "{base_level}}}}}")?;
532        }
533        Ok(())
534    }
535
536    pub(super) fn write_terminate(
537        &mut self,
538        module: &crate::Module,
539        mut level: Level,
540        query: Handle<crate::Expression>,
541        rq_tracker: &str,
542        func_ctx: &crate::back::FunctionCtx<'_>,
543    ) -> BackendResult {
544        let base_level = level;
545        if self.options.ray_query_initialization_tracking {
546            write!(self.out, "{level}if (")?;
547            // RayQuery::Abort() can be called any time after RayQuery::TraceRayInline() has been called.
548            // from https://microsoft.github.io/DirectX-Specs/d3d/Raytracing.html#rayquery-abort
549            self.write_contains_flags(rq_tracker, crate::back::RayQueryPoint::INITIALIZED.bits())?;
550            writeln!(self.out, ") {{")?;
551            level = level.next();
552        }
553
554        write!(self.out, "{level}")?;
555        self.write_expr(module, query, func_ctx)?;
556        writeln!(self.out, ".Abort();")?;
557
558        if self.options.ray_query_initialization_tracking {
559            writeln!(self.out, "{base_level}}}")?;
560        }
561
562        Ok(())
563    }
564}