naga/back/hlsl/
writer.rs

1use alloc::{
2    format,
3    string::{String, ToString},
4    vec::Vec,
5};
6use core::{
7    fmt::{self, Write as _},
8    mem,
9};
10
11use super::{
12    help,
13    help::{
14        WrappedArrayLength, WrappedConstructor, WrappedImageQuery, WrappedStructMatrixAccess,
15        WrappedZeroValue,
16    },
17    storage::StoreValue,
18    BackendResult, Error, FragmentEntryPoint, Options, PipelineOptions, ShaderModel,
19};
20use crate::{
21    back::{self, get_entry_points, Baked},
22    common,
23    proc::{self, index, ExternalTextureNameKey, NameKey},
24    valid, Handle, Module, RayQueryFunction, Scalar, ScalarKind, ShaderStage, TypeInner,
25};
26
27const LOCATION_SEMANTIC: &str = "LOC";
28const SPECIAL_CBUF_TYPE: &str = "NagaConstants";
29const SPECIAL_CBUF_VAR: &str = "_NagaConstants";
30const SPECIAL_FIRST_VERTEX: &str = "first_vertex";
31const SPECIAL_FIRST_INSTANCE: &str = "first_instance";
32const SPECIAL_OTHER: &str = "other";
33
34pub(crate) const MODF_FUNCTION: &str = "naga_modf";
35pub(crate) const FREXP_FUNCTION: &str = "naga_frexp";
36pub(crate) const EXTRACT_BITS_FUNCTION: &str = "naga_extractBits";
37pub(crate) const INSERT_BITS_FUNCTION: &str = "naga_insertBits";
38pub(crate) const SAMPLER_HEAP_VAR: &str = "nagaSamplerHeap";
39pub(crate) const COMPARISON_SAMPLER_HEAP_VAR: &str = "nagaComparisonSamplerHeap";
40pub(crate) const SAMPLE_EXTERNAL_TEXTURE_FUNCTION: &str = "nagaSampleExternalTexture";
41pub(crate) const ABS_FUNCTION: &str = "naga_abs";
42pub(crate) const DIV_FUNCTION: &str = "naga_div";
43pub(crate) const MOD_FUNCTION: &str = "naga_mod";
44pub(crate) const NEG_FUNCTION: &str = "naga_neg";
45pub(crate) const F2I32_FUNCTION: &str = "naga_f2i32";
46pub(crate) const F2U32_FUNCTION: &str = "naga_f2u32";
47pub(crate) const F2I64_FUNCTION: &str = "naga_f2i64";
48pub(crate) const F2U64_FUNCTION: &str = "naga_f2u64";
49pub(crate) const IMAGE_SAMPLE_BASE_CLAMP_TO_EDGE_FUNCTION: &str =
50    "nagaTextureSampleBaseClampToEdge";
51pub(crate) const IMAGE_LOAD_EXTERNAL_FUNCTION: &str = "nagaTextureLoadExternal";
52pub(crate) const RAY_QUERY_TRACKER_VARIABLE_PREFIX: &str = "naga_query_init_tracker_for_";
53/// Prefix for variables in a naga statement
54pub(crate) const INTERNAL_PREFIX: &str = "naga_";
55
56enum Index {
57    Expression(Handle<crate::Expression>),
58    Static(u32),
59}
60
61pub(super) struct EpStructMember {
62    pub(super) name: String,
63    pub(super) ty: Handle<crate::Type>,
64    // technically, this should always be `Some`
65    // (we `debug_assert!` this in `write_interface_struct`)
66    pub(super) binding: Option<crate::Binding>,
67    pub(super) index: u32,
68}
69
70/// Structure contains information required for generating
71/// wrapped structure of all entry points arguments
72pub(super) struct EntryPointBinding {
73    /// Name of the fake EP argument that contains the struct
74    /// with all the flattened input data.
75    pub(super) arg_name: String,
76    /// Generated structure name
77    pub(super) ty_name: String,
78    /// Members of generated structure
79    pub(super) members: Vec<EpStructMember>,
80    pub(super) local_invocation_index_name: Option<String>,
81}
82
83pub(super) struct EntryPointInterface {
84    /// If `Some`, the input of an entry point is gathered in a special
85    /// struct with members sorted by binding.
86    /// The `EntryPointBinding::members` array is sorted by index,
87    /// so that we can walk it in `write_ep_arguments_initialization`.
88    pub(crate) input: Option<EntryPointBinding>,
89    /// If `Some`, the output of an entry point is flattened.
90    /// The `EntryPointBinding::members` array is sorted by binding,
91    /// So that we can walk it in `Statement::Return` handler.
92    pub(crate) output: Option<EntryPointBinding>,
93    pub(crate) mesh_vertices: Option<EntryPointBinding>,
94    pub(crate) mesh_primitives: Option<EntryPointBinding>,
95    pub(crate) mesh_indices: Option<EntryPointBinding>,
96}
97
98#[derive(Clone, Eq, PartialEq, PartialOrd, Ord)]
99enum InterfaceKey {
100    Location(u32),
101    BuiltIn(crate::BuiltIn),
102    Other,
103}
104
105impl InterfaceKey {
106    const fn new(binding: Option<&crate::Binding>) -> Self {
107        match binding {
108            Some(&crate::Binding::Location { location, .. }) => Self::Location(location),
109            Some(&crate::Binding::BuiltIn(built_in)) => Self::BuiltIn(built_in),
110            None => Self::Other,
111        }
112    }
113}
114
115#[derive(Copy, Clone, PartialEq)]
116pub(super) enum Io {
117    Input,
118    Output,
119    MeshVertices,
120    MeshPrimitives,
121}
122
123/// Argument list for nested entry points
124pub(super) struct NestedEntryPointArgs {
125    /// Arguments literally declared by the user
126    pub user_args: Vec<String>,
127    pub task_payload: Option<String>,
128    pub local_invocation_index: String,
129}
130
131const fn is_subgroup_builtin_binding(binding: &Option<crate::Binding>) -> bool {
132    let &Some(crate::Binding::BuiltIn(builtin)) = binding else {
133        return false;
134    };
135    matches!(
136        builtin,
137        crate::BuiltIn::SubgroupSize
138            | crate::BuiltIn::SubgroupInvocationId
139            | crate::BuiltIn::NumSubgroups
140            | crate::BuiltIn::SubgroupId
141    )
142}
143
144/// Information for how to generate a `binding_array<sampler>` access.
145struct BindingArraySamplerInfo {
146    /// Variable name of the sampler heap
147    sampler_heap_name: &'static str,
148    /// Variable name of the sampler index buffer
149    sampler_index_buffer_name: String,
150    /// Variable name of the base index _into_ the sampler index buffer
151    binding_array_base_index_name: String,
152}
153
154impl<'a, W: fmt::Write> super::Writer<'a, W> {
155    pub fn new(out: W, options: &'a Options, pipeline_options: &'a PipelineOptions) -> Self {
156        Self {
157            out,
158            names: crate::FastHashMap::default(),
159            namer: proc::Namer::default(),
160            options,
161            pipeline_options,
162            entry_point_io: crate::FastHashMap::default(),
163            named_expressions: crate::NamedExpressions::default(),
164            wrapped: super::Wrapped::default(),
165            written_committed_intersection: false,
166            written_candidate_intersection: false,
167            continue_ctx: back::continue_forward::ContinueCtx::default(),
168            temp_access_chain: Vec::new(),
169            need_bake_expressions: Default::default(),
170            function_task_payload_var: Default::default(),
171        }
172    }
173
174    fn reset(&mut self, module: &Module) {
175        self.names.clear();
176        self.namer.reset(
177            module,
178            &super::keywords::RESERVED_SET,
179            proc::KeywordSet::empty(),
180            &super::keywords::RESERVED_CASE_INSENSITIVE_SET,
181            super::keywords::RESERVED_PREFIXES,
182            &mut self.names,
183        );
184        self.entry_point_io.clear();
185        self.named_expressions.clear();
186        self.wrapped.clear();
187        self.written_committed_intersection = false;
188        self.written_candidate_intersection = false;
189        self.continue_ctx.clear();
190        self.need_bake_expressions.clear();
191        self.function_task_payload_var.clear();
192    }
193
194    /// Generates statements to be inserted immediately before and at the very
195    /// start of the body of each loop, to defeat infinite loop reasoning.
196    /// The 0th item of the returned tuple should be inserted immediately prior
197    /// to the loop and the 1st item should be inserted at the very start of
198    /// the loop body.
199    ///
200    /// See [`back::msl::Writer::gen_force_bounded_loop_statements`] for details.
201    fn gen_force_bounded_loop_statements(
202        &mut self,
203        level: back::Level,
204    ) -> Option<(String, String)> {
205        if !self.options.force_loop_bounding {
206            return None;
207        }
208
209        let loop_bound_name = self.namer.call("loop_bound");
210        let max = u32::MAX;
211        // Count down from u32::MAX rather than up from 0 to avoid hang on
212        // certain Intel drivers. See <https://github.com/gfx-rs/wgpu/issues/7319>.
213        let decl = format!("{level}uint2 {loop_bound_name} = uint2({max}u, {max}u);");
214        let level = level.next();
215        let break_and_inc = format!(
216            "{level}if (all({loop_bound_name} == uint2(0u, 0u))) {{ break; }}
217{level}{loop_bound_name} -= uint2({loop_bound_name}.y == 0u, 1u);"
218        );
219
220        Some((decl, break_and_inc))
221    }
222
223    /// Helper method used to find which expressions of a given function require baking
224    ///
225    /// # Notes
226    /// Clears `need_bake_expressions` set before adding to it
227    fn update_expressions_to_bake(
228        &mut self,
229        module: &Module,
230        func: &crate::Function,
231        info: &valid::FunctionInfo,
232    ) {
233        use crate::Expression;
234        self.need_bake_expressions.clear();
235        for (exp_handle, expr) in func.expressions.iter() {
236            let expr_info = &info[exp_handle];
237            let min_ref_count = func.expressions[exp_handle].bake_ref_count();
238            if min_ref_count <= expr_info.ref_count {
239                self.need_bake_expressions.insert(exp_handle);
240            }
241            if let Expression::Load { pointer } = *expr {
242                if info[pointer]
243                    .ty
244                    .inner_with(&module.types)
245                    .is_atomic_pointer(&module.types)
246                {
247                    self.need_bake_expressions.insert(exp_handle);
248                }
249            }
250
251            if let Expression::Math { fun, arg, arg1, .. } = *expr {
252                match fun {
253                    crate::MathFunction::Asinh
254                    | crate::MathFunction::Acosh
255                    | crate::MathFunction::Atanh
256                    | crate::MathFunction::Unpack2x16float
257                    | crate::MathFunction::Unpack2x16snorm
258                    | crate::MathFunction::Unpack2x16unorm
259                    | crate::MathFunction::Unpack4x8snorm
260                    | crate::MathFunction::Unpack4x8unorm
261                    | crate::MathFunction::Unpack4xI8
262                    | crate::MathFunction::Unpack4xU8
263                    | crate::MathFunction::Pack2x16float
264                    | crate::MathFunction::Pack2x16snorm
265                    | crate::MathFunction::Pack2x16unorm
266                    | crate::MathFunction::Pack4x8snorm
267                    | crate::MathFunction::Pack4x8unorm
268                    | crate::MathFunction::Pack4xI8
269                    | crate::MathFunction::Pack4xU8
270                    | crate::MathFunction::Pack4xI8Clamp
271                    | crate::MathFunction::Pack4xU8Clamp => {
272                        self.need_bake_expressions.insert(arg);
273                    }
274                    crate::MathFunction::CountLeadingZeros => {
275                        let inner = info[exp_handle].ty.inner_with(&module.types);
276                        if let Some(ScalarKind::Sint) = inner.scalar_kind() {
277                            self.need_bake_expressions.insert(arg);
278                        }
279                    }
280                    crate::MathFunction::Dot4U8Packed | crate::MathFunction::Dot4I8Packed => {
281                        self.need_bake_expressions.insert(arg);
282                        self.need_bake_expressions.insert(arg1.unwrap());
283                    }
284                    _ => {}
285                }
286            }
287
288            if let Expression::Derivative { axis, ctrl, expr } = *expr {
289                use crate::{DerivativeAxis as Axis, DerivativeControl as Ctrl};
290                if axis == Axis::Width && (ctrl == Ctrl::Coarse || ctrl == Ctrl::Fine) {
291                    self.need_bake_expressions.insert(expr);
292                }
293            }
294
295            if let Expression::GlobalVariable(_) = *expr {
296                let inner = info[exp_handle].ty.inner_with(&module.types);
297
298                if let TypeInner::Sampler { .. } = *inner {
299                    self.need_bake_expressions.insert(exp_handle);
300                }
301            }
302        }
303        for statement in func.body.iter() {
304            match *statement {
305                crate::Statement::SubgroupCollectiveOperation {
306                    op: _,
307                    collective_op: crate::CollectiveOperation::InclusiveScan,
308                    argument,
309                    result: _,
310                } => {
311                    self.need_bake_expressions.insert(argument);
312                }
313                crate::Statement::Atomic {
314                    fun: crate::AtomicFunction::Exchange { compare: Some(cmp) },
315                    ..
316                } => {
317                    self.need_bake_expressions.insert(cmp);
318                }
319                _ => {}
320            }
321        }
322    }
323
324    pub fn write(
325        &mut self,
326        module: &Module,
327        module_info: &valid::ModuleInfo,
328        fragment_entry_point: Option<&FragmentEntryPoint<'_>>,
329    ) -> Result<super::ReflectionInfo, Error> {
330        self.reset(module);
331
332        if module.uses_mesh_shaders() && self.options.shader_model < ShaderModel::V6_5 {
333            return Err(Error::ShaderModelTooLow(
334                "mesh shaders".to_string(),
335                ShaderModel::V6_5,
336            ));
337        }
338
339        // Write special constants, if needed
340        if let Some(ref bt) = self.options.special_constants_binding {
341            writeln!(self.out, "struct {SPECIAL_CBUF_TYPE} {{")?;
342            writeln!(self.out, "{}int {};", back::INDENT, SPECIAL_FIRST_VERTEX)?;
343            writeln!(self.out, "{}int {};", back::INDENT, SPECIAL_FIRST_INSTANCE)?;
344            writeln!(self.out, "{}uint {};", back::INDENT, SPECIAL_OTHER)?;
345            writeln!(self.out, "}};")?;
346            write!(
347                self.out,
348                "ConstantBuffer<{}> {}: register(b{}",
349                SPECIAL_CBUF_TYPE, SPECIAL_CBUF_VAR, bt.register
350            )?;
351            if bt.space != 0 {
352                write!(self.out, ", space{}", bt.space)?;
353            }
354            writeln!(self.out, ");")?;
355
356            // Extra newline for readability
357            writeln!(self.out)?;
358        }
359
360        for (group, bt) in self.options.dynamic_storage_buffer_offsets_targets.iter() {
361            writeln!(self.out, "struct __dynamic_buffer_offsetsTy{group} {{")?;
362            for i in 0..bt.size {
363                writeln!(self.out, "{}uint _{};", back::INDENT, i)?;
364            }
365            writeln!(self.out, "}};")?;
366            writeln!(
367                self.out,
368                "ConstantBuffer<__dynamic_buffer_offsetsTy{}> __dynamic_buffer_offsets{}: register(b{}, space{});",
369                group, group, bt.register, bt.space
370            )?;
371
372            // Extra newline for readability
373            writeln!(self.out)?;
374        }
375
376        // Save all entry point output types
377        let ep_results = module
378            .entry_points
379            .iter()
380            .map(|ep| (ep.stage, ep.function.result.clone()))
381            .collect::<Vec<(ShaderStage, Option<crate::FunctionResult>)>>();
382
383        self.write_all_mat_cx2_typedefs_and_functions(module)?;
384
385        // Write all structs
386        for (handle, ty) in module.types.iter() {
387            if let TypeInner::Struct { ref members, span } = ty.inner {
388                if module.types[members.last().unwrap().ty]
389                    .inner
390                    .is_dynamically_sized(&module.types)
391                {
392                    // unsized arrays can only be in storage buffers,
393                    // for which we use `ByteAddressBuffer` anyway.
394                    continue;
395                }
396
397                let ep_result = ep_results.iter().find(|e| {
398                    if let Some(ref result) = e.1 {
399                        result.ty == handle
400                    } else {
401                        false
402                    }
403                });
404
405                self.write_struct(
406                    module,
407                    handle,
408                    members,
409                    span,
410                    ep_result.map(|r| (r.0, Io::Output)),
411                )?;
412                writeln!(self.out)?;
413            }
414        }
415
416        self.write_special_functions(module)?;
417
418        self.write_wrapped_expression_functions(module, &module.global_expressions, None)?;
419        self.write_wrapped_zero_value_functions(module, &module.global_expressions)?;
420
421        // Write all named constants
422        let mut constants = module
423            .constants
424            .iter()
425            .filter(|&(_, c)| c.name.is_some())
426            .peekable();
427        while let Some((handle, _)) = constants.next() {
428            self.write_global_constant(module, handle)?;
429            // Add extra newline for readability on last iteration
430            if constants.peek().is_none() {
431                writeln!(self.out)?;
432            }
433        }
434
435        // Write all globals
436        for (global, _) in module.global_variables.iter() {
437            self.write_global(module, global)?;
438        }
439
440        if !module.global_variables.is_empty() {
441            // Add extra newline for readability
442            writeln!(self.out)?;
443        }
444
445        let ep_range = get_entry_points(module, self.pipeline_options.entry_point.as_ref())
446            .map_err(|(stage, name)| Error::EntryPointNotFound(stage, name))?;
447
448        // Write all entry points wrapped structs
449        for index in ep_range.clone() {
450            let ep = &module.entry_points[index];
451            let ep_name = self.names[&NameKey::EntryPoint(index as u16)].clone();
452            let ep_io = self.write_ep_interface(module, ep, &ep_name, fragment_entry_point)?;
453            self.entry_point_io.insert(index, ep_io);
454        }
455
456        // Write all regular functions
457        for (handle, function) in module.functions.iter() {
458            let info = &module_info[handle];
459
460            // Check if all of the globals are accessible
461            if !self.options.fake_missing_bindings {
462                if let Some((var_handle, _)) =
463                    module
464                        .global_variables
465                        .iter()
466                        .find(|&(var_handle, var)| match var.binding {
467                            Some(ref binding) if !info[var_handle].is_empty() => {
468                                self.options.resolve_resource_binding(binding).is_err()
469                                    && self
470                                        .options
471                                        .resolve_external_texture_resource_binding(binding)
472                                        .is_err()
473                            }
474                            _ => false,
475                        })
476                {
477                    log::debug!(
478                        "Skipping function {:?} (name {:?}) because global {:?} is inaccessible",
479                        handle,
480                        function.name,
481                        var_handle
482                    );
483                    continue;
484                }
485            }
486
487            let ctx = back::FunctionCtx {
488                ty: back::FunctionType::Function(handle),
489                info,
490                expressions: &function.expressions,
491                named_expressions: &function.named_expressions,
492            };
493            let name = self.names[&NameKey::Function(handle)].clone();
494
495            self.write_wrapped_functions(module, &ctx)?;
496
497            self.write_function(module, name.as_str(), function, &ctx, info, String::new())?;
498
499            writeln!(self.out)?;
500        }
501
502        let mut translated_ep_names = Vec::with_capacity(ep_range.len());
503
504        // Write all entry points
505        for index in ep_range {
506            let ep = &module.entry_points[index];
507            let info = module_info.get_entry_point(index);
508
509            if !self.options.fake_missing_bindings {
510                let mut ep_error = None;
511                for (var_handle, var) in module.global_variables.iter() {
512                    match var.binding {
513                        Some(ref binding) if !info[var_handle].is_empty() => {
514                            if let Err(err) = self.options.resolve_resource_binding(binding) {
515                                if self
516                                    .options
517                                    .resolve_external_texture_resource_binding(binding)
518                                    .is_err()
519                                {
520                                    ep_error = Some(err);
521                                    break;
522                                }
523                            }
524                        }
525                        _ => {}
526                    }
527                }
528                if let Some(err) = ep_error {
529                    translated_ep_names.push(Err(err));
530                    continue;
531                }
532            }
533
534            let ctx = back::FunctionCtx {
535                ty: back::FunctionType::EntryPoint(index as u16),
536                info,
537                expressions: &ep.function.expressions,
538                named_expressions: &ep.function.named_expressions,
539            };
540
541            self.write_wrapped_functions(module, &ctx)?;
542
543            // Mesh/task shaders have a wrapper entry point which is declared after the "main"
544            // user-written function. We therefore cannot always just document the next function.
545            let mut attribute_string = String::new();
546            if ep.stage.compute_like() {
547                // HLSL is calling workgroup size "num threads"
548                let num_threads = ep.workgroup_size;
549                writeln!(
550                    attribute_string,
551                    "[numthreads({}, {}, {})]",
552                    num_threads[0], num_threads[1], num_threads[2]
553                )?;
554            }
555            if let Some(ref info) = ep.mesh_info {
556                let topology_str = match info.topology {
557                    crate::MeshOutputTopology::Points => unreachable!(),
558                    crate::MeshOutputTopology::Lines => "line",
559                    crate::MeshOutputTopology::Triangles => "triangle",
560                };
561                writeln!(attribute_string, "[outputtopology(\"{topology_str}\")]")?;
562            }
563
564            let name = self.names[&NameKey::EntryPoint(index as u16)].clone();
565            self.write_function(module, &name, &ep.function, &ctx, info, attribute_string)?;
566
567            if index < module.entry_points.len() - 1 {
568                writeln!(self.out)?;
569            }
570
571            translated_ep_names.push(Ok(name));
572        }
573
574        Ok(super::ReflectionInfo {
575            entry_point_names: translated_ep_names,
576        })
577    }
578
579    fn write_modifier(&mut self, binding: &crate::Binding) -> BackendResult {
580        match *binding {
581            crate::Binding::BuiltIn(crate::BuiltIn::Position { invariant: true }) => {
582                write!(self.out, "precise ")?;
583            }
584            crate::Binding::BuiltIn(crate::BuiltIn::Barycentric { perspective: false }) => {
585                write!(self.out, "noperspective ")?;
586            }
587            crate::Binding::Location {
588                interpolation,
589                sampling,
590                ..
591            } => {
592                if let Some(interpolation) = interpolation {
593                    if let Some(string) = interpolation.to_hlsl_str() {
594                        write!(self.out, "{string} ")?
595                    }
596                }
597
598                if let Some(sampling) = sampling {
599                    if let Some(string) = sampling.to_hlsl_str() {
600                        write!(self.out, "{string} ")?
601                    }
602                }
603            }
604            crate::Binding::BuiltIn(_) => {}
605        }
606
607        Ok(())
608    }
609
610    //TODO: we could force fragment outputs to always go through `entry_point_io.output` path
611    // if they are struct, so that the `stage` argument here could be omitted.
612    pub(super) fn write_semantic(
613        &mut self,
614        binding: &Option<crate::Binding>,
615        stage: Option<(ShaderStage, Io)>,
616    ) -> BackendResult {
617        let is_per_primitive = match *binding {
618            Some(crate::Binding::BuiltIn(builtin)) if !is_subgroup_builtin_binding(binding) => {
619                if builtin == crate::BuiltIn::ViewIndex
620                    && self.options.shader_model < ShaderModel::V6_1
621                {
622                    return Err(Error::ShaderModelTooLow(
623                        "used @builtin(view_index) or SV_ViewID".to_string(),
624                        ShaderModel::V6_1,
625                    ));
626                }
627                if let Some(builtin_str) = builtin.to_hlsl_str()? {
628                    write!(self.out, " : {builtin_str}")?;
629                }
630                false
631            }
632            Some(crate::Binding::Location {
633                blend_src: Some(1),
634                per_primitive,
635                ..
636            }) => {
637                write!(self.out, " : SV_Target1")?;
638                per_primitive
639            }
640            Some(crate::Binding::Location {
641                location,
642                per_primitive,
643                ..
644            }) => {
645                if stage == Some((ShaderStage::Fragment, Io::Output)) {
646                    write!(self.out, " : SV_Target{location}")?;
647                } else {
648                    write!(self.out, " : {LOCATION_SEMANTIC}{location}")?;
649                }
650                per_primitive
651            }
652            _ => false,
653        };
654        if is_per_primitive {
655            write!(self.out, " : primitive")?;
656        }
657
658        Ok(())
659    }
660
661    pub(super) fn write_interface_struct(
662        &mut self,
663        module: &Module,
664        shader_stage: (ShaderStage, Io),
665        struct_name: String,
666        var_name: Option<&str>,
667        mut members: Vec<EpStructMember>,
668    ) -> Result<EntryPointBinding, Error> {
669        let struct_name = self.namer.call(&struct_name);
670        // Sort the members so that first come the user-defined varyings
671        // in ascending locations, and then built-ins. This allows VS and FS
672        // interfaces to match with regards to order.
673        members.sort_by_key(|m| InterfaceKey::new(m.binding.as_ref()));
674
675        write!(self.out, "struct {struct_name}")?;
676        writeln!(self.out, " {{")?;
677        let mut local_invocation_index_name = None;
678        let mut subgroup_id_used = false;
679        for m in members.iter() {
680            // Sanity check that each IO member is a built-in or is assigned a
681            // location. Also see note about nesting in `write_ep_input_struct`.
682            debug_assert!(m.binding.is_some());
683
684            match m.binding {
685                Some(crate::Binding::BuiltIn(crate::BuiltIn::SubgroupId)) => {
686                    subgroup_id_used = true;
687                }
688                Some(crate::Binding::BuiltIn(crate::BuiltIn::LocalInvocationIndex)) => {
689                    local_invocation_index_name = Some(m.name.clone());
690                }
691                _ => (),
692            }
693
694            if is_subgroup_builtin_binding(&m.binding) {
695                continue;
696            }
697            write!(self.out, "{}", back::INDENT)?;
698            if let Some(ref binding) = m.binding {
699                self.write_modifier(binding)?;
700            }
701            self.write_type(module, m.ty)?;
702            write!(self.out, " {}", &m.name)?;
703            self.write_semantic(&m.binding, Some(shader_stage))?;
704            writeln!(self.out, ";")?;
705        }
706        if subgroup_id_used && local_invocation_index_name.is_none() {
707            let name = self.namer.call("local_invocation_index");
708            writeln!(self.out, "{}uint {name} : SV_GroupIndex;", back::INDENT)?;
709            local_invocation_index_name = Some(name);
710        }
711        writeln!(self.out, "}};")?;
712        writeln!(self.out)?;
713
714        // See ordering notes on EntryPointInterface fields
715        match shader_stage.1 {
716            Io::Input => {
717                // bring back the original order
718                members.sort_by_key(|m| m.index);
719            }
720            Io::Output | Io::MeshVertices | Io::MeshPrimitives => {
721                // keep it sorted by binding
722            }
723        }
724
725        Ok(EntryPointBinding {
726            arg_name: self
727                .namer
728                .call(var_name.unwrap_or(struct_name.to_lowercase().as_str())),
729            ty_name: struct_name,
730            members,
731            local_invocation_index_name,
732        })
733    }
734
735    /// Flatten all entry point arguments into a single struct.
736    /// This is needed since we need to re-order them: first placing user locations,
737    /// then built-ins.
738    fn write_ep_input_struct(
739        &mut self,
740        module: &Module,
741        func: &crate::Function,
742        stage: ShaderStage,
743        entry_point_name: &str,
744    ) -> Result<EntryPointBinding, Error> {
745        let struct_name = format!("{stage:?}Input_{entry_point_name}");
746
747        let mut fake_members = Vec::new();
748        for arg in func.arguments.iter() {
749            // NOTE: We don't need to handle nesting structs. All members must
750            // be either built-ins or assigned a location. I.E. `binding` is
751            // `Some`. This is checked in `VaryingContext::validate`. See:
752            // https://gpuweb.github.io/gpuweb/wgsl/#input-output-locations
753            match module.types[arg.ty].inner {
754                TypeInner::Struct { ref members, .. } => {
755                    for member in members.iter() {
756                        let name = self.namer.call_or(&member.name, "member");
757                        let index = fake_members.len() as u32;
758                        fake_members.push(EpStructMember {
759                            name,
760                            ty: member.ty,
761                            binding: member.binding.clone(),
762                            index,
763                        });
764                    }
765                }
766                _ => {
767                    let member_name = self.namer.call_or(&arg.name, "member");
768                    let index = fake_members.len() as u32;
769                    fake_members.push(EpStructMember {
770                        name: member_name,
771                        ty: arg.ty,
772                        binding: arg.binding.clone(),
773                        index,
774                    });
775                }
776            }
777        }
778
779        self.write_interface_struct(module, (stage, Io::Input), struct_name, None, fake_members)
780    }
781
782    /// Flatten all entry point results into a single struct.
783    /// This is needed since we need to re-order them: first placing user locations,
784    /// then built-ins.
785    fn write_ep_output_struct(
786        &mut self,
787        module: &Module,
788        result: &crate::FunctionResult,
789        stage: ShaderStage,
790        entry_point_name: &str,
791        frag_ep: Option<&FragmentEntryPoint<'_>>,
792    ) -> Result<EntryPointBinding, Error> {
793        let struct_name = format!("{stage:?}Output_{entry_point_name}");
794
795        let empty = [];
796        let members = match module.types[result.ty].inner {
797            TypeInner::Struct { ref members, .. } => members,
798            ref other => {
799                log::error!("Unexpected {other:?} output type without a binding");
800                &empty[..]
801            }
802        };
803
804        // Gather list of fragment input locations. We use this below to remove user-defined
805        // varyings from VS outputs that aren't in the FS inputs. This makes the VS interface match
806        // as long as the FS inputs are a subset of the VS outputs. This is only applied if the
807        // writer is supplied with information about the fragment entry point.
808        let fs_input_locs = if let (Some(frag_ep), ShaderStage::Vertex) = (frag_ep, stage) {
809            let mut fs_input_locs = Vec::new();
810            for arg in frag_ep.func.arguments.iter() {
811                let mut push_if_location = |binding: &Option<crate::Binding>| match *binding {
812                    Some(crate::Binding::Location { location, .. }) => fs_input_locs.push(location),
813                    Some(crate::Binding::BuiltIn(_)) | None => {}
814                };
815
816                // NOTE: We don't need to handle struct nesting. See note in
817                // `write_ep_input_struct`.
818                match frag_ep.module.types[arg.ty].inner {
819                    TypeInner::Struct { ref members, .. } => {
820                        for member in members.iter() {
821                            push_if_location(&member.binding);
822                        }
823                    }
824                    _ => push_if_location(&arg.binding),
825                }
826            }
827            fs_input_locs.sort();
828            Some(fs_input_locs)
829        } else {
830            None
831        };
832
833        let mut fake_members = Vec::new();
834        for (index, member) in members.iter().enumerate() {
835            if let Some(ref fs_input_locs) = fs_input_locs {
836                match member.binding {
837                    Some(crate::Binding::Location { location, .. }) => {
838                        if fs_input_locs.binary_search(&location).is_err() {
839                            continue;
840                        }
841                    }
842                    Some(crate::Binding::BuiltIn(_)) | None => {}
843                }
844            }
845
846            let member_name = self.namer.call_or(&member.name, "member");
847            fake_members.push(EpStructMember {
848                name: member_name,
849                ty: member.ty,
850                binding: member.binding.clone(),
851                index: index as u32,
852            });
853        }
854
855        self.write_interface_struct(module, (stage, Io::Output), struct_name, None, fake_members)
856    }
857
858    /// Writes special interface structures for an entry point. The special structures have
859    /// all the fields flattened into them and sorted by binding. They are needed to emulate
860    /// subgroup built-ins and to make the interfaces between VS outputs and FS inputs match.
861    fn write_ep_interface(
862        &mut self,
863        module: &Module,
864        ep: &crate::EntryPoint,
865        ep_name: &str,
866        frag_ep: Option<&FragmentEntryPoint<'_>>,
867    ) -> Result<EntryPointInterface, Error> {
868        let func = &ep.function;
869        let stage = ep.stage;
870        Ok(EntryPointInterface {
871            input: if !func.arguments.is_empty()
872                && (stage == ShaderStage::Fragment
873                    || func
874                        .arguments
875                        .iter()
876                        .any(|arg| is_subgroup_builtin_binding(&arg.binding)))
877            {
878                Some(self.write_ep_input_struct(module, func, stage, ep_name)?)
879            } else {
880                None
881            },
882            output: match func.result {
883                Some(ref fr) if fr.binding.is_none() && stage == ShaderStage::Vertex => {
884                    Some(self.write_ep_output_struct(module, fr, stage, ep_name, frag_ep)?)
885                }
886                _ => None,
887            },
888            mesh_vertices: if let Some(ref info) = ep.mesh_info {
889                Some(self.write_ep_mesh_output_struct(module, ep_name, false, info)?)
890            } else {
891                None
892            },
893            mesh_primitives: if let Some(ref info) = ep.mesh_info {
894                Some(self.write_ep_mesh_output_struct(module, ep_name, true, info)?)
895            } else {
896                None
897            },
898            mesh_indices: if let Some(ref info) = ep.mesh_info {
899                Some(self.write_ep_mesh_output_indices(info.topology)?)
900            } else {
901                None
902            },
903        })
904    }
905
906    fn write_ep_argument_initialization(
907        &mut self,
908        ep: &crate::EntryPoint,
909        ep_input: &EntryPointBinding,
910        fake_member: &EpStructMember,
911    ) -> BackendResult {
912        match fake_member.binding {
913            Some(crate::Binding::BuiltIn(crate::BuiltIn::SubgroupSize)) => {
914                write!(self.out, "WaveGetLaneCount()")?
915            }
916            Some(crate::Binding::BuiltIn(crate::BuiltIn::SubgroupInvocationId)) => {
917                write!(self.out, "WaveGetLaneIndex()")?
918            }
919            Some(crate::Binding::BuiltIn(crate::BuiltIn::NumSubgroups)) => write!(
920                self.out,
921                "({}u + WaveGetLaneCount() - 1u) / WaveGetLaneCount()",
922                ep.workgroup_size[0] * ep.workgroup_size[1] * ep.workgroup_size[2]
923            )?,
924            Some(crate::Binding::BuiltIn(crate::BuiltIn::SubgroupId)) => {
925                write!(
926                    self.out,
927                    "{}.{} / WaveGetLaneCount()",
928                    ep_input.arg_name,
929                    // When writing SubgroupId, we always guarantee that local_invocation_index_name is written
930                    ep_input.local_invocation_index_name.as_ref().unwrap()
931                )?;
932            }
933            Some(crate::Binding::Location {
934                interpolation: Some(crate::Interpolation::PerVertex),
935                ..
936            }) => {
937                if self.options.shader_model < ShaderModel::V6_1 {
938                    return Err(Error::ShaderModelTooLow(
939                        "per_vertex fragment inputs".to_string(),
940                        ShaderModel::V6_1,
941                    ));
942                }
943                write!(
944                    self.out,
945                    "{{ GetAttributeAtVertex({0}.{1}, 0), GetAttributeAtVertex({0}.{1}, 1), GetAttributeAtVertex({0}.{1}, 2) }}",
946                    ep_input.arg_name,
947                    fake_member.name,
948                )?;
949            }
950            _ => {
951                write!(self.out, "{}.{}", ep_input.arg_name, fake_member.name)?;
952            }
953        }
954        Ok(())
955    }
956
957    /// Write an entry point preface that initializes the arguments as specified in IR.
958    fn write_ep_arguments_initialization(
959        &mut self,
960        module: &Module,
961        func: &crate::Function,
962        ep_index: u16,
963    ) -> BackendResult {
964        let ep = &module.entry_points[ep_index as usize];
965        let ep_input = match self
966            .entry_point_io
967            .get_mut(&(ep_index as usize))
968            .unwrap()
969            .input
970            .take()
971        {
972            Some(ep_input) => ep_input,
973            None => return Ok(()),
974        };
975        let mut fake_iter = ep_input.members.iter();
976        for (arg_index, arg) in func.arguments.iter().enumerate() {
977            write!(self.out, "{}", back::INDENT)?;
978            self.write_type(module, arg.ty)?;
979            let arg_name = &self.names[&NameKey::EntryPointArgument(ep_index, arg_index as u32)];
980            write!(self.out, " {arg_name}")?;
981            match module.types[arg.ty].inner {
982                TypeInner::Array { base, size, .. } => {
983                    self.write_array_size(module, base, size)?;
984                    write!(self.out, " = ")?;
985                    self.write_ep_argument_initialization(
986                        ep,
987                        &ep_input,
988                        fake_iter.next().unwrap(),
989                    )?;
990                    writeln!(self.out, ";")?;
991                }
992                TypeInner::Struct { ref members, .. } => {
993                    write!(self.out, " = {{ ")?;
994                    for index in 0..members.len() {
995                        if index != 0 {
996                            write!(self.out, ", ")?;
997                        }
998                        self.write_ep_argument_initialization(
999                            ep,
1000                            &ep_input,
1001                            fake_iter.next().unwrap(),
1002                        )?;
1003                    }
1004                    writeln!(self.out, " }};")?;
1005                }
1006                _ => {
1007                    write!(self.out, " = ")?;
1008                    self.write_ep_argument_initialization(
1009                        ep,
1010                        &ep_input,
1011                        fake_iter.next().unwrap(),
1012                    )?;
1013                    writeln!(self.out, ";")?;
1014                }
1015            }
1016        }
1017        assert!(fake_iter.next().is_none());
1018        Ok(())
1019    }
1020
1021    /// Helper method used to write global variables
1022    /// # Notes
1023    /// Always adds a newline
1024    fn write_global(
1025        &mut self,
1026        module: &Module,
1027        handle: Handle<crate::GlobalVariable>,
1028    ) -> BackendResult {
1029        let global = &module.global_variables[handle];
1030        let inner = &module.types[global.ty].inner;
1031
1032        let handle_ty = match *inner {
1033            TypeInner::BindingArray { ref base, .. } => &module.types[*base].inner,
1034            _ => inner,
1035        };
1036
1037        // External textures are handled entirely differently, so defer entirely to that method.
1038        // We do so prior to calling resolve_resource_binding() below, as we even need to resolve
1039        // their bindings separately.
1040        let is_external_texture = matches!(
1041            *handle_ty,
1042            TypeInner::Image {
1043                class: crate::ImageClass::External,
1044                ..
1045            }
1046        );
1047        if is_external_texture {
1048            return self.write_global_external_texture(module, handle, global);
1049        }
1050
1051        if let Some(ref binding) = global.binding {
1052            if let Err(err) = self.options.resolve_resource_binding(binding) {
1053                log::debug!(
1054                    "Skipping global {:?} (name {:?}) for being inaccessible: {}",
1055                    handle,
1056                    global.name,
1057                    err,
1058                );
1059                return Ok(());
1060            }
1061        }
1062
1063        // Samplers are handled entirely differently, so defer entirely to that method.
1064        let is_sampler = matches!(*handle_ty, TypeInner::Sampler { .. });
1065
1066        if is_sampler {
1067            return self.write_global_sampler(module, handle, global);
1068        }
1069
1070        // https://docs.microsoft.com/en-us/windows/win32/direct3dhlsl/dx-graphics-hlsl-variable-register
1071        let register_ty = match global.space {
1072            crate::AddressSpace::Function => unreachable!("Function address space"),
1073            crate::AddressSpace::Private => {
1074                write!(self.out, "static ")?;
1075                self.write_type(module, global.ty)?;
1076                ""
1077            }
1078            crate::AddressSpace::WorkGroup | crate::AddressSpace::TaskPayload => {
1079                write!(self.out, "groupshared ")?;
1080                self.write_type(module, global.ty)?;
1081                ""
1082            }
1083            crate::AddressSpace::Uniform => {
1084                // constant buffer declarations are expected to be inlined, e.g.
1085                // `cbuffer foo: register(b0) { field1: type1; }`
1086                write!(self.out, "cbuffer")?;
1087                "b"
1088            }
1089            crate::AddressSpace::Storage { access } => {
1090                if global
1091                    .memory_decorations
1092                    .contains(crate::MemoryDecorations::COHERENT)
1093                {
1094                    write!(self.out, "globallycoherent ")?;
1095                }
1096                let (prefix, register) = if access.contains(crate::StorageAccess::STORE) {
1097                    ("RW", "u")
1098                } else {
1099                    ("", "t")
1100                };
1101                write!(self.out, "{prefix}ByteAddressBuffer")?;
1102                register
1103            }
1104            crate::AddressSpace::Handle => {
1105                let register = match *handle_ty {
1106                    // all storage textures are UAV, unconditionally
1107                    TypeInner::Image {
1108                        class: crate::ImageClass::Storage { .. },
1109                        ..
1110                    } => "u",
1111                    _ => "t",
1112                };
1113                self.write_type(module, global.ty)?;
1114                register
1115            }
1116            crate::AddressSpace::Immediate => {
1117                // The type of the immediates will be wrapped in `ConstantBuffer`
1118                write!(self.out, "ConstantBuffer<")?;
1119                "b"
1120            }
1121            crate::AddressSpace::RayPayload | crate::AddressSpace::IncomingRayPayload => {
1122                unimplemented!()
1123            }
1124        };
1125
1126        // If the global is a immediate data write the type now because it will be a
1127        // generic argument to `ConstantBuffer`
1128        if global.space == crate::AddressSpace::Immediate {
1129            self.write_global_type(module, global.ty)?;
1130
1131            // need to write the array size if the type was emitted with `write_type`
1132            if let TypeInner::Array { base, size, .. } = module.types[global.ty].inner {
1133                self.write_array_size(module, base, size)?;
1134            }
1135
1136            // Close the angled brackets for the generic argument
1137            write!(self.out, ">")?;
1138        }
1139
1140        let name = &self.names[&NameKey::GlobalVariable(handle)];
1141        write!(self.out, " {name}")?;
1142
1143        // Immediates need to be assigned a binding explicitly by the consumer
1144        // since naga has no way to know the binding from the shader alone
1145        if global.space == crate::AddressSpace::Immediate {
1146            match module.types[global.ty].inner {
1147                TypeInner::Struct { .. } => {}
1148                _ => {
1149                    return Err(Error::Unimplemented(format!(
1150                        "push-constant '{name}' has non-struct type; tracked by: https://github.com/gfx-rs/wgpu/issues/5683"
1151                    )));
1152                }
1153            }
1154
1155            let target = self
1156                .options
1157                .immediates_target
1158                .as_ref()
1159                .expect("No bind target was defined for the immediates block");
1160            write!(self.out, ": register(b{}", target.register)?;
1161            if target.space != 0 {
1162                write!(self.out, ", space{}", target.space)?;
1163            }
1164            write!(self.out, ")")?;
1165        }
1166
1167        if let Some(ref binding) = global.binding {
1168            // this was already resolved earlier when we started evaluating an entry point.
1169            let bt = self.options.resolve_resource_binding(binding).unwrap();
1170
1171            // need to write the binding array size if the type was emitted with `write_type`
1172            if let TypeInner::BindingArray { base, size, .. } = module.types[global.ty].inner {
1173                if let Some(overridden_size) = bt.binding_array_size {
1174                    write!(self.out, "[{overridden_size}]")?;
1175                } else {
1176                    self.write_array_size(module, base, size)?;
1177                }
1178            }
1179
1180            write!(self.out, " : register({}{}", register_ty, bt.register)?;
1181            if bt.space != 0 {
1182                write!(self.out, ", space{}", bt.space)?;
1183            }
1184            write!(self.out, ")")?;
1185        } else {
1186            // need to write the array size if the type was emitted with `write_type`
1187            if let TypeInner::Array { base, size, .. } = module.types[global.ty].inner {
1188                self.write_array_size(module, base, size)?;
1189            }
1190            if global.space == crate::AddressSpace::Private {
1191                write!(self.out, " = ")?;
1192                if let Some(init) = global.init {
1193                    self.write_const_expression(module, init, &module.global_expressions)?;
1194                } else {
1195                    self.write_default_init(module, global.ty)?;
1196                }
1197            }
1198        }
1199
1200        if global.space == crate::AddressSpace::Uniform {
1201            write!(self.out, " {{ ")?;
1202
1203            self.write_global_type(module, global.ty)?;
1204
1205            write!(
1206                self.out,
1207                " {}",
1208                &self.names[&NameKey::GlobalVariable(handle)]
1209            )?;
1210
1211            // need to write the array size if the type was emitted with `write_type`
1212            if let TypeInner::Array { base, size, .. } = module.types[global.ty].inner {
1213                self.write_array_size(module, base, size)?;
1214            }
1215
1216            writeln!(self.out, "; }}")?;
1217        } else {
1218            writeln!(self.out, ";")?;
1219        }
1220
1221        Ok(())
1222    }
1223
1224    fn write_global_sampler(
1225        &mut self,
1226        module: &Module,
1227        handle: Handle<crate::GlobalVariable>,
1228        global: &crate::GlobalVariable,
1229    ) -> BackendResult {
1230        let binding = *global.binding.as_ref().unwrap();
1231
1232        let key = super::SamplerIndexBufferKey {
1233            group: binding.group,
1234        };
1235        self.write_wrapped_sampler_buffer(key)?;
1236
1237        // This was already validated, so we can confidently unwrap it.
1238        let bt = self.options.resolve_resource_binding(&binding).unwrap();
1239
1240        match module.types[global.ty].inner {
1241            TypeInner::Sampler { comparison } => {
1242                // If we are generating a static access, we create a variable for the sampler.
1243                //
1244                // This prevents the DXIL from containing multiple lookups for the sampler, which
1245                // the backend compiler will then have to eliminate. AMD does seem to be able to
1246                // eliminate these, but better safe than sorry.
1247
1248                write!(self.out, "static const ")?;
1249                self.write_type(module, global.ty)?;
1250
1251                let heap_var = if comparison {
1252                    COMPARISON_SAMPLER_HEAP_VAR
1253                } else {
1254                    SAMPLER_HEAP_VAR
1255                };
1256
1257                let index_buffer_name = &self.wrapped.sampler_index_buffers[&key];
1258                let name = &self.names[&NameKey::GlobalVariable(handle)];
1259                writeln!(
1260                    self.out,
1261                    " {name} = {heap_var}[{index_buffer_name}[{register}]];",
1262                    register = bt.register
1263                )?;
1264            }
1265            TypeInner::BindingArray { .. } => {
1266                // If we are generating a binding array, we cannot directly access the sampler as the index
1267                // into the sampler index buffer is unknown at compile time. Instead we generate a constant
1268                // that represents the "base" index into the sampler index buffer. This constant is added
1269                // to the user provided index to get the final index into the sampler index buffer.
1270
1271                let name = &self.names[&NameKey::GlobalVariable(handle)];
1272                writeln!(
1273                    self.out,
1274                    "static const uint {name} = {register};",
1275                    register = bt.register
1276                )?;
1277            }
1278            _ => unreachable!(),
1279        };
1280
1281        Ok(())
1282    }
1283
1284    /// Write the declarations for an external texture global variable.
1285    /// These are emitted as multiple global variables: Three `Texture2D`s
1286    /// (one for each plane) and a parameters cbuffer.
1287    fn write_global_external_texture(
1288        &mut self,
1289        module: &Module,
1290        handle: Handle<crate::GlobalVariable>,
1291        global: &crate::GlobalVariable,
1292    ) -> BackendResult {
1293        let res_binding = global
1294            .binding
1295            .as_ref()
1296            .expect("External texture global variables must have a resource binding");
1297        let ext_tex_bindings = match self
1298            .options
1299            .resolve_external_texture_resource_binding(res_binding)
1300        {
1301            Ok(bindings) => bindings,
1302            Err(err) => {
1303                log::debug!(
1304                    "Skipping global {:?} (name {:?}) for being inaccessible: {}",
1305                    handle,
1306                    global.name,
1307                    err,
1308                );
1309                return Ok(());
1310            }
1311        };
1312
1313        let mut write_plane = |bt: &super::BindTarget, name| -> BackendResult {
1314            write!(
1315                self.out,
1316                "Texture2D<float4> {}: register(t{}",
1317                name, bt.register
1318            )?;
1319            if bt.space != 0 {
1320                write!(self.out, ", space{}", bt.space)?;
1321            }
1322            writeln!(self.out, ");")?;
1323            Ok(())
1324        };
1325        for (i, bt) in ext_tex_bindings.planes.iter().enumerate() {
1326            let plane_name = &self.names
1327                [&NameKey::ExternalTextureGlobalVariable(handle, ExternalTextureNameKey::Plane(i))];
1328            write_plane(bt, plane_name)?;
1329        }
1330
1331        let params_name = &self.names
1332            [&NameKey::ExternalTextureGlobalVariable(handle, ExternalTextureNameKey::Params)];
1333        let params_ty_name =
1334            &self.names[&NameKey::Type(module.special_types.external_texture_params.unwrap())];
1335        write!(
1336            self.out,
1337            "cbuffer {}: register(b{}",
1338            params_name, ext_tex_bindings.params.register
1339        )?;
1340        if ext_tex_bindings.params.space != 0 {
1341            write!(self.out, ", space{}", ext_tex_bindings.params.space)?;
1342        }
1343        writeln!(self.out, ") {{ {params_ty_name} {params_name}; }};")?;
1344
1345        Ok(())
1346    }
1347
1348    /// Helper method used to write global constants
1349    ///
1350    /// # Notes
1351    /// Ends in a newline
1352    fn write_global_constant(
1353        &mut self,
1354        module: &Module,
1355        handle: Handle<crate::Constant>,
1356    ) -> BackendResult {
1357        write!(self.out, "static const ")?;
1358        let constant = &module.constants[handle];
1359        self.write_type(module, constant.ty)?;
1360        let name = &self.names[&NameKey::Constant(handle)];
1361        write!(self.out, " {name}")?;
1362        // Write size for array type
1363        if let TypeInner::Array { base, size, .. } = module.types[constant.ty].inner {
1364            self.write_array_size(module, base, size)?;
1365        }
1366        write!(self.out, " = ")?;
1367        self.write_const_expression(module, constant.init, &module.global_expressions)?;
1368        writeln!(self.out, ";")?;
1369        Ok(())
1370    }
1371
1372    pub(super) fn write_array_size(
1373        &mut self,
1374        module: &Module,
1375        base: Handle<crate::Type>,
1376        size: crate::ArraySize,
1377    ) -> BackendResult {
1378        write!(self.out, "[")?;
1379
1380        match size.resolve(module.to_ctx())? {
1381            proc::IndexableLength::Known(size) => {
1382                write!(self.out, "{size}")?;
1383            }
1384            proc::IndexableLength::Dynamic => unreachable!(),
1385        }
1386
1387        write!(self.out, "]")?;
1388
1389        if let TypeInner::Array {
1390            base: next_base,
1391            size: next_size,
1392            ..
1393        } = module.types[base].inner
1394        {
1395            self.write_array_size(module, next_base, next_size)?;
1396        }
1397
1398        Ok(())
1399    }
1400
1401    /// Helper method used to write structs
1402    ///
1403    /// # Notes
1404    /// Ends in a newline
1405    fn write_struct(
1406        &mut self,
1407        module: &Module,
1408        handle: Handle<crate::Type>,
1409        members: &[crate::StructMember],
1410        span: u32,
1411        shader_stage: Option<(ShaderStage, Io)>,
1412    ) -> BackendResult {
1413        // Write struct name
1414        let struct_name = &self.names[&NameKey::Type(handle)];
1415        writeln!(self.out, "struct {struct_name} {{")?;
1416
1417        let mut last_offset = 0;
1418        for (index, member) in members.iter().enumerate() {
1419            if member.binding.is_none() && member.offset > last_offset {
1420                // using int as padding should work as long as the backend
1421                // doesn't support a type that's less than 4 bytes in size
1422                // (Error::UnsupportedScalar catches this)
1423                let padding = (member.offset - last_offset) / 4;
1424                for i in 0..padding {
1425                    writeln!(self.out, "{}int _pad{}_{};", back::INDENT, index, i)?;
1426                }
1427            }
1428            let ty_inner = &module.types[member.ty].inner;
1429            last_offset = member.offset + ty_inner.size_hlsl(module.to_ctx())?;
1430
1431            // The indentation is only for readability
1432            write!(self.out, "{}", back::INDENT)?;
1433
1434            match module.types[member.ty].inner {
1435                TypeInner::Array { base, size, .. } => {
1436                    // HLSL arrays are written as `type name[size]`
1437
1438                    self.write_global_type(module, member.ty)?;
1439
1440                    // Write `name`
1441                    write!(
1442                        self.out,
1443                        " {}",
1444                        &self.names[&NameKey::StructMember(handle, index as u32)]
1445                    )?;
1446                    // Write [size]
1447                    self.write_array_size(module, base, size)?;
1448                }
1449                // We treat matrices of the form `matCx2` as a sequence of C `vec2`s.
1450                // See the module-level block comment in mod.rs for details.
1451                TypeInner::Matrix {
1452                    rows,
1453                    columns,
1454                    scalar,
1455                } if member.binding.is_none() && rows == crate::VectorSize::Bi => {
1456                    let vec_ty = TypeInner::Vector { size: rows, scalar };
1457                    let field_name_key = NameKey::StructMember(handle, index as u32);
1458
1459                    for i in 0..columns as u8 {
1460                        if i != 0 {
1461                            write!(self.out, "; ")?;
1462                        }
1463                        self.write_value_type(module, &vec_ty)?;
1464                        write!(self.out, " {}_{}", &self.names[&field_name_key], i)?;
1465                    }
1466                }
1467                _ => {
1468                    // Write modifier before type
1469                    if let Some(ref binding) = member.binding {
1470                        self.write_modifier(binding)?;
1471                    }
1472
1473                    // Even though Naga IR matrices are column-major, we must describe
1474                    // matrices passed from the CPU as being in row-major order.
1475                    // See the module-level block comment in mod.rs for details.
1476                    if let TypeInner::Matrix { .. } = module.types[member.ty].inner {
1477                        write!(self.out, "row_major ")?;
1478                    }
1479
1480                    // Write the member type and name
1481                    self.write_type(module, member.ty)?;
1482                    write!(
1483                        self.out,
1484                        " {}",
1485                        &self.names[&NameKey::StructMember(handle, index as u32)]
1486                    )?;
1487                }
1488            }
1489
1490            self.write_semantic(&member.binding, shader_stage)?;
1491            writeln!(self.out, ";")?;
1492        }
1493
1494        // add padding at the end since sizes of types don't get rounded up to their alignment in HLSL
1495        if members.last().unwrap().binding.is_none() && span > last_offset {
1496            let padding = (span - last_offset) / 4;
1497            for i in 0..padding {
1498                writeln!(self.out, "{}int _end_pad_{};", back::INDENT, i)?;
1499            }
1500        }
1501
1502        writeln!(self.out, "}};")?;
1503        Ok(())
1504    }
1505
1506    /// Helper method used to write global/structs non image/sampler types
1507    ///
1508    /// # Notes
1509    /// Adds no trailing or leading whitespace
1510    pub(super) fn write_global_type(
1511        &mut self,
1512        module: &Module,
1513        ty: Handle<crate::Type>,
1514    ) -> BackendResult {
1515        let matrix_data = get_inner_matrix_data(module, ty);
1516
1517        // We treat matrices of the form `matCx2` as a sequence of C `vec2`s.
1518        // See the module-level block comment in mod.rs for details.
1519        if let Some(MatrixType {
1520            columns,
1521            rows: crate::VectorSize::Bi,
1522            width: 4,
1523        }) = matrix_data
1524        {
1525            write!(self.out, "__mat{}x2", columns as u8)?;
1526        } else {
1527            // Even though Naga IR matrices are column-major, we must describe
1528            // matrices passed from the CPU as being in row-major order.
1529            // See the module-level block comment in mod.rs for details.
1530            if matrix_data.is_some() {
1531                write!(self.out, "row_major ")?;
1532            }
1533
1534            self.write_type(module, ty)?;
1535        }
1536
1537        Ok(())
1538    }
1539
1540    /// Helper method used to write non image/sampler types
1541    ///
1542    /// # Notes
1543    /// Adds no trailing or leading whitespace
1544    pub(super) fn write_type(&mut self, module: &Module, ty: Handle<crate::Type>) -> BackendResult {
1545        let inner = &module.types[ty].inner;
1546        match *inner {
1547            TypeInner::Struct { .. } => write!(self.out, "{}", self.names[&NameKey::Type(ty)])?,
1548            // hlsl array has the size separated from the base type
1549            TypeInner::Array { base, .. } | TypeInner::BindingArray { base, .. } => {
1550                self.write_type(module, base)?
1551            }
1552            ref other => self.write_value_type(module, other)?,
1553        }
1554
1555        Ok(())
1556    }
1557
1558    /// Helper method used to write value types
1559    ///
1560    /// # Notes
1561    /// Adds no trailing or leading whitespace
1562    pub(super) fn write_value_type(&mut self, module: &Module, inner: &TypeInner) -> BackendResult {
1563        match *inner {
1564            TypeInner::Scalar(scalar) | TypeInner::Atomic(scalar) => {
1565                write!(self.out, "{}", scalar.to_hlsl_str()?)?;
1566            }
1567            TypeInner::Vector { size, scalar } => {
1568                write!(
1569                    self.out,
1570                    "{}{}",
1571                    scalar.to_hlsl_str()?,
1572                    common::vector_size_str(size)
1573                )?;
1574            }
1575            TypeInner::Matrix {
1576                columns,
1577                rows,
1578                scalar,
1579            } => {
1580                // The IR supports only float matrix
1581                // https://docs.microsoft.com/en-us/windows/win32/direct3dhlsl/dx-graphics-hlsl-matrix
1582
1583                // Because of the implicit transpose all matrices have in HLSL, we need to transpose the size as well.
1584                write!(
1585                    self.out,
1586                    "{}{}x{}",
1587                    scalar.to_hlsl_str()?,
1588                    common::vector_size_str(columns),
1589                    common::vector_size_str(rows),
1590                )?;
1591            }
1592            TypeInner::Image {
1593                dim,
1594                arrayed,
1595                class,
1596            } => {
1597                self.write_image_type(dim, arrayed, class)?;
1598            }
1599            TypeInner::Sampler { comparison } => {
1600                let sampler = if comparison {
1601                    "SamplerComparisonState"
1602                } else {
1603                    "SamplerState"
1604                };
1605                write!(self.out, "{sampler}")?;
1606            }
1607            // HLSL arrays are written as `type name[size]`
1608            // Current code is written arrays only as `[size]`
1609            // Base `type` and `name` should be written outside
1610            TypeInner::Array { base, size, .. } | TypeInner::BindingArray { base, size } => {
1611                self.write_array_size(module, base, size)?;
1612            }
1613            TypeInner::AccelerationStructure { .. } => {
1614                write!(self.out, "RaytracingAccelerationStructure")?;
1615            }
1616            TypeInner::RayQuery { .. } => {
1617                // these are constant flags, there are dynamic flags also but constant flags are not supported by naga
1618                write!(self.out, "RayQuery<RAY_FLAG_NONE>")?;
1619            }
1620            _ => return Err(Error::Unimplemented(format!("write_value_type {inner:?}"))),
1621        }
1622
1623        Ok(())
1624    }
1625
1626    /// Helper method used to write functions
1627    /// # Notes
1628    /// Ends in a newline
1629    fn write_function(
1630        &mut self,
1631        module: &Module,
1632        name: &str,
1633        func: &crate::Function,
1634        func_ctx: &back::FunctionCtx<'_>,
1635        info: &valid::FunctionInfo,
1636        header: String,
1637    ) -> BackendResult {
1638        // Function Declaration Syntax - https://docs.microsoft.com/en-us/windows/win32/direct3dhlsl/dx-graphics-hlsl-function-syntax
1639
1640        self.update_expressions_to_bake(module, func, info);
1641        let ep = match func_ctx.ty {
1642            back::FunctionType::EntryPoint(idx) => Some(&module.entry_points[idx as usize]),
1643            back::FunctionType::Function(_) => None,
1644        };
1645
1646        let nested = matches!(
1647            ep,
1648            Some(crate::EntryPoint {
1649                stage: ShaderStage::Task | ShaderStage::Mesh,
1650                ..
1651            })
1652        );
1653        if !nested {
1654            write!(self.out, "{header}")?;
1655        }
1656
1657        if let Some(ref result) = func.result {
1658            // Write typedef if return type is an array
1659            let array_return_type = match module.types[result.ty].inner {
1660                TypeInner::Array { base, size, .. } => {
1661                    let array_return_type = self.namer.call(&format!("ret_{name}"));
1662                    write!(self.out, "typedef ")?;
1663                    self.write_type(module, result.ty)?;
1664                    write!(self.out, " {array_return_type}")?;
1665                    self.write_array_size(module, base, size)?;
1666                    writeln!(self.out, ";")?;
1667                    Some(array_return_type)
1668                }
1669                _ => None,
1670            };
1671
1672            // Write modifier
1673            if let Some(
1674                ref binding @ crate::Binding::BuiltIn(crate::BuiltIn::Position { invariant: true }),
1675            ) = result.binding
1676            {
1677                self.write_modifier(binding)?;
1678            }
1679
1680            // Write return type
1681            match func_ctx.ty {
1682                back::FunctionType::Function(_) => {
1683                    if let Some(array_return_type) = array_return_type {
1684                        write!(self.out, "{array_return_type}")?;
1685                    } else {
1686                        self.write_type(module, result.ty)?;
1687                    }
1688                }
1689                back::FunctionType::EntryPoint(index) => {
1690                    if let Some(ref ep_output) =
1691                        self.entry_point_io.get(&(index as usize)).unwrap().output
1692                    {
1693                        write!(self.out, "{}", ep_output.ty_name)?;
1694                    } else {
1695                        self.write_type(module, result.ty)?;
1696                    }
1697                }
1698            }
1699        } else {
1700            write!(self.out, "void")?;
1701        }
1702
1703        let nested_name = if nested {
1704            self.namer.call(&format!("_{name}"))
1705        } else {
1706            name.to_string()
1707        };
1708
1709        // Write function name
1710        write!(self.out, " {nested_name}(")?;
1711
1712        let need_workgroup_variables_initialization =
1713            self.need_workgroup_variables_initialization(func_ctx, module);
1714
1715        let mut any_args_written = false;
1716        let mut separator = || {
1717            if any_args_written {
1718                ", "
1719            } else {
1720                any_args_written = true;
1721                ""
1722            }
1723        };
1724
1725        let needs_local_invocation_index_name = need_workgroup_variables_initialization || nested;
1726        let mut local_invocation_index_name = None;
1727        // For nested entry points, collect arg names as we write them so that
1728        // write_nested_function_outer can pass the exact same names to the call site.
1729        let mut nested_wgsl_args: Vec<String> = Vec::new();
1730        let mut nested_task_payload_name: Option<String> = None;
1731        // Write function arguments for non entry point functions
1732        match func_ctx.ty {
1733            back::FunctionType::Function(handle) => {
1734                for (index, arg) in func.arguments.iter().enumerate() {
1735                    write!(self.out, "{}", separator())?;
1736                    self.write_function_argument(module, handle, arg, index)?;
1737                }
1738                // If this reads a task payload variable the variable needs to be passed as an `in` argument
1739                for (var_handle, var) in module.global_variables.iter() {
1740                    let uses = info[var_handle];
1741                    if uses.contains(valid::GlobalUse::READ)
1742                        && !uses.contains(valid::GlobalUse::WRITE)
1743                        && var.space == crate::AddressSpace::TaskPayload
1744                    {
1745                        self.function_task_payload_var.insert(handle, var_handle);
1746                        write!(self.out, "{}in ", separator())?;
1747
1748                        self.write_type(module, var.ty)?;
1749                        let name = &self.names[&NameKey::GlobalVariable(var_handle)];
1750                        write!(self.out, " {name}")?;
1751                        break;
1752                    }
1753                }
1754            }
1755            back::FunctionType::EntryPoint(ep_index) => {
1756                let ep = &module.entry_points[ep_index as usize];
1757                if let Some(ref ep_input) =
1758                    self.entry_point_io.get(&(ep_index as usize)).unwrap().input
1759                {
1760                    write!(self.out, "{} {}", ep_input.ty_name, ep_input.arg_name)?;
1761                    separator();
1762                    nested_wgsl_args.push(ep_input.arg_name.clone());
1763                } else {
1764                    let stage = ep.stage;
1765                    for (index, arg) in func.arguments.iter().enumerate() {
1766                        write!(self.out, "{}", separator())?;
1767                        self.write_type(module, arg.ty)?;
1768
1769                        let argument_name =
1770                            &self.names[&NameKey::EntryPointArgument(ep_index, index as u32)];
1771
1772                        if arg.binding
1773                            == Some(crate::Binding::BuiltIn(
1774                                crate::BuiltIn::LocalInvocationIndex,
1775                            ))
1776                        {
1777                            local_invocation_index_name = Some(argument_name.clone());
1778                        }
1779
1780                        nested_wgsl_args.push(argument_name.clone());
1781                        write!(self.out, " {argument_name}")?;
1782                        if let TypeInner::Array { base, size, .. } = module.types[arg.ty].inner {
1783                            self.write_array_size(module, base, size)?;
1784                        }
1785
1786                        self.write_semantic(&arg.binding, Some((stage, Io::Input)))?;
1787                    }
1788                }
1789                if ep.stage == ShaderStage::Mesh {
1790                    if let Some(var_handle) = ep.task_payload {
1791                        let var = &module.global_variables[var_handle];
1792                        write!(self.out, "{}in ", separator())?;
1793                        self.write_type(module, var.ty)?;
1794                        let arg_name = &self.names[&NameKey::GlobalVariable(var_handle)];
1795                        write!(self.out, " {arg_name}")?;
1796                        nested_task_payload_name = Some(arg_name.clone());
1797                        if let TypeInner::Array { base, size, .. } = module.types[var.ty].inner {
1798                            self.write_array_size(module, base, size)?;
1799                        }
1800                    }
1801                }
1802                if needs_local_invocation_index_name && local_invocation_index_name.is_none() {
1803                    let name = self.namer.call("local_invocation_index");
1804                    write!(self.out, "{}uint {name}", separator())?;
1805                    write!(self.out, " : SV_GroupIndex")?;
1806                    local_invocation_index_name = Some(name);
1807                }
1808            }
1809        }
1810        // Ends of arguments
1811        write!(self.out, ")")?;
1812
1813        // Write semantic if it present
1814        if let back::FunctionType::EntryPoint(index) = func_ctx.ty {
1815            let stage = module.entry_points[index as usize].stage;
1816            if let Some(crate::FunctionResult { ref binding, .. }) = func.result {
1817                self.write_semantic(binding, Some((stage, Io::Output)))?;
1818            }
1819        }
1820
1821        // Function body start
1822        writeln!(self.out)?;
1823        writeln!(self.out, "{{")?;
1824
1825        if need_workgroup_variables_initialization && !nested {
1826            let back::FunctionType::EntryPoint(index) = func_ctx.ty else {
1827                unreachable!();
1828            };
1829            writeln!(
1830                self.out,
1831                "{}if ({} == 0) {{",
1832                back::INDENT,
1833                // need_workgroup_variables_initialization forces this to be written
1834                // if the user doesn't specify it (so this must be Some())
1835                local_invocation_index_name.as_ref().unwrap(),
1836            )?;
1837            self.write_workgroup_variables_initialization(
1838                func_ctx,
1839                module,
1840                module.entry_points[index as usize].stage,
1841            )?;
1842
1843            writeln!(self.out, "{}}}", back::INDENT)?;
1844            self.write_control_barrier(crate::Barrier::WORK_GROUP, back::Level(1))?;
1845        }
1846
1847        if let back::FunctionType::EntryPoint(index) = func_ctx.ty {
1848            self.write_ep_arguments_initialization(module, func, index)?;
1849        }
1850
1851        // Write function local variables
1852        for (handle, local) in func.local_variables.iter() {
1853            // Write indentation (only for readability)
1854            write!(self.out, "{}", back::INDENT)?;
1855
1856            // Write the local name
1857            // The leading space is important
1858            self.write_type(module, local.ty)?;
1859            write!(self.out, " {}", self.names[&func_ctx.name_key(handle)])?;
1860            // Write size for array type
1861            if let TypeInner::Array { base, size, .. } = module.types[local.ty].inner {
1862                self.write_array_size(module, base, size)?;
1863            }
1864
1865            let is_ray_query = match module.types[local.ty].inner {
1866                // from https://microsoft.github.io/DirectX-Specs/d3d/Raytracing.html#tracerayinline-example-1 it seems that ray queries shouldn't be zeroed
1867                TypeInner::RayQuery { .. } => true,
1868                _ => {
1869                    write!(self.out, " = ")?;
1870                    // Write the local initializer if needed
1871                    if let Some(init) = local.init {
1872                        self.write_expr(module, init, func_ctx)?;
1873                    } else {
1874                        // Zero initialize local variables
1875                        self.write_default_init(module, local.ty)?;
1876                    }
1877                    false
1878                }
1879            };
1880            // Finish the local with `;` and add a newline (only for readability)
1881            writeln!(self.out, ";")?;
1882            // If it's a ray query, we also want a tracker variable
1883            if is_ray_query {
1884                write!(self.out, "{}", back::INDENT)?;
1885                self.write_value_type(module, &TypeInner::Scalar(Scalar::U32))?;
1886                writeln!(
1887                    self.out,
1888                    " {RAY_QUERY_TRACKER_VARIABLE_PREFIX}{} = 0;",
1889                    self.names[&func_ctx.name_key(handle)]
1890                )?;
1891            }
1892        }
1893
1894        if !func.local_variables.is_empty() {
1895            writeln!(self.out)?;
1896        }
1897
1898        // Write the function body (statement list)
1899        for sta in func.body.iter() {
1900            // The indentation should always be 1 when writing the function body
1901            self.write_stmt(module, sta, func_ctx, back::Level(1))?;
1902        }
1903
1904        writeln!(self.out, "}}")?;
1905
1906        if nested {
1907            self.write_nested_function_outer(
1908                module,
1909                func_ctx,
1910                &header,
1911                name,
1912                need_workgroup_variables_initialization,
1913                &nested_name,
1914                ep.unwrap(),
1915                NestedEntryPointArgs {
1916                    user_args: nested_wgsl_args,
1917                    task_payload: nested_task_payload_name,
1918                    // guaranteed to be set for nested functions (task/mesh shaders)
1919                    local_invocation_index: local_invocation_index_name.unwrap(),
1920                },
1921            )?;
1922        }
1923
1924        self.named_expressions.clear();
1925
1926        Ok(())
1927    }
1928
1929    fn write_function_argument(
1930        &mut self,
1931        module: &Module,
1932        handle: Handle<crate::Function>,
1933        arg: &crate::FunctionArgument,
1934        index: usize,
1935    ) -> BackendResult {
1936        // External texture arguments must be expanded into separate
1937        // arguments for each plane and the params buffer.
1938        if let TypeInner::Image {
1939            class: crate::ImageClass::External,
1940            ..
1941        } = module.types[arg.ty].inner
1942        {
1943            return self.write_function_external_texture_argument(module, handle, index);
1944        }
1945
1946        // Write argument type
1947        let arg_ty = match module.types[arg.ty].inner {
1948            // pointers in function arguments are expected and resolve to `inout`
1949            TypeInner::Pointer { base, .. } => {
1950                //TODO: can we narrow this down to just `in` when possible?
1951                write!(self.out, "inout ")?;
1952                base
1953            }
1954            _ => arg.ty,
1955        };
1956        self.write_type(module, arg_ty)?;
1957
1958        let argument_name = &self.names[&NameKey::FunctionArgument(handle, index as u32)];
1959
1960        // Write argument name. Space is important.
1961        write!(self.out, " {argument_name}")?;
1962        if let TypeInner::Array { base, size, .. } = module.types[arg_ty].inner {
1963            self.write_array_size(module, base, size)?;
1964        }
1965
1966        Ok(())
1967    }
1968
1969    fn write_function_external_texture_argument(
1970        &mut self,
1971        module: &Module,
1972        handle: Handle<crate::Function>,
1973        index: usize,
1974    ) -> BackendResult {
1975        let plane_names = [0, 1, 2].map(|i| {
1976            &self.names[&NameKey::ExternalTextureFunctionArgument(
1977                handle,
1978                index as u32,
1979                ExternalTextureNameKey::Plane(i),
1980            )]
1981        });
1982        let params_name = &self.names[&NameKey::ExternalTextureFunctionArgument(
1983            handle,
1984            index as u32,
1985            ExternalTextureNameKey::Params,
1986        )];
1987        let params_ty_name =
1988            &self.names[&NameKey::Type(module.special_types.external_texture_params.unwrap())];
1989        write!(
1990            self.out,
1991            "Texture2D<float4> {}, Texture2D<float4> {}, Texture2D<float4> {}, {params_ty_name} {params_name}",
1992            plane_names[0], plane_names[1], plane_names[2],
1993        )?;
1994        Ok(())
1995    }
1996
1997    fn need_workgroup_variables_initialization(
1998        &mut self,
1999        func_ctx: &back::FunctionCtx,
2000        module: &Module,
2001    ) -> bool {
2002        self.options.zero_initialize_workgroup_memory
2003            && func_ctx.ty.is_compute_like_entry_point(module)
2004            && module.global_variables.iter().any(|(handle, var)| {
2005                !func_ctx.info[handle].is_empty() && var.space.is_workgroup_like()
2006            })
2007    }
2008
2009    pub(super) fn write_workgroup_variables_initialization(
2010        &mut self,
2011        func_ctx: &back::FunctionCtx,
2012        module: &Module,
2013        stage: ShaderStage,
2014    ) -> BackendResult {
2015        let vars = module.global_variables.iter().filter(|&(handle, var)| {
2016            // Read-only in mesh shaders
2017            let task_needs_zero =
2018                (var.space == crate::AddressSpace::TaskPayload) && stage == ShaderStage::Task;
2019            !func_ctx.info[handle].is_empty()
2020                && (var.space == crate::AddressSpace::WorkGroup || task_needs_zero)
2021        });
2022
2023        for (handle, var) in vars {
2024            let name = &self.names[&NameKey::GlobalVariable(handle)];
2025            write!(self.out, "{}{} = ", back::Level(2), name)?;
2026            self.write_default_init(module, var.ty)?;
2027            writeln!(self.out, ";")?;
2028        }
2029        Ok(())
2030    }
2031
2032    /// Helper method used to write switches
2033    fn write_switch(
2034        &mut self,
2035        module: &Module,
2036        func_ctx: &back::FunctionCtx<'_>,
2037        level: back::Level,
2038        selector: Handle<crate::Expression>,
2039        cases: &[crate::SwitchCase],
2040    ) -> BackendResult {
2041        // Write all cases
2042        let indent_level_1 = level.next();
2043        let indent_level_2 = indent_level_1.next();
2044
2045        // See docs of `back::continue_forward` module.
2046        if let Some(variable) = self.continue_ctx.enter_switch(&mut self.namer) {
2047            writeln!(self.out, "{level}bool {variable} = false;",)?;
2048        };
2049
2050        // Check if there is only one body, by seeing if all except the last case are fall through
2051        // with empty bodies. FXC doesn't handle these switches correctly, so
2052        // we generate a `do {} while(false);` loop instead. There must be a default case, so there
2053        // is no need to check if one of the cases would have matched.
2054        let one_body = cases
2055            .iter()
2056            .rev()
2057            .skip(1)
2058            .all(|case| case.fall_through && case.body.is_empty());
2059        if one_body {
2060            // Start the do-while
2061            writeln!(self.out, "{level}do {{")?;
2062            // Note: Expressions have no side-effects so we don't need to emit selector expression.
2063
2064            // Body
2065            if let Some(case) = cases.last() {
2066                for sta in case.body.iter() {
2067                    self.write_stmt(module, sta, func_ctx, indent_level_1)?;
2068                }
2069            }
2070            // End do-while
2071            writeln!(self.out, "{level}}} while(false);")?;
2072        } else {
2073            // Start the switch
2074            write!(self.out, "{level}")?;
2075            write!(self.out, "switch(")?;
2076            self.write_expr(module, selector, func_ctx)?;
2077            writeln!(self.out, ") {{")?;
2078
2079            for (i, case) in cases.iter().enumerate() {
2080                match case.value {
2081                    crate::SwitchValue::I32(value) => {
2082                        write!(self.out, "{indent_level_1}case {value}:")?
2083                    }
2084                    crate::SwitchValue::U32(value) => {
2085                        write!(self.out, "{indent_level_1}case {value}u:")?
2086                    }
2087                    crate::SwitchValue::Default => write!(self.out, "{indent_level_1}default:")?,
2088                }
2089
2090                // The new block is not only stylistic, it plays a role here:
2091                // We might end up having to write the same case body
2092                // multiple times due to FXC not supporting fallthrough.
2093                // Therefore, some `Expression`s written by `Statement::Emit`
2094                // will end up having the same name (`_expr<handle_index>`).
2095                // So we need to put each case in its own scope.
2096                let write_block_braces = !(case.fall_through && case.body.is_empty());
2097                if write_block_braces {
2098                    writeln!(self.out, " {{")?;
2099                } else {
2100                    writeln!(self.out)?;
2101                }
2102
2103                // Although FXC does support a series of case clauses before
2104                // a block[^yes], it does not support fallthrough from a
2105                // non-empty case block to the next[^no]. If this case has a
2106                // non-empty body with a fallthrough, emulate that by
2107                // duplicating the bodies of all the cases it would fall
2108                // into as extensions of this case's own body. This makes
2109                // the HLSL output potentially quadratic in the size of the
2110                // Naga IR.
2111                //
2112                // [^yes]: ```hlsl
2113                // case 1:
2114                // case 2: do_stuff()
2115                // ```
2116                // [^no]: ```hlsl
2117                // case 1: do_this();
2118                // case 2: do_that();
2119                // ```
2120                if case.fall_through && !case.body.is_empty() {
2121                    let curr_len = i + 1;
2122                    let end_case_idx = curr_len
2123                        + cases
2124                            .iter()
2125                            .skip(curr_len)
2126                            .position(|case| !case.fall_through)
2127                            .unwrap();
2128                    let indent_level_3 = indent_level_2.next();
2129                    for case in &cases[i..=end_case_idx] {
2130                        writeln!(self.out, "{indent_level_2}{{")?;
2131                        let prev_len = self.named_expressions.len();
2132                        for sta in case.body.iter() {
2133                            self.write_stmt(module, sta, func_ctx, indent_level_3)?;
2134                        }
2135                        // Clear all named expressions that were previously inserted by the statements in the block
2136                        self.named_expressions.truncate(prev_len);
2137                        writeln!(self.out, "{indent_level_2}}}")?;
2138                    }
2139
2140                    let last_case = &cases[end_case_idx];
2141                    if last_case.body.last().is_none_or(|s| !s.is_terminator()) {
2142                        writeln!(self.out, "{indent_level_2}break;")?;
2143                    }
2144                } else {
2145                    for sta in case.body.iter() {
2146                        self.write_stmt(module, sta, func_ctx, indent_level_2)?;
2147                    }
2148                    if !case.fall_through && case.body.last().is_none_or(|s| !s.is_terminator()) {
2149                        writeln!(self.out, "{indent_level_2}break;")?;
2150                    }
2151                }
2152
2153                if write_block_braces {
2154                    writeln!(self.out, "{indent_level_1}}}")?;
2155                }
2156            }
2157
2158            writeln!(self.out, "{level}}}")?;
2159        }
2160
2161        // Handle any forwarded continue statements.
2162        use back::continue_forward::ExitControlFlow;
2163        let op = match self.continue_ctx.exit_switch() {
2164            ExitControlFlow::None => None,
2165            ExitControlFlow::Continue { variable } => Some(("continue", variable)),
2166            ExitControlFlow::Break { variable } => Some(("break", variable)),
2167        };
2168        if let Some((control_flow, variable)) = op {
2169            writeln!(self.out, "{level}if ({variable}) {{")?;
2170            writeln!(self.out, "{indent_level_1}{control_flow};")?;
2171            writeln!(self.out, "{level}}}")?;
2172        }
2173
2174        Ok(())
2175    }
2176
2177    fn write_index(
2178        &mut self,
2179        module: &Module,
2180        index: Index,
2181        func_ctx: &back::FunctionCtx<'_>,
2182    ) -> BackendResult {
2183        match index {
2184            Index::Static(index) => {
2185                write!(self.out, "{index}")?;
2186            }
2187            Index::Expression(index) => {
2188                self.write_expr(module, index, func_ctx)?;
2189            }
2190        }
2191        Ok(())
2192    }
2193
2194    /// Helper method used to write statements
2195    ///
2196    /// # Notes
2197    /// Always adds a newline
2198    fn write_stmt(
2199        &mut self,
2200        module: &Module,
2201        stmt: &crate::Statement,
2202        func_ctx: &back::FunctionCtx<'_>,
2203        level: back::Level,
2204    ) -> BackendResult {
2205        use crate::Statement;
2206
2207        match *stmt {
2208            Statement::Emit(ref range) => {
2209                for handle in range.clone() {
2210                    let ptr_class = func_ctx.resolve_type(handle, &module.types).pointer_space();
2211                    let expr_name = if ptr_class.is_some() {
2212                        // HLSL can't save a pointer-valued expression in a variable,
2213                        // but we shouldn't ever need to: they should never be named expressions,
2214                        // and none of the expression types flagged by bake_ref_count can be pointer-valued.
2215                        None
2216                    } else if let Some(name) = func_ctx.named_expressions.get(&handle) {
2217                        // Front end provides names for all variables at the start of writing.
2218                        // But we write them to step by step. We need to recache them
2219                        // Otherwise, we could accidentally write variable name instead of full expression.
2220                        // Also, we use sanitized names! It defense backend from generating variable with name from reserved keywords.
2221                        Some(self.namer.call(name))
2222                    } else if self.need_bake_expressions.contains(&handle) {
2223                        Some(Baked(handle).to_string())
2224                    } else {
2225                        None
2226                    };
2227
2228                    if let Some(name) = expr_name {
2229                        write!(self.out, "{level}")?;
2230                        self.write_named_expr(module, handle, name, handle, func_ctx)?;
2231                    }
2232                }
2233            }
2234            // TODO: copy-paste from glsl-out
2235            Statement::Block(ref block) => {
2236                write!(self.out, "{level}")?;
2237                writeln!(self.out, "{{")?;
2238                for sta in block.iter() {
2239                    // Increase the indentation to help with readability
2240                    self.write_stmt(module, sta, func_ctx, level.next())?
2241                }
2242                writeln!(self.out, "{level}}}")?
2243            }
2244            // TODO: copy-paste from glsl-out
2245            Statement::If {
2246                condition,
2247                ref accept,
2248                ref reject,
2249            } => {
2250                write!(self.out, "{level}")?;
2251                write!(self.out, "if (")?;
2252                self.write_expr(module, condition, func_ctx)?;
2253                writeln!(self.out, ") {{")?;
2254
2255                let l2 = level.next();
2256                for sta in accept {
2257                    // Increase indentation to help with readability
2258                    self.write_stmt(module, sta, func_ctx, l2)?;
2259                }
2260
2261                // If there are no statements in the reject block we skip writing it
2262                // This is only for readability
2263                if !reject.is_empty() {
2264                    writeln!(self.out, "{level}}} else {{")?;
2265
2266                    for sta in reject {
2267                        // Increase indentation to help with readability
2268                        self.write_stmt(module, sta, func_ctx, l2)?;
2269                    }
2270                }
2271
2272                writeln!(self.out, "{level}}}")?
2273            }
2274            // TODO: copy-paste from glsl-out
2275            Statement::Kill => writeln!(self.out, "{level}discard;")?,
2276            Statement::Return { value: None } => {
2277                writeln!(self.out, "{level}return;")?;
2278            }
2279            Statement::Return { value: Some(expr) } => {
2280                let base_ty_res = &func_ctx.info[expr].ty;
2281                let mut resolved = base_ty_res.inner_with(&module.types);
2282                if let TypeInner::Pointer { base, space: _ } = *resolved {
2283                    resolved = &module.types[base].inner;
2284                }
2285
2286                if let TypeInner::Struct { .. } = *resolved {
2287                    // We can safely unwrap here, since we now we working with struct
2288                    let ty = base_ty_res.handle().unwrap();
2289                    let struct_name = &self.names[&NameKey::Type(ty)];
2290                    let variable_name = self.namer.call(&struct_name.to_lowercase());
2291                    write!(self.out, "{level}const {struct_name} {variable_name} = ",)?;
2292                    self.write_expr(module, expr, func_ctx)?;
2293                    writeln!(self.out, ";")?;
2294
2295                    // for entry point returns, we may need to reshuffle the outputs into a different struct
2296                    let ep_output = match func_ctx.ty {
2297                        back::FunctionType::Function(_) => None,
2298                        back::FunctionType::EntryPoint(index) => self
2299                            .entry_point_io
2300                            .get(&(index as usize))
2301                            .unwrap()
2302                            .output
2303                            .as_ref(),
2304                    };
2305                    let final_name = match ep_output {
2306                        Some(ep_output) => {
2307                            let final_name = self.namer.call(&variable_name);
2308                            write!(
2309                                self.out,
2310                                "{}const {} {} = {{ ",
2311                                level, ep_output.ty_name, final_name,
2312                            )?;
2313                            for (index, m) in ep_output.members.iter().enumerate() {
2314                                if index != 0 {
2315                                    write!(self.out, ", ")?;
2316                                }
2317                                let member_name = &self.names[&NameKey::StructMember(ty, m.index)];
2318                                write!(self.out, "{variable_name}.{member_name}")?;
2319                            }
2320                            writeln!(self.out, " }};")?;
2321                            final_name
2322                        }
2323                        None => variable_name,
2324                    };
2325                    writeln!(self.out, "{level}return {final_name};")?;
2326                } else {
2327                    write!(self.out, "{level}return ")?;
2328                    self.write_expr(module, expr, func_ctx)?;
2329                    writeln!(self.out, ";")?
2330                }
2331            }
2332            Statement::Store { pointer, value } => {
2333                let ty_inner = func_ctx.resolve_type(pointer, &module.types);
2334                if ty_inner.is_atomic_pointer(&module.types) {
2335                    let pointer_space = ty_inner.pointer_space().unwrap();
2336                    let dummy = self.namer.call("dummy");
2337                    write!(self.out, "{level}{{ ")?;
2338                    if let TypeInner::Pointer { base, .. } = *ty_inner {
2339                        self.write_value_type(module, &module.types[base].inner)?;
2340                    }
2341                    write!(self.out, " {dummy} = 0; ")?;
2342                    match pointer_space {
2343                        crate::AddressSpace::WorkGroup => {
2344                            write!(self.out, "InterlockedExchange(")?;
2345                            self.write_expr(module, pointer, func_ctx)?;
2346                        }
2347                        crate::AddressSpace::Storage { .. } => {
2348                            let var_handle = self.fill_access_chain(module, pointer, func_ctx)?;
2349                            let var_name = &self.names[&NameKey::GlobalVariable(var_handle)];
2350                            write!(self.out, "{var_name}.InterlockedExchange(")?;
2351                            let chain = mem::take(&mut self.temp_access_chain);
2352                            self.write_storage_address(module, &chain, func_ctx)?;
2353                            self.temp_access_chain = chain;
2354                        }
2355                        _ => unreachable!(),
2356                    }
2357                    write!(self.out, ", ")?;
2358                    self.write_expr(module, value, func_ctx)?;
2359                    writeln!(self.out, ", {dummy}); }}")?;
2360                } else if let Some(crate::AddressSpace::Storage { .. }) = ty_inner.pointer_space() {
2361                    let var_handle = self.fill_access_chain(module, pointer, func_ctx)?;
2362                    self.write_storage_store(
2363                        module,
2364                        var_handle,
2365                        StoreValue::Expression(value),
2366                        func_ctx,
2367                        level,
2368                        None,
2369                    )?;
2370                } else {
2371                    // We treat matrices of the form `matCx2` as a sequence of C `vec2`s.
2372                    // See the module-level block comment in mod.rs for details.
2373                    //
2374                    // We handle matrix Stores here directly (including sub accesses for Vectors and Scalars).
2375                    // Loads are handled by `Expression::AccessIndex` (since sub accesses work fine for Loads).
2376                    enum MatrixAccess {
2377                        Direct {
2378                            base: Handle<crate::Expression>,
2379                            index: u32,
2380                        },
2381                        Struct {
2382                            columns: crate::VectorSize,
2383                            base: Handle<crate::Expression>,
2384                        },
2385                    }
2386
2387                    let get_members = |expr: Handle<crate::Expression>| {
2388                        let resolved = func_ctx.resolve_type(expr, &module.types);
2389                        match *resolved {
2390                            TypeInner::Pointer { base, .. } => match module.types[base].inner {
2391                                TypeInner::Struct { ref members, .. } => Some(members),
2392                                _ => None,
2393                            },
2394                            _ => None,
2395                        }
2396                    };
2397
2398                    write!(self.out, "{level}")?;
2399
2400                    let matrix_access_on_lhs =
2401                        find_matrix_in_access_chain(module, pointer, func_ctx).and_then(
2402                            |(matrix_expr, vector, scalar)| match (
2403                                func_ctx.resolve_type(matrix_expr, &module.types),
2404                                &func_ctx.expressions[matrix_expr],
2405                            ) {
2406                                (
2407                                    &TypeInner::Pointer { base: ty, .. },
2408                                    &crate::Expression::AccessIndex { base, index },
2409                                ) if matches!(
2410                                    module.types[ty].inner,
2411                                    TypeInner::Matrix {
2412                                        rows: crate::VectorSize::Bi,
2413                                        ..
2414                                    }
2415                                ) && get_members(base)
2416                                    .map(|members| members[index as usize].binding.is_none())
2417                                    == Some(true) =>
2418                                {
2419                                    Some((MatrixAccess::Direct { base, index }, vector, scalar))
2420                                }
2421                                _ => {
2422                                    if let Some(MatrixType {
2423                                        columns,
2424                                        rows: crate::VectorSize::Bi,
2425                                        width: 4,
2426                                    }) = get_inner_matrix_of_struct_array_member(
2427                                        module,
2428                                        matrix_expr,
2429                                        func_ctx,
2430                                        true,
2431                                    ) {
2432                                        Some((
2433                                            MatrixAccess::Struct {
2434                                                columns,
2435                                                base: matrix_expr,
2436                                            },
2437                                            vector,
2438                                            scalar,
2439                                        ))
2440                                    } else {
2441                                        None
2442                                    }
2443                                }
2444                            },
2445                        );
2446
2447                    match matrix_access_on_lhs {
2448                        Some((MatrixAccess::Direct { index, base }, vector, scalar)) => {
2449                            let base_ty_res = &func_ctx.info[base].ty;
2450                            let resolved = base_ty_res.inner_with(&module.types);
2451                            let ty = match *resolved {
2452                                TypeInner::Pointer { base, .. } => base,
2453                                _ => base_ty_res.handle().unwrap(),
2454                            };
2455
2456                            if let Some(Index::Static(vec_index)) = vector {
2457                                self.write_expr(module, base, func_ctx)?;
2458                                write!(
2459                                    self.out,
2460                                    ".{}_{}",
2461                                    &self.names[&NameKey::StructMember(ty, index)],
2462                                    vec_index
2463                                )?;
2464
2465                                if let Some(scalar_index) = scalar {
2466                                    write!(self.out, "[")?;
2467                                    self.write_index(module, scalar_index, func_ctx)?;
2468                                    write!(self.out, "]")?;
2469                                }
2470
2471                                write!(self.out, " = ")?;
2472                                self.write_expr(module, value, func_ctx)?;
2473                                writeln!(self.out, ";")?;
2474                            } else {
2475                                let access = WrappedStructMatrixAccess { ty, index };
2476                                match (&vector, &scalar) {
2477                                    (&Some(_), &Some(_)) => {
2478                                        self.write_wrapped_struct_matrix_set_scalar_function_name(
2479                                            access,
2480                                        )?;
2481                                    }
2482                                    (&Some(_), &None) => {
2483                                        self.write_wrapped_struct_matrix_set_vec_function_name(
2484                                            access,
2485                                        )?;
2486                                    }
2487                                    (&None, _) => {
2488                                        self.write_wrapped_struct_matrix_set_function_name(access)?;
2489                                    }
2490                                }
2491
2492                                write!(self.out, "(")?;
2493                                self.write_expr(module, base, func_ctx)?;
2494                                write!(self.out, ", ")?;
2495                                self.write_expr(module, value, func_ctx)?;
2496
2497                                if let Some(Index::Expression(vec_index)) = vector {
2498                                    write!(self.out, ", ")?;
2499                                    self.write_expr(module, vec_index, func_ctx)?;
2500
2501                                    if let Some(scalar_index) = scalar {
2502                                        write!(self.out, ", ")?;
2503                                        self.write_index(module, scalar_index, func_ctx)?;
2504                                    }
2505                                }
2506                                writeln!(self.out, ");")?;
2507                            }
2508                        }
2509                        Some((
2510                            MatrixAccess::Struct { columns, base },
2511                            Some(Index::Expression(vec_index)),
2512                            scalar,
2513                        )) => {
2514                            // We handle `Store`s to __matCx2 column vectors and scalar elements via
2515                            // the previously injected functions __set_col_of_matCx2 / __set_el_of_matCx2.
2516
2517                            if scalar.is_some() {
2518                                write!(self.out, "__set_el_of_mat{}x2", columns as u8)?;
2519                            } else {
2520                                write!(self.out, "__set_col_of_mat{}x2", columns as u8)?;
2521                            }
2522                            write!(self.out, "(")?;
2523                            self.write_expr(module, base, func_ctx)?;
2524                            write!(self.out, ", ")?;
2525                            self.write_expr(module, vec_index, func_ctx)?;
2526
2527                            if let Some(scalar_index) = scalar {
2528                                write!(self.out, ", ")?;
2529                                self.write_index(module, scalar_index, func_ctx)?;
2530                            }
2531
2532                            write!(self.out, ", ")?;
2533                            self.write_expr(module, value, func_ctx)?;
2534
2535                            writeln!(self.out, ");")?;
2536                        }
2537                        Some((MatrixAccess::Struct { .. }, Some(Index::Static(_)), _))
2538                        | Some((MatrixAccess::Struct { .. }, None, _))
2539                        | None => {
2540                            self.write_expr(module, pointer, func_ctx)?;
2541                            write!(self.out, " = ")?;
2542
2543                            // We cast the RHS of this store in cases where the LHS
2544                            // is a struct member with type:
2545                            //  - matCx2 or
2546                            //  - a (possibly nested) array of matCx2's
2547                            if let Some(MatrixType {
2548                                columns,
2549                                rows: crate::VectorSize::Bi,
2550                                width: 4,
2551                            }) = get_inner_matrix_of_struct_array_member(
2552                                module, pointer, func_ctx, false,
2553                            ) {
2554                                let mut resolved = func_ctx.resolve_type(pointer, &module.types);
2555                                if let TypeInner::Pointer { base, .. } = *resolved {
2556                                    resolved = &module.types[base].inner;
2557                                }
2558
2559                                write!(self.out, "(__mat{}x2", columns as u8)?;
2560                                if let TypeInner::Array { base, size, .. } = *resolved {
2561                                    self.write_array_size(module, base, size)?;
2562                                }
2563                                write!(self.out, ")")?;
2564                            }
2565
2566                            self.write_expr(module, value, func_ctx)?;
2567                            writeln!(self.out, ";")?
2568                        }
2569                    }
2570                }
2571            }
2572            Statement::Loop {
2573                ref body,
2574                ref continuing,
2575                break_if,
2576            } => {
2577                let force_loop_bound_statements = self.gen_force_bounded_loop_statements(level);
2578                let gate_name = (!continuing.is_empty() || break_if.is_some())
2579                    .then(|| self.namer.call("loop_init"));
2580
2581                if let Some((ref decl, _)) = force_loop_bound_statements {
2582                    writeln!(self.out, "{decl}")?;
2583                }
2584                if let Some(ref gate_name) = gate_name {
2585                    writeln!(self.out, "{level}bool {gate_name} = true;")?;
2586                }
2587
2588                self.continue_ctx.enter_loop();
2589                writeln!(self.out, "{level}while(true) {{")?;
2590                if let Some((_, ref break_and_inc)) = force_loop_bound_statements {
2591                    writeln!(self.out, "{break_and_inc}")?;
2592                }
2593                let l2 = level.next();
2594                if let Some(gate_name) = gate_name {
2595                    writeln!(self.out, "{l2}if (!{gate_name}) {{")?;
2596                    let l3 = l2.next();
2597                    for sta in continuing.iter() {
2598                        self.write_stmt(module, sta, func_ctx, l3)?;
2599                    }
2600                    if let Some(condition) = break_if {
2601                        write!(self.out, "{l3}if (")?;
2602                        self.write_expr(module, condition, func_ctx)?;
2603                        writeln!(self.out, ") {{")?;
2604                        writeln!(self.out, "{}break;", l3.next())?;
2605                        writeln!(self.out, "{l3}}}")?;
2606                    }
2607                    writeln!(self.out, "{l2}}}")?;
2608                    writeln!(self.out, "{l2}{gate_name} = false;")?;
2609                }
2610
2611                for sta in body.iter() {
2612                    self.write_stmt(module, sta, func_ctx, l2)?;
2613                }
2614
2615                writeln!(self.out, "{level}}}")?;
2616                self.continue_ctx.exit_loop();
2617            }
2618            Statement::Break => writeln!(self.out, "{level}break;")?,
2619            Statement::Continue => {
2620                if let Some(variable) = self.continue_ctx.continue_encountered() {
2621                    writeln!(self.out, "{level}{variable} = true;")?;
2622                    writeln!(self.out, "{level}break;")?
2623                } else {
2624                    writeln!(self.out, "{level}continue;")?
2625                }
2626            }
2627            Statement::ControlBarrier(barrier) => {
2628                self.write_control_barrier(barrier, level)?;
2629            }
2630            Statement::MemoryBarrier(barrier) => {
2631                self.write_memory_barrier(barrier, level)?;
2632            }
2633            Statement::ImageStore {
2634                image,
2635                coordinate,
2636                array_index,
2637                value,
2638            } => {
2639                write!(self.out, "{level}")?;
2640                self.write_expr(module, image, func_ctx)?;
2641
2642                write!(self.out, "[")?;
2643                if let Some(index) = array_index {
2644                    // Array index accepted only for texture_storage_2d_array, so we can safety use int3(coordinate, array_index) here
2645                    write!(self.out, "int3(")?;
2646                    self.write_expr(module, coordinate, func_ctx)?;
2647                    write!(self.out, ", ")?;
2648                    self.write_expr(module, index, func_ctx)?;
2649                    write!(self.out, ")")?;
2650                } else {
2651                    self.write_expr(module, coordinate, func_ctx)?;
2652                }
2653                write!(self.out, "]")?;
2654
2655                write!(self.out, " = ")?;
2656                self.write_expr(module, value, func_ctx)?;
2657                writeln!(self.out, ";")?;
2658            }
2659            Statement::Call {
2660                function,
2661                ref arguments,
2662                result,
2663            } => {
2664                write!(self.out, "{level}")?;
2665
2666                if let Some(expr) = result {
2667                    write!(self.out, "const ")?;
2668                    let name = Baked(expr).to_string();
2669                    let expr_ty = &func_ctx.info[expr].ty;
2670                    let ty_inner = match *expr_ty {
2671                        proc::TypeResolution::Handle(handle) => {
2672                            self.write_type(module, handle)?;
2673                            &module.types[handle].inner
2674                        }
2675                        proc::TypeResolution::Value(ref value) => {
2676                            self.write_value_type(module, value)?;
2677                            value
2678                        }
2679                    };
2680                    write!(self.out, " {name}")?;
2681                    if let TypeInner::Array { base, size, .. } = *ty_inner {
2682                        self.write_array_size(module, base, size)?;
2683                    }
2684                    write!(self.out, " = ")?;
2685                    self.named_expressions.insert(expr, name);
2686                }
2687                let func_name = &self.names[&NameKey::Function(function)];
2688                write!(self.out, "{func_name}(")?;
2689                let mut any_args_written = false;
2690                let mut separator = || {
2691                    if any_args_written {
2692                        ", "
2693                    } else {
2694                        any_args_written = true;
2695                        ""
2696                    }
2697                };
2698                for argument in arguments {
2699                    write!(self.out, "{}", separator())?;
2700                    self.write_expr(module, *argument, func_ctx)?;
2701                }
2702                if let Some(&var) = self.function_task_payload_var.get(&function) {
2703                    let name = &self.names[&NameKey::GlobalVariable(var)];
2704                    // Pass it through directly, whether its an in variable to this function or the global variable
2705                    write!(self.out, "{}{name}", separator())?;
2706                }
2707                writeln!(self.out, ");")?;
2708            }
2709            Statement::Atomic {
2710                pointer,
2711                ref fun,
2712                value,
2713                result,
2714            } => {
2715                write!(self.out, "{level}")?;
2716                let res_var_info = if let Some(res_handle) = result {
2717                    let name = Baked(res_handle).to_string();
2718                    match func_ctx.info[res_handle].ty {
2719                        proc::TypeResolution::Handle(handle) => self.write_type(module, handle)?,
2720                        proc::TypeResolution::Value(ref value) => {
2721                            self.write_value_type(module, value)?
2722                        }
2723                    };
2724                    write!(self.out, " {name}; ")?;
2725                    self.named_expressions.insert(res_handle, name.clone());
2726                    Some((res_handle, name))
2727                } else {
2728                    None
2729                };
2730                let pointer_space = func_ctx
2731                    .resolve_type(pointer, &module.types)
2732                    .pointer_space()
2733                    .unwrap();
2734                let fun_str = fun.to_hlsl_suffix();
2735                let compare_expr = match *fun {
2736                    crate::AtomicFunction::Exchange { compare: Some(cmp) } => Some(cmp),
2737                    _ => None,
2738                };
2739                match pointer_space {
2740                    crate::AddressSpace::WorkGroup => {
2741                        write!(self.out, "Interlocked{fun_str}(")?;
2742                        self.write_expr(module, pointer, func_ctx)?;
2743                        self.emit_hlsl_atomic_tail(
2744                            module,
2745                            func_ctx,
2746                            fun,
2747                            compare_expr,
2748                            value,
2749                            &res_var_info,
2750                        )?;
2751                    }
2752                    crate::AddressSpace::Storage { .. } => {
2753                        let var_handle = self.fill_access_chain(module, pointer, func_ctx)?;
2754                        let var_name = &self.names[&NameKey::GlobalVariable(var_handle)];
2755                        let width = match func_ctx.resolve_type(value, &module.types) {
2756                            &TypeInner::Scalar(Scalar { width: 8, .. }) => "64",
2757                            _ => "",
2758                        };
2759                        write!(self.out, "{var_name}.Interlocked{fun_str}{width}(")?;
2760                        let chain = mem::take(&mut self.temp_access_chain);
2761                        self.write_storage_address(module, &chain, func_ctx)?;
2762                        self.temp_access_chain = chain;
2763                        self.emit_hlsl_atomic_tail(
2764                            module,
2765                            func_ctx,
2766                            fun,
2767                            compare_expr,
2768                            value,
2769                            &res_var_info,
2770                        )?;
2771                    }
2772                    ref other => {
2773                        return Err(Error::Custom(format!(
2774                            "invalid address space {other:?} for atomic statement"
2775                        )))
2776                    }
2777                }
2778                if let Some(cmp) = compare_expr {
2779                    if let Some(&(_res_handle, ref res_name)) = res_var_info.as_ref() {
2780                        write!(
2781                            self.out,
2782                            "{level}{res_name}.exchanged = ({res_name}.old_value == "
2783                        )?;
2784                        self.write_expr(module, cmp, func_ctx)?;
2785                        writeln!(self.out, ");")?;
2786                    }
2787                }
2788            }
2789            Statement::ImageAtomic {
2790                image,
2791                coordinate,
2792                array_index,
2793                fun,
2794                value,
2795            } => {
2796                write!(self.out, "{level}")?;
2797
2798                let fun_str = fun.to_hlsl_suffix();
2799                write!(self.out, "Interlocked{fun_str}(")?;
2800                self.write_expr(module, image, func_ctx)?;
2801                write!(self.out, "[")?;
2802                self.write_texture_coordinates(
2803                    "int",
2804                    coordinate,
2805                    array_index,
2806                    None,
2807                    module,
2808                    func_ctx,
2809                )?;
2810                write!(self.out, "],")?;
2811
2812                self.write_expr(module, value, func_ctx)?;
2813                writeln!(self.out, ");")?;
2814            }
2815            Statement::WorkGroupUniformLoad { pointer, result } => {
2816                self.write_control_barrier(crate::Barrier::WORK_GROUP, level)?;
2817                write!(self.out, "{level}")?;
2818                let name = Baked(result).to_string();
2819                self.write_named_expr(module, pointer, name, result, func_ctx)?;
2820
2821                self.write_control_barrier(crate::Barrier::WORK_GROUP, level)?;
2822            }
2823            Statement::Switch {
2824                selector,
2825                ref cases,
2826            } => {
2827                self.write_switch(module, func_ctx, level, selector, cases)?;
2828            }
2829            Statement::RayQuery { query, ref fun } => {
2830                // There are three possibilities for a ptr to be:
2831                // 1. A variable
2832                // 2. A function argument
2833                // 3. part of a struct
2834                //
2835                // 2 and 3 are not possible, a ray query (in naga IR)
2836                // is not allowed to be passed into a function, and
2837                // all languages disallow it in a struct (you get fun results if
2838                // you try it :) ).
2839                //
2840                // Therefore, the ray query expression must be a variable.
2841                let crate::Expression::LocalVariable(query_var) = func_ctx.expressions[query]
2842                else {
2843                    unreachable!()
2844                };
2845
2846                let tracker_expr_name = format!(
2847                    "{RAY_QUERY_TRACKER_VARIABLE_PREFIX}{}",
2848                    self.names[&func_ctx.name_key(query_var)]
2849                );
2850
2851                match *fun {
2852                    RayQueryFunction::Initialize {
2853                        acceleration_structure,
2854                        descriptor,
2855                    } => {
2856                        self.write_initialize_function(
2857                            module,
2858                            level,
2859                            query,
2860                            acceleration_structure,
2861                            descriptor,
2862                            &tracker_expr_name,
2863                            func_ctx,
2864                        )?;
2865                    }
2866                    RayQueryFunction::Proceed { result } => {
2867                        self.write_proceed(
2868                            module,
2869                            level,
2870                            query,
2871                            result,
2872                            &tracker_expr_name,
2873                            func_ctx,
2874                        )?;
2875                    }
2876                    RayQueryFunction::GenerateIntersection { hit_t } => {
2877                        self.write_generate_intersection(
2878                            module,
2879                            level,
2880                            query,
2881                            hit_t,
2882                            &tracker_expr_name,
2883                            func_ctx,
2884                        )?;
2885                    }
2886                    RayQueryFunction::ConfirmIntersection => {
2887                        self.write_confirm_intersection(
2888                            module,
2889                            level,
2890                            query,
2891                            &tracker_expr_name,
2892                            func_ctx,
2893                        )?;
2894                    }
2895                    RayQueryFunction::Terminate => {
2896                        self.write_terminate(module, level, query, &tracker_expr_name, func_ctx)?;
2897                    }
2898                }
2899            }
2900            Statement::SubgroupBallot { result, predicate } => {
2901                write!(self.out, "{level}")?;
2902                let name = Baked(result).to_string();
2903                write!(self.out, "const uint4 {name} = ")?;
2904                self.named_expressions.insert(result, name);
2905
2906                write!(self.out, "WaveActiveBallot(")?;
2907                match predicate {
2908                    Some(predicate) => self.write_expr(module, predicate, func_ctx)?,
2909                    None => write!(self.out, "true")?,
2910                }
2911                writeln!(self.out, ");")?;
2912            }
2913            Statement::SubgroupCollectiveOperation {
2914                op,
2915                collective_op,
2916                argument,
2917                result,
2918            } => {
2919                write!(self.out, "{level}")?;
2920                write!(self.out, "const ")?;
2921                let name = Baked(result).to_string();
2922                match func_ctx.info[result].ty {
2923                    proc::TypeResolution::Handle(handle) => self.write_type(module, handle)?,
2924                    proc::TypeResolution::Value(ref value) => {
2925                        self.write_value_type(module, value)?
2926                    }
2927                };
2928                write!(self.out, " {name} = ")?;
2929                self.named_expressions.insert(result, name);
2930
2931                match (collective_op, op) {
2932                    (crate::CollectiveOperation::Reduce, crate::SubgroupOperation::All) => {
2933                        write!(self.out, "WaveActiveAllTrue(")?
2934                    }
2935                    (crate::CollectiveOperation::Reduce, crate::SubgroupOperation::Any) => {
2936                        write!(self.out, "WaveActiveAnyTrue(")?
2937                    }
2938                    (crate::CollectiveOperation::Reduce, crate::SubgroupOperation::Add) => {
2939                        write!(self.out, "WaveActiveSum(")?
2940                    }
2941                    (crate::CollectiveOperation::Reduce, crate::SubgroupOperation::Mul) => {
2942                        write!(self.out, "WaveActiveProduct(")?
2943                    }
2944                    (crate::CollectiveOperation::Reduce, crate::SubgroupOperation::Max) => {
2945                        write!(self.out, "WaveActiveMax(")?
2946                    }
2947                    (crate::CollectiveOperation::Reduce, crate::SubgroupOperation::Min) => {
2948                        write!(self.out, "WaveActiveMin(")?
2949                    }
2950                    (crate::CollectiveOperation::Reduce, crate::SubgroupOperation::And) => {
2951                        write!(self.out, "WaveActiveBitAnd(")?
2952                    }
2953                    (crate::CollectiveOperation::Reduce, crate::SubgroupOperation::Or) => {
2954                        write!(self.out, "WaveActiveBitOr(")?
2955                    }
2956                    (crate::CollectiveOperation::Reduce, crate::SubgroupOperation::Xor) => {
2957                        write!(self.out, "WaveActiveBitXor(")?
2958                    }
2959                    (crate::CollectiveOperation::ExclusiveScan, crate::SubgroupOperation::Add) => {
2960                        write!(self.out, "WavePrefixSum(")?
2961                    }
2962                    (crate::CollectiveOperation::ExclusiveScan, crate::SubgroupOperation::Mul) => {
2963                        write!(self.out, "WavePrefixProduct(")?
2964                    }
2965                    (crate::CollectiveOperation::InclusiveScan, crate::SubgroupOperation::Add) => {
2966                        self.write_expr(module, argument, func_ctx)?;
2967                        write!(self.out, " + WavePrefixSum(")?;
2968                    }
2969                    (crate::CollectiveOperation::InclusiveScan, crate::SubgroupOperation::Mul) => {
2970                        self.write_expr(module, argument, func_ctx)?;
2971                        write!(self.out, " * WavePrefixProduct(")?;
2972                    }
2973                    _ => unimplemented!(),
2974                }
2975                self.write_expr(module, argument, func_ctx)?;
2976                writeln!(self.out, ");")?;
2977            }
2978            Statement::SubgroupGather {
2979                mode,
2980                argument,
2981                result,
2982            } => {
2983                write!(self.out, "{level}")?;
2984                write!(self.out, "const ")?;
2985                let name = Baked(result).to_string();
2986                match func_ctx.info[result].ty {
2987                    proc::TypeResolution::Handle(handle) => self.write_type(module, handle)?,
2988                    proc::TypeResolution::Value(ref value) => {
2989                        self.write_value_type(module, value)?
2990                    }
2991                };
2992                write!(self.out, " {name} = ")?;
2993                self.named_expressions.insert(result, name);
2994                match mode {
2995                    crate::GatherMode::BroadcastFirst => {
2996                        write!(self.out, "WaveReadLaneFirst(")?;
2997                        self.write_expr(module, argument, func_ctx)?;
2998                    }
2999                    crate::GatherMode::QuadBroadcast(index) => {
3000                        write!(self.out, "QuadReadLaneAt(")?;
3001                        self.write_expr(module, argument, func_ctx)?;
3002                        write!(self.out, ", ")?;
3003                        self.write_expr(module, index, func_ctx)?;
3004                    }
3005                    crate::GatherMode::QuadSwap(direction) => {
3006                        match direction {
3007                            crate::Direction::X => {
3008                                write!(self.out, "QuadReadAcrossX(")?;
3009                            }
3010                            crate::Direction::Y => {
3011                                write!(self.out, "QuadReadAcrossY(")?;
3012                            }
3013                            crate::Direction::Diagonal => {
3014                                write!(self.out, "QuadReadAcrossDiagonal(")?;
3015                            }
3016                        }
3017                        self.write_expr(module, argument, func_ctx)?;
3018                    }
3019                    _ => {
3020                        write!(self.out, "WaveReadLaneAt(")?;
3021                        self.write_expr(module, argument, func_ctx)?;
3022                        write!(self.out, ", ")?;
3023                        match mode {
3024                            crate::GatherMode::BroadcastFirst => unreachable!(),
3025                            crate::GatherMode::Broadcast(index)
3026                            | crate::GatherMode::Shuffle(index) => {
3027                                self.write_expr(module, index, func_ctx)?;
3028                            }
3029                            crate::GatherMode::ShuffleDown(index) => {
3030                                write!(self.out, "WaveGetLaneIndex() + ")?;
3031                                self.write_expr(module, index, func_ctx)?;
3032                            }
3033                            crate::GatherMode::ShuffleUp(index) => {
3034                                write!(self.out, "WaveGetLaneIndex() - ")?;
3035                                self.write_expr(module, index, func_ctx)?;
3036                            }
3037                            crate::GatherMode::ShuffleXor(index) => {
3038                                write!(self.out, "WaveGetLaneIndex() ^ ")?;
3039                                self.write_expr(module, index, func_ctx)?;
3040                            }
3041                            crate::GatherMode::QuadBroadcast(_) => unreachable!(),
3042                            crate::GatherMode::QuadSwap(_) => unreachable!(),
3043                        }
3044                    }
3045                }
3046                writeln!(self.out, ");")?;
3047            }
3048            Statement::CooperativeStore { .. } => unimplemented!(),
3049            Statement::RayPipelineFunction(_) => unreachable!(),
3050        }
3051
3052        Ok(())
3053    }
3054
3055    fn write_const_expression(
3056        &mut self,
3057        module: &Module,
3058        expr: Handle<crate::Expression>,
3059        arena: &crate::Arena<crate::Expression>,
3060    ) -> BackendResult {
3061        self.write_possibly_const_expression(module, expr, arena, |writer, expr| {
3062            writer.write_const_expression(module, expr, arena)
3063        })
3064    }
3065
3066    pub(super) fn write_literal(&mut self, literal: crate::Literal) -> BackendResult {
3067        match literal {
3068            crate::Literal::F64(value) => write!(self.out, "{value:?}L")?,
3069            crate::Literal::F32(value) => write!(self.out, "{value:?}")?,
3070            crate::Literal::F16(value) => write!(self.out, "{value:?}h")?,
3071            crate::Literal::U16(value) => write!(self.out, "uint16_t({value})")?,
3072            crate::Literal::I16(value) => write!(self.out, "int16_t({value})")?,
3073            crate::Literal::U32(value) => write!(self.out, "{value}u")?,
3074            // `-2147483648` is parsed by some compilers as unary negation of
3075            // positive 2147483648, which is too large for an int, causing
3076            // issues for some compilers. Neither DXC nor FXC appear to have
3077            // this problem, but this is not specified and could change. We
3078            // therefore use `-2147483647 - 1` as a precaution.
3079            crate::Literal::I32(value) if value == i32::MIN => {
3080                write!(self.out, "int({} - 1)", value + 1)?
3081            }
3082            // HLSL has no suffix for explicit i32 literals, but not using any suffix
3083            // makes the type ambiguous which prevents overload resolution from
3084            // working. So we explicitly use the int() constructor syntax.
3085            crate::Literal::I32(value) => write!(self.out, "int({value})")?,
3086            crate::Literal::U64(value) => write!(self.out, "{value}uL")?,
3087            // I64 version of the minimum I32 value issue described above.
3088            crate::Literal::I64(value) if value == i64::MIN => {
3089                write!(self.out, "({}L - 1L)", value + 1)?;
3090            }
3091            crate::Literal::I64(value) => write!(self.out, "{value}L")?,
3092            crate::Literal::Bool(value) => write!(self.out, "{value}")?,
3093            crate::Literal::AbstractInt(_) | crate::Literal::AbstractFloat(_) => {
3094                return Err(Error::Custom(
3095                    "Abstract types should not appear in IR presented to backends".into(),
3096                ));
3097            }
3098        }
3099        Ok(())
3100    }
3101
3102    fn write_possibly_const_expression<E>(
3103        &mut self,
3104        module: &Module,
3105        expr: Handle<crate::Expression>,
3106        expressions: &crate::Arena<crate::Expression>,
3107        write_expression: E,
3108    ) -> BackendResult
3109    where
3110        E: Fn(&mut Self, Handle<crate::Expression>) -> BackendResult,
3111    {
3112        use crate::Expression;
3113
3114        match expressions[expr] {
3115            Expression::Literal(literal) => {
3116                self.write_literal(literal)?;
3117            }
3118            Expression::Constant(handle) => {
3119                let constant = &module.constants[handle];
3120                if constant.name.is_some() {
3121                    write!(self.out, "{}", self.names[&NameKey::Constant(handle)])?;
3122                } else {
3123                    self.write_const_expression(module, constant.init, &module.global_expressions)?;
3124                }
3125            }
3126            Expression::ZeroValue(ty) => {
3127                self.write_wrapped_zero_value_function_name(module, WrappedZeroValue { ty })?;
3128                write!(self.out, "()")?;
3129            }
3130            Expression::Compose { ty, ref components } => {
3131                match module.types[ty].inner {
3132                    TypeInner::Struct { .. } | TypeInner::Array { .. } => {
3133                        self.write_wrapped_constructor_function_name(
3134                            module,
3135                            WrappedConstructor { ty },
3136                        )?;
3137                    }
3138                    _ => {
3139                        self.write_type(module, ty)?;
3140                    }
3141                };
3142                write!(self.out, "(")?;
3143                for (index, component) in components.iter().enumerate() {
3144                    if index != 0 {
3145                        write!(self.out, ", ")?;
3146                    }
3147                    write_expression(self, *component)?;
3148                }
3149                write!(self.out, ")")?;
3150            }
3151            Expression::Splat { size, value } => {
3152                // hlsl is not supported one value constructor
3153                // if we write, for example, int4(0), dxc returns error:
3154                // error: too few elements in vector initialization (expected 4 elements, have 1)
3155                let number_of_components = match size {
3156                    crate::VectorSize::Bi => "xx",
3157                    crate::VectorSize::Tri => "xxx",
3158                    crate::VectorSize::Quad => "xxxx",
3159                };
3160                write!(self.out, "(")?;
3161                write_expression(self, value)?;
3162                write!(self.out, ").{number_of_components}")?
3163            }
3164            _ => {
3165                return Err(Error::Override);
3166            }
3167        }
3168
3169        Ok(())
3170    }
3171
3172    /// Helper method to write expressions
3173    ///
3174    /// # Notes
3175    /// Doesn't add any newlines or leading/trailing spaces
3176    pub(super) fn write_expr(
3177        &mut self,
3178        module: &Module,
3179        expr: Handle<crate::Expression>,
3180        func_ctx: &back::FunctionCtx<'_>,
3181    ) -> BackendResult {
3182        use crate::Expression;
3183
3184        // Handle the special semantics of vertex_index/instance_index
3185        let ff_input = if self.options.special_constants_binding.is_some() {
3186            func_ctx.is_fixed_function_input(expr, module)
3187        } else {
3188            None
3189        };
3190        let closing_bracket = match ff_input {
3191            Some(crate::BuiltIn::VertexIndex) => {
3192                write!(self.out, "({SPECIAL_CBUF_VAR}.{SPECIAL_FIRST_VERTEX} + ")?;
3193                ")"
3194            }
3195            Some(crate::BuiltIn::InstanceIndex) => {
3196                write!(self.out, "({SPECIAL_CBUF_VAR}.{SPECIAL_FIRST_INSTANCE} + ",)?;
3197                ")"
3198            }
3199            Some(crate::BuiltIn::NumWorkGroups) => {
3200                // Note: despite their names (`FIRST_VERTEX` and `FIRST_INSTANCE`),
3201                // in compute shaders the special constants contain the number
3202                // of workgroups, which we are using here.
3203                write!(
3204                    self.out,
3205                    "uint3({SPECIAL_CBUF_VAR}.{SPECIAL_FIRST_VERTEX}, {SPECIAL_CBUF_VAR}.{SPECIAL_FIRST_INSTANCE}, {SPECIAL_CBUF_VAR}.{SPECIAL_OTHER})",
3206                )?;
3207                return Ok(());
3208            }
3209            _ => "",
3210        };
3211
3212        if let Some(name) = self.named_expressions.get(&expr) {
3213            write!(self.out, "{name}{closing_bracket}")?;
3214            return Ok(());
3215        }
3216
3217        let expression = &func_ctx.expressions[expr];
3218
3219        match *expression {
3220            Expression::Literal(_)
3221            | Expression::Constant(_)
3222            | Expression::ZeroValue(_)
3223            | Expression::Compose { .. }
3224            | Expression::Splat { .. } => {
3225                self.write_possibly_const_expression(
3226                    module,
3227                    expr,
3228                    func_ctx.expressions,
3229                    |writer, expr| writer.write_expr(module, expr, func_ctx),
3230                )?;
3231            }
3232            Expression::Override(_) => return Err(Error::Override),
3233            // Avoid undefined behaviour for addition, subtraction, and
3234            // multiplication of signed integers by casting operands to
3235            // unsigned, performing the operation, then casting the result back
3236            // to signed.
3237            // TODO(#7109): This relies on the asint()/asuint() functions which only work
3238            // for 32-bit types, so we must find another solution for different bit widths.
3239            Expression::Binary {
3240                op:
3241                    op @ crate::BinaryOperator::Add
3242                    | op @ crate::BinaryOperator::Subtract
3243                    | op @ crate::BinaryOperator::Multiply,
3244                left,
3245                right,
3246            } if matches!(
3247                func_ctx.resolve_type(expr, &module.types).scalar(),
3248                Some(Scalar::I32)
3249            ) =>
3250            {
3251                write!(self.out, "asint(asuint(",)?;
3252                self.write_expr(module, left, func_ctx)?;
3253                write!(self.out, ") {} asuint(", back::binary_operation_str(op))?;
3254                self.write_expr(module, right, func_ctx)?;
3255                write!(self.out, "))")?;
3256            }
3257            // All of the multiplication can be expressed as `mul`,
3258            // except vector * vector, which needs to use the "*" operator.
3259            Expression::Binary {
3260                op: crate::BinaryOperator::Multiply,
3261                left,
3262                right,
3263            } if func_ctx.resolve_type(left, &module.types).is_matrix()
3264                || func_ctx.resolve_type(right, &module.types).is_matrix() =>
3265            {
3266                // We intentionally flip the order of multiplication as our matrices are implicitly transposed.
3267                write!(self.out, "mul(")?;
3268                self.write_expr(module, right, func_ctx)?;
3269                write!(self.out, ", ")?;
3270                self.write_expr(module, left, func_ctx)?;
3271                write!(self.out, ")")?;
3272            }
3273
3274            // WGSL says that floating-point division by zero should return
3275            // infinity. Microsoft's Direct3D 11 functional specification
3276            // (https://microsoft.github.io/DirectX-Specs/d3d/archive/D3D11_3_FunctionalSpec.htm)
3277            // says:
3278            //
3279            //     Divide by 0 produces +/- INF, except 0/0 which results in NaN.
3280            //
3281            // which is what we want. The DXIL specification for the FDiv
3282            // instruction corroborates this:
3283            //
3284            // https://github.com/microsoft/DirectXShaderCompiler/blob/main/docs/DXIL.rst#fdiv
3285            Expression::Binary {
3286                op: crate::BinaryOperator::Divide,
3287                left,
3288                right,
3289            } if matches!(
3290                func_ctx.resolve_type(expr, &module.types).scalar_kind(),
3291                Some(ScalarKind::Sint | ScalarKind::Uint)
3292            ) =>
3293            {
3294                write!(self.out, "{DIV_FUNCTION}(")?;
3295                self.write_expr(module, left, func_ctx)?;
3296                write!(self.out, ", ")?;
3297                self.write_expr(module, right, func_ctx)?;
3298                write!(self.out, ")")?;
3299            }
3300
3301            Expression::Binary {
3302                op: crate::BinaryOperator::Modulo,
3303                left,
3304                right,
3305            } if matches!(
3306                func_ctx.resolve_type(expr, &module.types).scalar_kind(),
3307                Some(ScalarKind::Sint | ScalarKind::Uint | ScalarKind::Float)
3308            ) =>
3309            {
3310                write!(self.out, "{MOD_FUNCTION}(")?;
3311                self.write_expr(module, left, func_ctx)?;
3312                write!(self.out, ", ")?;
3313                self.write_expr(module, right, func_ctx)?;
3314                write!(self.out, ")")?;
3315            }
3316
3317            Expression::Binary { op, left, right } => {
3318                write!(self.out, "(")?;
3319                self.write_expr(module, left, func_ctx)?;
3320                write!(self.out, " {} ", back::binary_operation_str(op))?;
3321                self.write_expr(module, right, func_ctx)?;
3322                write!(self.out, ")")?;
3323            }
3324            Expression::Access { base, index } => {
3325                if let Some(crate::AddressSpace::Storage { .. }) =
3326                    func_ctx.resolve_type(expr, &module.types).pointer_space()
3327                {
3328                    // do nothing, the chain is written on `Load`/`Store`
3329                } else {
3330                    // We use the function __get_col_of_matCx2 here in cases
3331                    // where `base`s type resolves to a matCx2 and is part of a
3332                    // struct member with type of (possibly nested) array of matCx2's.
3333                    //
3334                    // Note that this only works for `Load`s and we handle
3335                    // `Store`s differently in `Statement::Store`.
3336                    if let Some(MatrixType {
3337                        columns,
3338                        rows: crate::VectorSize::Bi,
3339                        width: 4,
3340                    }) = get_inner_matrix_of_struct_array_member(module, base, func_ctx, true)
3341                        .or_else(|| get_global_uniform_matrix(module, base, func_ctx))
3342                    {
3343                        write!(self.out, "__get_col_of_mat{}x2(", columns as u8)?;
3344                        self.write_expr(module, base, func_ctx)?;
3345                        write!(self.out, ", ")?;
3346                        self.write_expr(module, index, func_ctx)?;
3347                        write!(self.out, ")")?;
3348                        return Ok(());
3349                    }
3350
3351                    let resolved = func_ctx.resolve_type(base, &module.types);
3352
3353                    let (indexing_binding_array, non_uniform_qualifier) = match *resolved {
3354                        TypeInner::BindingArray { .. } => {
3355                            let uniformity = &func_ctx.info[index].uniformity;
3356
3357                            (true, uniformity.non_uniform_result.is_some())
3358                        }
3359                        _ => (false, false),
3360                    };
3361
3362                    self.write_expr(module, base, func_ctx)?;
3363
3364                    let array_sampler_info = self.sampler_binding_array_info_from_expression(
3365                        module, func_ctx, base, resolved,
3366                    );
3367
3368                    if let Some(ref info) = array_sampler_info {
3369                        write!(self.out, "{}[", info.sampler_heap_name)?;
3370                    } else {
3371                        write!(self.out, "[")?;
3372                    }
3373
3374                    let needs_bound_check = self.options.restrict_indexing
3375                        && !indexing_binding_array
3376                        && match resolved.pointer_space() {
3377                            Some(
3378                                crate::AddressSpace::Function
3379                                | crate::AddressSpace::Private
3380                                | crate::AddressSpace::WorkGroup
3381                                | crate::AddressSpace::Immediate
3382                                | crate::AddressSpace::TaskPayload
3383                                | crate::AddressSpace::RayPayload
3384                                | crate::AddressSpace::IncomingRayPayload,
3385                            )
3386                            | None => true,
3387                            Some(crate::AddressSpace::Uniform) => {
3388                                // check if BindTarget.restrict_indexing is set, this is used for dynamic buffers
3389                                let var_handle = self.fill_access_chain(module, base, func_ctx)?;
3390                                let bind_target = self
3391                                    .options
3392                                    .resolve_resource_binding(
3393                                        module.global_variables[var_handle]
3394                                            .binding
3395                                            .as_ref()
3396                                            .unwrap(),
3397                                    )
3398                                    .unwrap();
3399                                bind_target.restrict_indexing
3400                            }
3401                            Some(
3402                                crate::AddressSpace::Handle | crate::AddressSpace::Storage { .. },
3403                            ) => unreachable!(),
3404                        };
3405                    // Decide whether this index needs to be clamped to fall within range.
3406                    let restriction_needed = if needs_bound_check {
3407                        index::access_needs_check(
3408                            base,
3409                            index::GuardedIndex::Expression(index),
3410                            module,
3411                            func_ctx.expressions,
3412                            func_ctx.info,
3413                        )
3414                    } else {
3415                        None
3416                    };
3417                    if let Some(limit) = restriction_needed {
3418                        write!(self.out, "min(uint(")?;
3419                        self.write_expr(module, index, func_ctx)?;
3420                        write!(self.out, "), ")?;
3421                        match limit {
3422                            index::IndexableLength::Known(limit) => {
3423                                write!(self.out, "{}u", limit - 1)?;
3424                            }
3425                            index::IndexableLength::Dynamic => unreachable!(),
3426                        }
3427                        write!(self.out, ")")?;
3428                    } else {
3429                        if non_uniform_qualifier {
3430                            write!(self.out, "NonUniformResourceIndex(")?;
3431                        }
3432                        if let Some(ref info) = array_sampler_info {
3433                            write!(
3434                                self.out,
3435                                "{}[{} + ",
3436                                info.sampler_index_buffer_name, info.binding_array_base_index_name,
3437                            )?;
3438                        }
3439                        self.write_expr(module, index, func_ctx)?;
3440                        if array_sampler_info.is_some() {
3441                            write!(self.out, "]")?;
3442                        }
3443                        if non_uniform_qualifier {
3444                            write!(self.out, ")")?;
3445                        }
3446                    }
3447
3448                    write!(self.out, "]")?;
3449                }
3450            }
3451            Expression::AccessIndex { base, index } => {
3452                if let Some(crate::AddressSpace::Storage { .. }) =
3453                    func_ctx.resolve_type(expr, &module.types).pointer_space()
3454                {
3455                    // do nothing, the chain is written on `Load`/`Store`
3456                } else {
3457                    // See if we need to write the matrix column access in a
3458                    // special way since the type of `base` is our special
3459                    // __matCx2 struct.
3460                    if let Some(MatrixType {
3461                        rows: crate::VectorSize::Bi,
3462                        width: 4,
3463                        ..
3464                    }) = get_inner_matrix_of_struct_array_member(module, base, func_ctx, true)
3465                        .or_else(|| get_global_uniform_matrix(module, base, func_ctx))
3466                    {
3467                        self.write_expr(module, base, func_ctx)?;
3468                        write!(self.out, "._{index}")?;
3469                        return Ok(());
3470                    }
3471
3472                    let base_ty_res = &func_ctx.info[base].ty;
3473                    let mut resolved = base_ty_res.inner_with(&module.types);
3474                    let base_ty_handle = match *resolved {
3475                        TypeInner::Pointer { base, .. } => {
3476                            resolved = &module.types[base].inner;
3477                            Some(base)
3478                        }
3479                        _ => base_ty_res.handle(),
3480                    };
3481
3482                    // We treat matrices of the form `matCx2` as a sequence of C `vec2`s.
3483                    // See the module-level block comment in mod.rs for details.
3484                    //
3485                    // We handle matrix reconstruction here for Loads.
3486                    // Stores are handled directly by `Statement::Store`.
3487                    if let TypeInner::Struct { ref members, .. } = *resolved {
3488                        let member = &members[index as usize];
3489
3490                        match module.types[member.ty].inner {
3491                            TypeInner::Matrix {
3492                                rows: crate::VectorSize::Bi,
3493                                ..
3494                            } if member.binding.is_none() => {
3495                                let ty = base_ty_handle.unwrap();
3496                                self.write_wrapped_struct_matrix_get_function_name(
3497                                    WrappedStructMatrixAccess { ty, index },
3498                                )?;
3499                                write!(self.out, "(")?;
3500                                self.write_expr(module, base, func_ctx)?;
3501                                write!(self.out, ")")?;
3502                                return Ok(());
3503                            }
3504                            _ => {}
3505                        }
3506                    }
3507
3508                    let array_sampler_info = self.sampler_binding_array_info_from_expression(
3509                        module, func_ctx, base, resolved,
3510                    );
3511
3512                    if let Some(ref info) = array_sampler_info {
3513                        write!(
3514                            self.out,
3515                            "{}[{}",
3516                            info.sampler_heap_name, info.sampler_index_buffer_name
3517                        )?;
3518                    }
3519
3520                    self.write_expr(module, base, func_ctx)?;
3521
3522                    match *resolved {
3523                        // We specifically lift the ValuePointer to this case. While `[0]` is valid
3524                        // HLSL for any vector behind a value pointer, FXC completely miscompiles
3525                        // it and generates completely nonsensical DXBC.
3526                        //
3527                        // See https://github.com/gfx-rs/naga/issues/2095 for more details.
3528                        TypeInner::Vector { .. } | TypeInner::ValuePointer { .. } => {
3529                            // Write vector access as a swizzle
3530                            write!(self.out, ".{}", back::COMPONENTS[index as usize])?
3531                        }
3532                        TypeInner::Matrix { .. }
3533                        | TypeInner::Array { .. }
3534                        | TypeInner::BindingArray { .. } => {
3535                            if let Some(ref info) = array_sampler_info {
3536                                write!(
3537                                    self.out,
3538                                    "[{} + {index}]",
3539                                    info.binding_array_base_index_name
3540                                )?;
3541                            } else {
3542                                write!(self.out, "[{index}]")?;
3543                            }
3544                        }
3545                        TypeInner::Struct { .. } => {
3546                            // This will never panic in case the type is a `Struct`, this is not true
3547                            // for other types so we can only check while inside this match arm
3548                            let ty = base_ty_handle.unwrap();
3549
3550                            write!(
3551                                self.out,
3552                                ".{}",
3553                                &self.names[&NameKey::StructMember(ty, index)]
3554                            )?
3555                        }
3556                        ref other => return Err(Error::Custom(format!("Cannot index {other:?}"))),
3557                    }
3558
3559                    if array_sampler_info.is_some() {
3560                        write!(self.out, "]")?;
3561                    }
3562                }
3563            }
3564            Expression::FunctionArgument(pos) => {
3565                let ty = func_ctx.resolve_type(expr, &module.types);
3566
3567                // We know that any external texture function argument has been expanded into
3568                // separate consecutive arguments for each plane and the parameters buffer. And we
3569                // also know that external textures can only ever be used as an argument to another
3570                // function. Therefore we can simply emit each of the expanded arguments in a
3571                // consecutive comma-separated list.
3572                if let TypeInner::Image {
3573                    class: crate::ImageClass::External,
3574                    ..
3575                } = *ty
3576                {
3577                    let plane_names = [0, 1, 2].map(|i| {
3578                        &self.names[&func_ctx
3579                            .external_texture_argument_key(pos, ExternalTextureNameKey::Plane(i))]
3580                    });
3581                    let params_name = &self.names[&func_ctx
3582                        .external_texture_argument_key(pos, ExternalTextureNameKey::Params)];
3583                    write!(
3584                        self.out,
3585                        "{}, {}, {}, {}",
3586                        plane_names[0], plane_names[1], plane_names[2], params_name
3587                    )?;
3588                } else {
3589                    let key = func_ctx.argument_key(pos);
3590                    let name = &self.names[&key];
3591                    write!(self.out, "{name}")?;
3592                }
3593            }
3594            Expression::ImageSample {
3595                coordinate,
3596                image,
3597                sampler,
3598                clamp_to_edge: true,
3599                gather: None,
3600                array_index: None,
3601                offset: None,
3602                level: crate::SampleLevel::Zero,
3603                depth_ref: None,
3604            } => {
3605                write!(self.out, "{IMAGE_SAMPLE_BASE_CLAMP_TO_EDGE_FUNCTION}(")?;
3606                self.write_expr(module, image, func_ctx)?;
3607                write!(self.out, ", ")?;
3608                self.write_expr(module, sampler, func_ctx)?;
3609                write!(self.out, ", ")?;
3610                self.write_expr(module, coordinate, func_ctx)?;
3611                write!(self.out, ")")?;
3612            }
3613            Expression::ImageSample {
3614                image,
3615                sampler,
3616                gather,
3617                coordinate,
3618                array_index,
3619                offset,
3620                level,
3621                depth_ref,
3622                clamp_to_edge,
3623            } => {
3624                if clamp_to_edge {
3625                    return Err(Error::Custom(
3626                        "ImageSample::clamp_to_edge should have been validated out".to_string(),
3627                    ));
3628                }
3629
3630                use crate::SampleLevel as Sl;
3631                const COMPONENTS: [&str; 4] = ["", "Green", "Blue", "Alpha"];
3632
3633                let (base_str, component_str) = match gather {
3634                    Some(component) => ("Gather", COMPONENTS[component as usize]),
3635                    None => ("Sample", ""),
3636                };
3637                let cmp_str = match depth_ref {
3638                    Some(_) => "Cmp",
3639                    None => "",
3640                };
3641                let level_str = match level {
3642                    Sl::Zero if gather.is_none() => "LevelZero",
3643                    Sl::Auto | Sl::Zero => "",
3644                    Sl::Exact(_) => "Level",
3645                    Sl::Bias(_) => "Bias",
3646                    Sl::Gradient { .. } => "Grad",
3647                };
3648
3649                self.write_expr(module, image, func_ctx)?;
3650                write!(self.out, ".{base_str}{cmp_str}{component_str}{level_str}(")?;
3651                self.write_expr(module, sampler, func_ctx)?;
3652                write!(self.out, ", ")?;
3653                self.write_texture_coordinates(
3654                    "float",
3655                    coordinate,
3656                    array_index,
3657                    None,
3658                    module,
3659                    func_ctx,
3660                )?;
3661
3662                if let Some(depth_ref) = depth_ref {
3663                    write!(self.out, ", ")?;
3664                    self.write_expr(module, depth_ref, func_ctx)?;
3665                }
3666
3667                match level {
3668                    Sl::Auto | Sl::Zero => {}
3669                    Sl::Exact(expr) => {
3670                        write!(self.out, ", ")?;
3671                        self.write_expr(module, expr, func_ctx)?;
3672                    }
3673                    Sl::Bias(expr) => {
3674                        write!(self.out, ", ")?;
3675                        self.write_expr(module, expr, func_ctx)?;
3676                    }
3677                    Sl::Gradient { x, y } => {
3678                        write!(self.out, ", ")?;
3679                        self.write_expr(module, x, func_ctx)?;
3680                        write!(self.out, ", ")?;
3681                        self.write_expr(module, y, func_ctx)?;
3682                    }
3683                }
3684
3685                if let Some(offset) = offset {
3686                    write!(self.out, ", ")?;
3687                    write!(self.out, "int2(")?; // work around https://github.com/microsoft/DirectXShaderCompiler/issues/5082#issuecomment-1540147807
3688                    self.write_const_expression(module, offset, func_ctx.expressions)?;
3689                    write!(self.out, ")")?;
3690                }
3691
3692                write!(self.out, ")")?;
3693            }
3694            Expression::ImageQuery { image, query } => {
3695                // use wrapped image query function
3696                if let TypeInner::Image {
3697                    dim,
3698                    arrayed,
3699                    class,
3700                } = *func_ctx.resolve_type(image, &module.types)
3701                {
3702                    let wrapped_image_query = WrappedImageQuery {
3703                        dim,
3704                        arrayed,
3705                        class,
3706                        query: query.into(),
3707                    };
3708
3709                    self.write_wrapped_image_query_function_name(wrapped_image_query)?;
3710                    write!(self.out, "(")?;
3711                    // Image always first param
3712                    self.write_expr(module, image, func_ctx)?;
3713                    if let crate::ImageQuery::Size { level: Some(level) } = query {
3714                        write!(self.out, ", ")?;
3715                        self.write_expr(module, level, func_ctx)?;
3716                    }
3717                    write!(self.out, ")")?;
3718                }
3719            }
3720            Expression::ImageLoad {
3721                image,
3722                coordinate,
3723                array_index,
3724                sample,
3725                level,
3726            } => self.write_image_load(
3727                &module,
3728                expr,
3729                func_ctx,
3730                image,
3731                coordinate,
3732                array_index,
3733                sample,
3734                level,
3735            )?,
3736            Expression::GlobalVariable(handle) => {
3737                let global_variable = &module.global_variables[handle];
3738                let ty = &module.types[global_variable.ty].inner;
3739
3740                // In the case of binding arrays of samplers, we need to not write anything
3741                // as the we are in the wrong position to fully write the expression.
3742                //
3743                // The entire writing is done by AccessIndex.
3744                let is_binding_array_of_samplers = match *ty {
3745                    TypeInner::BindingArray { base, .. } => {
3746                        let base_ty = &module.types[base].inner;
3747                        matches!(*base_ty, TypeInner::Sampler { .. })
3748                    }
3749                    _ => false,
3750                };
3751
3752                let is_storage_space =
3753                    matches!(global_variable.space, crate::AddressSpace::Storage { .. });
3754
3755                // Our external texture global variable has been expanded into multiple
3756                // global variables, one for each plane and the parameters buffer.
3757                // External textures can only ever be used as arguments to a function
3758                // call, and we know that an external texture argument to any function
3759                // will have been expanded to separate consecutive arguments for each
3760                // plane and the parameters buffer. Therefore we can simply emit each of
3761                // the expanded global variables in a consecutive comma-separated list.
3762                if let TypeInner::Image {
3763                    class: crate::ImageClass::External,
3764                    ..
3765                } = *ty
3766                {
3767                    let plane_names = [0, 1, 2].map(|i| {
3768                        &self.names[&NameKey::ExternalTextureGlobalVariable(
3769                            handle,
3770                            ExternalTextureNameKey::Plane(i),
3771                        )]
3772                    });
3773                    let params_name = &self.names[&NameKey::ExternalTextureGlobalVariable(
3774                        handle,
3775                        ExternalTextureNameKey::Params,
3776                    )];
3777                    write!(
3778                        self.out,
3779                        "{}, {}, {}, {}",
3780                        plane_names[0], plane_names[1], plane_names[2], params_name
3781                    )?;
3782                } else if !is_binding_array_of_samplers && !is_storage_space {
3783                    let name = &self.names[&NameKey::GlobalVariable(handle)];
3784                    write!(self.out, "{name}")?;
3785                }
3786            }
3787            Expression::LocalVariable(handle) => {
3788                write!(self.out, "{}", self.names[&func_ctx.name_key(handle)])?
3789            }
3790            Expression::Load { pointer } => {
3791                match func_ctx
3792                    .resolve_type(pointer, &module.types)
3793                    .pointer_space()
3794                {
3795                    Some(crate::AddressSpace::Storage { .. }) => {
3796                        let var_handle = self.fill_access_chain(module, pointer, func_ctx)?;
3797                        let result_ty = func_ctx.info[expr].ty.clone();
3798                        self.write_storage_load(module, var_handle, result_ty, func_ctx)?;
3799                    }
3800                    _ => {
3801                        let mut close_paren = false;
3802
3803                        // We cast the value loaded to a native HLSL floatCx2
3804                        // in cases where it is of type:
3805                        //  - __matCx2 or
3806                        //  - a (possibly nested) array of __matCx2's
3807                        if let Some(MatrixType {
3808                            rows: crate::VectorSize::Bi,
3809                            width: 4,
3810                            ..
3811                        }) = get_inner_matrix_of_struct_array_member(
3812                            module, pointer, func_ctx, false,
3813                        )
3814                        .or_else(|| get_inner_matrix_of_global_uniform(module, pointer, func_ctx))
3815                        {
3816                            let mut resolved = func_ctx.resolve_type(pointer, &module.types);
3817                            let ptr_tr = resolved.pointer_base_type();
3818                            if let Some(ptr_ty) =
3819                                ptr_tr.as_ref().map(|tr| tr.inner_with(&module.types))
3820                            {
3821                                resolved = ptr_ty;
3822                            }
3823
3824                            write!(self.out, "((")?;
3825                            if let TypeInner::Array { base, size, .. } = *resolved {
3826                                self.write_type(module, base)?;
3827                                self.write_array_size(module, base, size)?;
3828                            } else {
3829                                self.write_value_type(module, resolved)?;
3830                            }
3831                            write!(self.out, ")")?;
3832                            close_paren = true;
3833                        }
3834
3835                        self.write_expr(module, pointer, func_ctx)?;
3836
3837                        if close_paren {
3838                            write!(self.out, ")")?;
3839                        }
3840                    }
3841                }
3842            }
3843            Expression::Unary { op, expr } => {
3844                // https://docs.microsoft.com/en-us/windows/win32/direct3dhlsl/dx-graphics-hlsl-operators#unary-operators
3845                let op_str = match op {
3846                    crate::UnaryOperator::Negate => {
3847                        match func_ctx.resolve_type(expr, &module.types).scalar() {
3848                            Some(Scalar::I32) => NEG_FUNCTION,
3849                            _ => "-",
3850                        }
3851                    }
3852                    crate::UnaryOperator::LogicalNot => "!",
3853                    crate::UnaryOperator::BitwiseNot => "~",
3854                };
3855                write!(self.out, "{op_str}(")?;
3856                self.write_expr(module, expr, func_ctx)?;
3857                write!(self.out, ")")?;
3858            }
3859            Expression::As {
3860                expr,
3861                kind,
3862                convert,
3863            } => {
3864                let inner = func_ctx.resolve_type(expr, &module.types);
3865                if inner.scalar_kind() == Some(ScalarKind::Float)
3866                    && (kind == ScalarKind::Sint || kind == ScalarKind::Uint)
3867                    && convert.is_some()
3868                    && matches!(convert, Some(4) | Some(8))
3869                {
3870                    // Use helper functions for float to int casts in order to
3871                    // avoid undefined behaviour when value is out of range for
3872                    // the target type.
3873                    let fun_name = match (kind, convert) {
3874                        (ScalarKind::Sint, Some(4)) => F2I32_FUNCTION,
3875                        (ScalarKind::Uint, Some(4)) => F2U32_FUNCTION,
3876                        (ScalarKind::Sint, Some(8)) => F2I64_FUNCTION,
3877                        (ScalarKind::Uint, Some(8)) => F2U64_FUNCTION,
3878                        _ => unreachable!(),
3879                    };
3880                    write!(self.out, "{fun_name}(")?;
3881                    self.write_expr(module, expr, func_ctx)?;
3882                    write!(self.out, ")")?;
3883                } else {
3884                    let close_paren = match convert {
3885                        Some(dst_width) => {
3886                            let scalar = Scalar {
3887                                kind,
3888                                width: dst_width,
3889                            };
3890                            match *inner {
3891                                TypeInner::Vector { size, .. } => {
3892                                    write!(
3893                                        self.out,
3894                                        "{}{}(",
3895                                        scalar.to_hlsl_str()?,
3896                                        common::vector_size_str(size)
3897                                    )?;
3898                                }
3899                                TypeInner::Scalar(_) => {
3900                                    write!(self.out, "{}(", scalar.to_hlsl_str()?,)?;
3901                                }
3902                                TypeInner::Matrix { columns, rows, .. } => {
3903                                    write!(
3904                                        self.out,
3905                                        "{}{}x{}(",
3906                                        scalar.to_hlsl_str()?,
3907                                        common::vector_size_str(columns),
3908                                        common::vector_size_str(rows)
3909                                    )?;
3910                                }
3911                                _ => {
3912                                    return Err(Error::Unimplemented(format!(
3913                                        "write_expr expression::as {inner:?}"
3914                                    )));
3915                                }
3916                            };
3917                            true
3918                        }
3919                        None => {
3920                            if inner.scalar_width() == Some(8) {
3921                                false
3922                            } else if inner.scalar_width() == Some(2) {
3923                                // HLSL's asint()/asuint() only work on 32-bit types.
3924                                // For 16-bit bitcasts, use type constructor instead.
3925                                let dst_scalar = Scalar { kind, width: 2 };
3926                                match *inner {
3927                                    TypeInner::Vector { size, .. } => {
3928                                        write!(
3929                                            self.out,
3930                                            "{}{}(",
3931                                            dst_scalar.to_hlsl_str()?,
3932                                            common::vector_size_str(size)
3933                                        )?;
3934                                    }
3935                                    _ => {
3936                                        write!(self.out, "{}(", dst_scalar.to_hlsl_str()?)?;
3937                                    }
3938                                };
3939                                true
3940                            } else {
3941                                write!(self.out, "{}(", kind.to_hlsl_cast(),)?;
3942                                true
3943                            }
3944                        }
3945                    };
3946                    self.write_expr(module, expr, func_ctx)?;
3947                    if close_paren {
3948                        write!(self.out, ")")?;
3949                    }
3950                }
3951            }
3952            Expression::Math {
3953                fun,
3954                arg,
3955                arg1,
3956                arg2,
3957                arg3,
3958            } => {
3959                use crate::MathFunction as Mf;
3960
3961                enum Function {
3962                    Asincosh { is_sin: bool },
3963                    Atanh,
3964                    Pack2x16float,
3965                    Pack2x16snorm,
3966                    Pack2x16unorm,
3967                    Pack4x8snorm,
3968                    Pack4x8unorm,
3969                    Pack4xI8,
3970                    Pack4xU8,
3971                    Pack4xI8Clamp,
3972                    Pack4xU8Clamp,
3973                    Unpack2x16float,
3974                    Unpack2x16snorm,
3975                    Unpack2x16unorm,
3976                    Unpack4x8snorm,
3977                    Unpack4x8unorm,
3978                    Unpack4xI8,
3979                    Unpack4xU8,
3980                    Dot4I8Packed,
3981                    Dot4U8Packed,
3982                    QuantizeToF16,
3983                    Regular(&'static str),
3984                    MissingIntOverload(&'static str),
3985                    MissingIntReturnType(&'static str),
3986                    CountTrailingZeros,
3987                    CountLeadingZeros,
3988                }
3989
3990                let fun = match fun {
3991                    // comparison
3992                    Mf::Abs => match func_ctx.resolve_type(arg, &module.types).scalar() {
3993                        Some(Scalar::I32) => Function::Regular(ABS_FUNCTION),
3994                        _ => Function::Regular("abs"),
3995                    },
3996                    Mf::Min => Function::Regular("min"),
3997                    Mf::Max => Function::Regular("max"),
3998                    Mf::Clamp => Function::Regular("clamp"),
3999                    Mf::Saturate => Function::Regular("saturate"),
4000                    // trigonometry
4001                    Mf::Cos => Function::Regular("cos"),
4002                    Mf::Cosh => Function::Regular("cosh"),
4003                    Mf::Sin => Function::Regular("sin"),
4004                    Mf::Sinh => Function::Regular("sinh"),
4005                    Mf::Tan => Function::Regular("tan"),
4006                    Mf::Tanh => Function::Regular("tanh"),
4007                    Mf::Acos => Function::Regular("acos"),
4008                    Mf::Asin => Function::Regular("asin"),
4009                    Mf::Atan => Function::Regular("atan"),
4010                    Mf::Atan2 => Function::Regular("atan2"),
4011                    Mf::Asinh => Function::Asincosh { is_sin: true },
4012                    Mf::Acosh => Function::Asincosh { is_sin: false },
4013                    Mf::Atanh => Function::Atanh,
4014                    Mf::Radians => Function::Regular("radians"),
4015                    Mf::Degrees => Function::Regular("degrees"),
4016                    // decomposition
4017                    Mf::Ceil => Function::Regular("ceil"),
4018                    Mf::Floor => Function::Regular("floor"),
4019                    Mf::Round => Function::Regular("round"),
4020                    Mf::Fract => Function::Regular("frac"),
4021                    Mf::Trunc => Function::Regular("trunc"),
4022                    Mf::Modf => Function::Regular(MODF_FUNCTION),
4023                    Mf::Frexp => Function::Regular(FREXP_FUNCTION),
4024                    Mf::Ldexp => Function::Regular("ldexp"),
4025                    // exponent
4026                    Mf::Exp => Function::Regular("exp"),
4027                    Mf::Exp2 => Function::Regular("exp2"),
4028                    Mf::Log => Function::Regular("log"),
4029                    Mf::Log2 => Function::Regular("log2"),
4030                    Mf::Pow => Function::Regular("pow"),
4031                    // geometry
4032                    Mf::Dot => Function::Regular("dot"),
4033                    Mf::Dot4I8Packed => Function::Dot4I8Packed,
4034                    Mf::Dot4U8Packed => Function::Dot4U8Packed,
4035                    //Mf::Outer => ,
4036                    Mf::Cross => Function::Regular("cross"),
4037                    Mf::Distance => Function::Regular("distance"),
4038                    Mf::Length => Function::Regular("length"),
4039                    Mf::Normalize => Function::Regular("normalize"),
4040                    Mf::FaceForward => Function::Regular("faceforward"),
4041                    Mf::Reflect => Function::Regular("reflect"),
4042                    Mf::Refract => Function::Regular("refract"),
4043                    // computational
4044                    Mf::Sign => Function::Regular("sign"),
4045                    Mf::Fma => Function::Regular("mad"),
4046                    Mf::Mix => Function::Regular("lerp"),
4047                    Mf::Step => Function::Regular("step"),
4048                    Mf::SmoothStep => Function::Regular("smoothstep"),
4049                    Mf::Sqrt => Function::Regular("sqrt"),
4050                    Mf::InverseSqrt => Function::Regular("rsqrt"),
4051                    //Mf::Inverse =>,
4052                    Mf::Transpose => Function::Regular("transpose"),
4053                    Mf::Determinant => Function::Regular("determinant"),
4054                    Mf::QuantizeToF16 => Function::QuantizeToF16,
4055                    // bits
4056                    Mf::CountTrailingZeros => Function::CountTrailingZeros,
4057                    Mf::CountLeadingZeros => Function::CountLeadingZeros,
4058                    Mf::CountOneBits => Function::MissingIntOverload("countbits"),
4059                    Mf::ReverseBits => Function::MissingIntOverload("reversebits"),
4060                    Mf::FirstTrailingBit => Function::MissingIntReturnType("firstbitlow"),
4061                    Mf::FirstLeadingBit => Function::MissingIntReturnType("firstbithigh"),
4062                    Mf::ExtractBits => Function::Regular(EXTRACT_BITS_FUNCTION),
4063                    Mf::InsertBits => Function::Regular(INSERT_BITS_FUNCTION),
4064                    // Data Packing
4065                    Mf::Pack2x16float => Function::Pack2x16float,
4066                    Mf::Pack2x16snorm => Function::Pack2x16snorm,
4067                    Mf::Pack2x16unorm => Function::Pack2x16unorm,
4068                    Mf::Pack4x8snorm => Function::Pack4x8snorm,
4069                    Mf::Pack4x8unorm => Function::Pack4x8unorm,
4070                    Mf::Pack4xI8 => Function::Pack4xI8,
4071                    Mf::Pack4xU8 => Function::Pack4xU8,
4072                    Mf::Pack4xI8Clamp => Function::Pack4xI8Clamp,
4073                    Mf::Pack4xU8Clamp => Function::Pack4xU8Clamp,
4074                    // Data Unpacking
4075                    Mf::Unpack2x16float => Function::Unpack2x16float,
4076                    Mf::Unpack2x16snorm => Function::Unpack2x16snorm,
4077                    Mf::Unpack2x16unorm => Function::Unpack2x16unorm,
4078                    Mf::Unpack4x8snorm => Function::Unpack4x8snorm,
4079                    Mf::Unpack4x8unorm => Function::Unpack4x8unorm,
4080                    Mf::Unpack4xI8 => Function::Unpack4xI8,
4081                    Mf::Unpack4xU8 => Function::Unpack4xU8,
4082                    _ => return Err(Error::Unimplemented(format!("write_expr_math {fun:?}"))),
4083                };
4084
4085                match fun {
4086                    Function::Asincosh { is_sin } => {
4087                        write!(self.out, "log(")?;
4088                        self.write_expr(module, arg, func_ctx)?;
4089                        write!(self.out, " + sqrt(")?;
4090                        self.write_expr(module, arg, func_ctx)?;
4091                        write!(self.out, " * ")?;
4092                        self.write_expr(module, arg, func_ctx)?;
4093                        match is_sin {
4094                            true => write!(self.out, " + 1.0))")?,
4095                            false => write!(self.out, " - 1.0))")?,
4096                        }
4097                    }
4098                    Function::Atanh => {
4099                        write!(self.out, "0.5 * log((1.0 + ")?;
4100                        self.write_expr(module, arg, func_ctx)?;
4101                        write!(self.out, ") / (1.0 - ")?;
4102                        self.write_expr(module, arg, func_ctx)?;
4103                        write!(self.out, "))")?;
4104                    }
4105                    Function::Pack2x16float => {
4106                        write!(self.out, "(f32tof16(")?;
4107                        self.write_expr(module, arg, func_ctx)?;
4108                        write!(self.out, "[0]) | f32tof16(")?;
4109                        self.write_expr(module, arg, func_ctx)?;
4110                        write!(self.out, "[1]) << 16)")?;
4111                    }
4112                    Function::Pack2x16snorm => {
4113                        let scale = 32767;
4114
4115                        write!(self.out, "uint((int(round(clamp(")?;
4116                        self.write_expr(module, arg, func_ctx)?;
4117                        write!(
4118                            self.out,
4119                            "[0], -1.0, 1.0) * {scale}.0)) & 0xFFFF) | ((int(round(clamp("
4120                        )?;
4121                        self.write_expr(module, arg, func_ctx)?;
4122                        write!(self.out, "[1], -1.0, 1.0) * {scale}.0)) & 0xFFFF) << 16))",)?;
4123                    }
4124                    Function::Pack2x16unorm => {
4125                        let scale = 65535;
4126
4127                        write!(self.out, "(uint(round(clamp(")?;
4128                        self.write_expr(module, arg, func_ctx)?;
4129                        write!(self.out, "[0], 0.0, 1.0) * {scale}.0)) | uint(round(clamp(")?;
4130                        self.write_expr(module, arg, func_ctx)?;
4131                        write!(self.out, "[1], 0.0, 1.0) * {scale}.0)) << 16)")?;
4132                    }
4133                    Function::Pack4x8snorm => {
4134                        let scale = 127;
4135
4136                        write!(self.out, "uint((int(round(clamp(")?;
4137                        self.write_expr(module, arg, func_ctx)?;
4138                        write!(
4139                            self.out,
4140                            "[0], -1.0, 1.0) * {scale}.0)) & 0xFF) | ((int(round(clamp("
4141                        )?;
4142                        self.write_expr(module, arg, func_ctx)?;
4143                        write!(
4144                            self.out,
4145                            "[1], -1.0, 1.0) * {scale}.0)) & 0xFF) << 8) | ((int(round(clamp("
4146                        )?;
4147                        self.write_expr(module, arg, func_ctx)?;
4148                        write!(
4149                            self.out,
4150                            "[2], -1.0, 1.0) * {scale}.0)) & 0xFF) << 16) | ((int(round(clamp("
4151                        )?;
4152                        self.write_expr(module, arg, func_ctx)?;
4153                        write!(self.out, "[3], -1.0, 1.0) * {scale}.0)) & 0xFF) << 24))",)?;
4154                    }
4155                    Function::Pack4x8unorm => {
4156                        let scale = 255;
4157
4158                        write!(self.out, "(uint(round(clamp(")?;
4159                        self.write_expr(module, arg, func_ctx)?;
4160                        write!(self.out, "[0], 0.0, 1.0) * {scale}.0)) | uint(round(clamp(")?;
4161                        self.write_expr(module, arg, func_ctx)?;
4162                        write!(
4163                            self.out,
4164                            "[1], 0.0, 1.0) * {scale}.0)) << 8 | uint(round(clamp("
4165                        )?;
4166                        self.write_expr(module, arg, func_ctx)?;
4167                        write!(
4168                            self.out,
4169                            "[2], 0.0, 1.0) * {scale}.0)) << 16 | uint(round(clamp("
4170                        )?;
4171                        self.write_expr(module, arg, func_ctx)?;
4172                        write!(self.out, "[3], 0.0, 1.0) * {scale}.0)) << 24)")?;
4173                    }
4174                    fun @ (Function::Pack4xI8
4175                    | Function::Pack4xU8
4176                    | Function::Pack4xI8Clamp
4177                    | Function::Pack4xU8Clamp) => {
4178                        let was_signed =
4179                            matches!(fun, Function::Pack4xI8 | Function::Pack4xI8Clamp);
4180                        let clamp_bounds = match fun {
4181                            Function::Pack4xI8Clamp => Some(("-128", "127")),
4182                            Function::Pack4xU8Clamp => Some(("0", "255")),
4183                            _ => None,
4184                        };
4185                        if was_signed {
4186                            write!(self.out, "uint(")?;
4187                        }
4188                        let write_arg = |this: &mut Self| -> BackendResult {
4189                            if let Some((min, max)) = clamp_bounds {
4190                                write!(this.out, "clamp(")?;
4191                                this.write_expr(module, arg, func_ctx)?;
4192                                write!(this.out, ", {min}, {max})")?;
4193                            } else {
4194                                this.write_expr(module, arg, func_ctx)?;
4195                            }
4196                            Ok(())
4197                        };
4198                        write!(self.out, "(")?;
4199                        write_arg(self)?;
4200                        write!(self.out, "[0] & 0xFF) | ((")?;
4201                        write_arg(self)?;
4202                        write!(self.out, "[1] & 0xFF) << 8) | ((")?;
4203                        write_arg(self)?;
4204                        write!(self.out, "[2] & 0xFF) << 16) | ((")?;
4205                        write_arg(self)?;
4206                        write!(self.out, "[3] & 0xFF) << 24)")?;
4207                        if was_signed {
4208                            write!(self.out, ")")?;
4209                        }
4210                    }
4211
4212                    Function::Unpack2x16float => {
4213                        write!(self.out, "float2(f16tof32(")?;
4214                        self.write_expr(module, arg, func_ctx)?;
4215                        write!(self.out, "), f16tof32((")?;
4216                        self.write_expr(module, arg, func_ctx)?;
4217                        write!(self.out, ") >> 16))")?;
4218                    }
4219                    Function::Unpack2x16snorm => {
4220                        let scale = 32767;
4221
4222                        write!(self.out, "(float2(int2(")?;
4223                        self.write_expr(module, arg, func_ctx)?;
4224                        write!(self.out, " << 16, ")?;
4225                        self.write_expr(module, arg, func_ctx)?;
4226                        write!(self.out, ") >> 16) / {scale}.0)")?;
4227                    }
4228                    Function::Unpack2x16unorm => {
4229                        let scale = 65535;
4230
4231                        write!(self.out, "(float2(")?;
4232                        self.write_expr(module, arg, func_ctx)?;
4233                        write!(self.out, " & 0xFFFF, ")?;
4234                        self.write_expr(module, arg, func_ctx)?;
4235                        write!(self.out, " >> 16) / {scale}.0)")?;
4236                    }
4237                    Function::Unpack4x8snorm => {
4238                        let scale = 127;
4239
4240                        write!(self.out, "(float4(int4(")?;
4241                        self.write_expr(module, arg, func_ctx)?;
4242                        write!(self.out, " << 24, ")?;
4243                        self.write_expr(module, arg, func_ctx)?;
4244                        write!(self.out, " << 16, ")?;
4245                        self.write_expr(module, arg, func_ctx)?;
4246                        write!(self.out, " << 8, ")?;
4247                        self.write_expr(module, arg, func_ctx)?;
4248                        write!(self.out, ") >> 24) / {scale}.0)")?;
4249                    }
4250                    Function::Unpack4x8unorm => {
4251                        let scale = 255;
4252
4253                        write!(self.out, "(float4(")?;
4254                        self.write_expr(module, arg, func_ctx)?;
4255                        write!(self.out, " & 0xFF, ")?;
4256                        self.write_expr(module, arg, func_ctx)?;
4257                        write!(self.out, " >> 8 & 0xFF, ")?;
4258                        self.write_expr(module, arg, func_ctx)?;
4259                        write!(self.out, " >> 16 & 0xFF, ")?;
4260                        self.write_expr(module, arg, func_ctx)?;
4261                        write!(self.out, " >> 24) / {scale}.0)")?;
4262                    }
4263                    fun @ (Function::Unpack4xI8 | Function::Unpack4xU8) => {
4264                        write!(self.out, "(")?;
4265                        if matches!(fun, Function::Unpack4xU8) {
4266                            write!(self.out, "u")?;
4267                        }
4268                        write!(self.out, "int4(")?;
4269                        self.write_expr(module, arg, func_ctx)?;
4270                        write!(self.out, ", ")?;
4271                        self.write_expr(module, arg, func_ctx)?;
4272                        write!(self.out, " >> 8, ")?;
4273                        self.write_expr(module, arg, func_ctx)?;
4274                        write!(self.out, " >> 16, ")?;
4275                        self.write_expr(module, arg, func_ctx)?;
4276                        write!(self.out, " >> 24) << 24 >> 24)")?;
4277                    }
4278                    fun @ (Function::Dot4I8Packed | Function::Dot4U8Packed) => {
4279                        let arg1 = arg1.unwrap();
4280
4281                        if self.options.shader_model >= ShaderModel::V6_4 {
4282                            // Intrinsics `dot4add_{i, u}8packed` are available in SM 6.4 and later.
4283                            let function_name = match fun {
4284                                Function::Dot4I8Packed => "dot4add_i8packed",
4285                                Function::Dot4U8Packed => "dot4add_u8packed",
4286                                _ => unreachable!(),
4287                            };
4288                            write!(self.out, "{function_name}(")?;
4289                            self.write_expr(module, arg, func_ctx)?;
4290                            write!(self.out, ", ")?;
4291                            self.write_expr(module, arg1, func_ctx)?;
4292                            write!(self.out, ", 0)")?;
4293                        } else {
4294                            // Fall back to a polyfill as `dot4add_u8packed` is not available.
4295                            write!(self.out, "dot(")?;
4296
4297                            if matches!(fun, Function::Dot4U8Packed) {
4298                                write!(self.out, "u")?;
4299                            }
4300                            write!(self.out, "int4(")?;
4301                            self.write_expr(module, arg, func_ctx)?;
4302                            write!(self.out, ", ")?;
4303                            self.write_expr(module, arg, func_ctx)?;
4304                            write!(self.out, " >> 8, ")?;
4305                            self.write_expr(module, arg, func_ctx)?;
4306                            write!(self.out, " >> 16, ")?;
4307                            self.write_expr(module, arg, func_ctx)?;
4308                            write!(self.out, " >> 24) << 24 >> 24, ")?;
4309
4310                            if matches!(fun, Function::Dot4U8Packed) {
4311                                write!(self.out, "u")?;
4312                            }
4313                            write!(self.out, "int4(")?;
4314                            self.write_expr(module, arg1, func_ctx)?;
4315                            write!(self.out, ", ")?;
4316                            self.write_expr(module, arg1, func_ctx)?;
4317                            write!(self.out, " >> 8, ")?;
4318                            self.write_expr(module, arg1, func_ctx)?;
4319                            write!(self.out, " >> 16, ")?;
4320                            self.write_expr(module, arg1, func_ctx)?;
4321                            write!(self.out, " >> 24) << 24 >> 24)")?;
4322                        }
4323                    }
4324                    Function::QuantizeToF16 => {
4325                        write!(self.out, "f16tof32(f32tof16(")?;
4326                        self.write_expr(module, arg, func_ctx)?;
4327                        write!(self.out, "))")?;
4328                    }
4329                    Function::Regular(fun_name) => {
4330                        write!(self.out, "{fun_name}(")?;
4331                        self.write_expr(module, arg, func_ctx)?;
4332                        if let Some(arg) = arg1 {
4333                            write!(self.out, ", ")?;
4334                            self.write_expr(module, arg, func_ctx)?;
4335                        }
4336                        if let Some(arg) = arg2 {
4337                            write!(self.out, ", ")?;
4338                            self.write_expr(module, arg, func_ctx)?;
4339                        }
4340                        if let Some(arg) = arg3 {
4341                            write!(self.out, ", ")?;
4342                            self.write_expr(module, arg, func_ctx)?;
4343                        }
4344                        write!(self.out, ")")?
4345                    }
4346                    // These overloads are only missing on FXC, so this is only needed for 32bit types,
4347                    // as non-32bit types are DXC only.
4348                    Function::MissingIntOverload(fun_name) => {
4349                        let scalar_kind = func_ctx.resolve_type(arg, &module.types).scalar();
4350                        if let Some(Scalar::I32) = scalar_kind {
4351                            write!(self.out, "asint({fun_name}(asuint(")?;
4352                            self.write_expr(module, arg, func_ctx)?;
4353                            write!(self.out, ")))")?;
4354                        } else {
4355                            write!(self.out, "{fun_name}(")?;
4356                            self.write_expr(module, arg, func_ctx)?;
4357                            write!(self.out, ")")?;
4358                        }
4359                    }
4360                    // These overloads are only missing on FXC, so this is only needed for 32bit types,
4361                    // as non-32bit types are DXC only.
4362                    Function::MissingIntReturnType(fun_name) => {
4363                        let scalar_kind = func_ctx.resolve_type(arg, &module.types).scalar();
4364                        if let Some(Scalar::I32) = scalar_kind {
4365                            write!(self.out, "asint({fun_name}(")?;
4366                            self.write_expr(module, arg, func_ctx)?;
4367                            write!(self.out, "))")?;
4368                        } else {
4369                            write!(self.out, "{fun_name}(")?;
4370                            self.write_expr(module, arg, func_ctx)?;
4371                            write!(self.out, ")")?;
4372                        }
4373                    }
4374                    Function::CountTrailingZeros => {
4375                        match *func_ctx.resolve_type(arg, &module.types) {
4376                            TypeInner::Vector { size, scalar } => {
4377                                let s = match size {
4378                                    crate::VectorSize::Bi => ".xx",
4379                                    crate::VectorSize::Tri => ".xxx",
4380                                    crate::VectorSize::Quad => ".xxxx",
4381                                };
4382
4383                                let scalar_width_bits = scalar.width * 8;
4384
4385                                if scalar.kind == ScalarKind::Uint || scalar.width != 4 {
4386                                    write!(
4387                                        self.out,
4388                                        "min(({scalar_width_bits}u){s}, firstbitlow("
4389                                    )?;
4390                                    self.write_expr(module, arg, func_ctx)?;
4391                                    write!(self.out, "))")?;
4392                                } else {
4393                                    // This is only needed for the FXC path, on 32bit signed integers.
4394                                    write!(
4395                                        self.out,
4396                                        "asint(min(({scalar_width_bits}u){s}, firstbitlow("
4397                                    )?;
4398                                    self.write_expr(module, arg, func_ctx)?;
4399                                    write!(self.out, ")))")?;
4400                                }
4401                            }
4402                            TypeInner::Scalar(scalar) => {
4403                                let scalar_width_bits = scalar.width * 8;
4404
4405                                if scalar.kind == ScalarKind::Uint || scalar.width != 4 {
4406                                    write!(self.out, "min({scalar_width_bits}u, firstbitlow(")?;
4407                                    self.write_expr(module, arg, func_ctx)?;
4408                                    write!(self.out, "))")?;
4409                                } else {
4410                                    // This is only needed for the FXC path, on 32bit signed integers.
4411                                    write!(
4412                                        self.out,
4413                                        "asint(min({scalar_width_bits}u, firstbitlow("
4414                                    )?;
4415                                    self.write_expr(module, arg, func_ctx)?;
4416                                    write!(self.out, ")))")?;
4417                                }
4418                            }
4419                            _ => unreachable!(),
4420                        }
4421
4422                        return Ok(());
4423                    }
4424                    Function::CountLeadingZeros => {
4425                        match *func_ctx.resolve_type(arg, &module.types) {
4426                            TypeInner::Vector { size, scalar } => {
4427                                let s = match size {
4428                                    crate::VectorSize::Bi => ".xx",
4429                                    crate::VectorSize::Tri => ".xxx",
4430                                    crate::VectorSize::Quad => ".xxxx",
4431                                };
4432
4433                                // scalar width - 1
4434                                let constant = scalar.width * 8 - 1;
4435
4436                                if scalar.kind == ScalarKind::Uint {
4437                                    write!(self.out, "(({constant}u){s} - firstbithigh(")?;
4438                                    self.write_expr(module, arg, func_ctx)?;
4439                                    write!(self.out, "))")?;
4440                                } else {
4441                                    let conversion_func = match scalar.width {
4442                                        4 => "asint",
4443                                        _ => "",
4444                                    };
4445                                    write!(self.out, "(")?;
4446                                    self.write_expr(module, arg, func_ctx)?;
4447                                    write!(
4448                                        self.out,
4449                                        " < (0){s} ? (0){s} : ({constant}){s} - {conversion_func}(firstbithigh("
4450                                    )?;
4451                                    self.write_expr(module, arg, func_ctx)?;
4452                                    write!(self.out, ")))")?;
4453                                }
4454                            }
4455                            TypeInner::Scalar(scalar) => {
4456                                // scalar width - 1
4457                                let constant = scalar.width * 8 - 1;
4458
4459                                if let ScalarKind::Uint = scalar.kind {
4460                                    write!(self.out, "({constant}u - firstbithigh(")?;
4461                                    self.write_expr(module, arg, func_ctx)?;
4462                                    write!(self.out, "))")?;
4463                                } else {
4464                                    let conversion_func = match scalar.width {
4465                                        4 => "asint",
4466                                        _ => "",
4467                                    };
4468                                    write!(self.out, "(")?;
4469                                    self.write_expr(module, arg, func_ctx)?;
4470                                    write!(
4471                                        self.out,
4472                                        " < 0 ? 0 : {constant} - {conversion_func}(firstbithigh("
4473                                    )?;
4474                                    self.write_expr(module, arg, func_ctx)?;
4475                                    write!(self.out, ")))")?;
4476                                }
4477                            }
4478                            _ => unreachable!(),
4479                        }
4480
4481                        return Ok(());
4482                    }
4483                }
4484            }
4485            Expression::Swizzle {
4486                size,
4487                vector,
4488                pattern,
4489            } => {
4490                self.write_expr(module, vector, func_ctx)?;
4491                write!(self.out, ".")?;
4492                for &sc in pattern[..size as usize].iter() {
4493                    self.out.write_char(back::COMPONENTS[sc as usize])?;
4494                }
4495            }
4496            Expression::ArrayLength(expr) => {
4497                let var_handle = match func_ctx.expressions[expr] {
4498                    Expression::AccessIndex { base, index: _ } => {
4499                        match func_ctx.expressions[base] {
4500                            Expression::GlobalVariable(handle) => handle,
4501                            _ => unreachable!(),
4502                        }
4503                    }
4504                    Expression::GlobalVariable(handle) => handle,
4505                    _ => unreachable!(),
4506                };
4507
4508                let var = &module.global_variables[var_handle];
4509                let (offset, stride) = match module.types[var.ty].inner {
4510                    TypeInner::Array { stride, .. } => (0, stride),
4511                    TypeInner::Struct { ref members, .. } => {
4512                        let last = members.last().unwrap();
4513                        let stride = match module.types[last.ty].inner {
4514                            TypeInner::Array { stride, .. } => stride,
4515                            _ => unreachable!(),
4516                        };
4517                        (last.offset, stride)
4518                    }
4519                    _ => unreachable!(),
4520                };
4521
4522                let storage_access = match var.space {
4523                    crate::AddressSpace::Storage { access } => access,
4524                    _ => crate::StorageAccess::default(),
4525                };
4526                let wrapped_array_length = WrappedArrayLength {
4527                    writable: storage_access.contains(crate::StorageAccess::STORE),
4528                };
4529
4530                write!(self.out, "((")?;
4531                self.write_wrapped_array_length_function_name(wrapped_array_length)?;
4532                let var_name = &self.names[&NameKey::GlobalVariable(var_handle)];
4533                write!(self.out, "({var_name}) - {offset}) / {stride})")?
4534            }
4535            Expression::Derivative { axis, ctrl, expr } => {
4536                use crate::{DerivativeAxis as Axis, DerivativeControl as Ctrl};
4537                if axis == Axis::Width && (ctrl == Ctrl::Coarse || ctrl == Ctrl::Fine) {
4538                    let tail = match ctrl {
4539                        Ctrl::Coarse => "coarse",
4540                        Ctrl::Fine => "fine",
4541                        Ctrl::None => unreachable!(),
4542                    };
4543                    write!(self.out, "abs(ddx_{tail}(")?;
4544                    self.write_expr(module, expr, func_ctx)?;
4545                    write!(self.out, ")) + abs(ddy_{tail}(")?;
4546                    self.write_expr(module, expr, func_ctx)?;
4547                    write!(self.out, "))")?
4548                } else {
4549                    let fun_str = match (axis, ctrl) {
4550                        (Axis::X, Ctrl::Coarse) => "ddx_coarse",
4551                        (Axis::X, Ctrl::Fine) => "ddx_fine",
4552                        (Axis::X, Ctrl::None) => "ddx",
4553                        (Axis::Y, Ctrl::Coarse) => "ddy_coarse",
4554                        (Axis::Y, Ctrl::Fine) => "ddy_fine",
4555                        (Axis::Y, Ctrl::None) => "ddy",
4556                        (Axis::Width, Ctrl::Coarse | Ctrl::Fine) => unreachable!(),
4557                        (Axis::Width, Ctrl::None) => "fwidth",
4558                    };
4559                    write!(self.out, "{fun_str}(")?;
4560                    self.write_expr(module, expr, func_ctx)?;
4561                    write!(self.out, ")")?
4562                }
4563            }
4564            Expression::Relational { fun, argument } => {
4565                use crate::RelationalFunction as Rf;
4566
4567                let fun_str = match fun {
4568                    Rf::All => "all",
4569                    Rf::Any => "any",
4570                    Rf::IsNan => "isnan",
4571                    Rf::IsInf => "isinf",
4572                };
4573                write!(self.out, "{fun_str}(")?;
4574                self.write_expr(module, argument, func_ctx)?;
4575                write!(self.out, ")")?
4576            }
4577            Expression::Select {
4578                condition,
4579                accept,
4580                reject,
4581            } => {
4582                write!(self.out, "(")?;
4583                self.write_expr(module, condition, func_ctx)?;
4584                write!(self.out, " ? ")?;
4585                self.write_expr(module, accept, func_ctx)?;
4586                write!(self.out, " : ")?;
4587                self.write_expr(module, reject, func_ctx)?;
4588                write!(self.out, ")")?
4589            }
4590            Expression::RayQueryGetIntersection { query, committed } => {
4591                // For reasoning, see write_stmt
4592                let Expression::LocalVariable(query_var) = func_ctx.expressions[query] else {
4593                    unreachable!()
4594                };
4595
4596                let tracker_expr_name = format!(
4597                    "{RAY_QUERY_TRACKER_VARIABLE_PREFIX}{}",
4598                    self.names[&func_ctx.name_key(query_var)]
4599                );
4600
4601                if committed {
4602                    write!(self.out, "GetCommittedIntersection(")?;
4603                    self.write_expr(module, query, func_ctx)?;
4604                    write!(self.out, ", {tracker_expr_name})")?;
4605                } else {
4606                    write!(self.out, "GetCandidateIntersection(")?;
4607                    self.write_expr(module, query, func_ctx)?;
4608                    write!(self.out, ", {tracker_expr_name})")?;
4609                }
4610            }
4611            // Not supported yet
4612            Expression::RayQueryVertexPositions { .. }
4613            | Expression::CooperativeLoad { .. }
4614            | Expression::CooperativeMultiplyAdd { .. } => {
4615                unreachable!()
4616            }
4617            // Nothing to do here, since call expression already cached
4618            Expression::CallResult(_)
4619            | Expression::AtomicResult { .. }
4620            | Expression::WorkGroupUniformLoadResult { .. }
4621            | Expression::RayQueryProceedResult
4622            | Expression::SubgroupBallotResult
4623            | Expression::SubgroupOperationResult { .. } => {}
4624        }
4625
4626        if !closing_bracket.is_empty() {
4627            write!(self.out, "{closing_bracket}")?;
4628        }
4629        Ok(())
4630    }
4631
4632    #[allow(clippy::too_many_arguments)]
4633    fn write_image_load(
4634        &mut self,
4635        module: &&Module,
4636        expr: Handle<crate::Expression>,
4637        func_ctx: &back::FunctionCtx,
4638        image: Handle<crate::Expression>,
4639        coordinate: Handle<crate::Expression>,
4640        array_index: Option<Handle<crate::Expression>>,
4641        sample: Option<Handle<crate::Expression>>,
4642        level: Option<Handle<crate::Expression>>,
4643    ) -> Result<(), Error> {
4644        let mut wrapping_type = None;
4645        match *func_ctx.resolve_type(image, &module.types) {
4646            TypeInner::Image {
4647                class: crate::ImageClass::External,
4648                ..
4649            } => {
4650                write!(self.out, "{IMAGE_LOAD_EXTERNAL_FUNCTION}(")?;
4651                self.write_expr(module, image, func_ctx)?;
4652                write!(self.out, ", ")?;
4653                self.write_expr(module, coordinate, func_ctx)?;
4654                write!(self.out, ")")?;
4655                return Ok(());
4656            }
4657            TypeInner::Image {
4658                class: crate::ImageClass::Storage { format, .. },
4659                ..
4660            } => {
4661                if format.single_component() {
4662                    wrapping_type = Some(Scalar::from(format));
4663                }
4664            }
4665            _ => {}
4666        }
4667        if let Some(scalar) = wrapping_type {
4668            write!(
4669                self.out,
4670                "{}{}(",
4671                help::IMAGE_STORAGE_LOAD_SCALAR_WRAPPER,
4672                scalar.to_hlsl_str()?
4673            )?;
4674        }
4675        // https://docs.microsoft.com/en-us/windows/win32/direct3dhlsl/dx-graphics-hlsl-to-load
4676        self.write_expr(module, image, func_ctx)?;
4677        write!(self.out, ".Load(")?;
4678
4679        self.write_texture_coordinates("int", coordinate, array_index, level, module, func_ctx)?;
4680
4681        if let Some(sample) = sample {
4682            write!(self.out, ", ")?;
4683            self.write_expr(module, sample, func_ctx)?;
4684        }
4685
4686        // close bracket for Load function
4687        write!(self.out, ")")?;
4688
4689        if wrapping_type.is_some() {
4690            write!(self.out, ")")?;
4691        }
4692
4693        // return x component if return type is scalar
4694        if let TypeInner::Scalar(_) = *func_ctx.resolve_type(expr, &module.types) {
4695            write!(self.out, ".x")?;
4696        }
4697        Ok(())
4698    }
4699
4700    /// Find the [`BindingArraySamplerInfo`] from an expression so that such an access
4701    /// can be generated later.
4702    fn sampler_binding_array_info_from_expression(
4703        &mut self,
4704        module: &Module,
4705        func_ctx: &back::FunctionCtx<'_>,
4706        base: Handle<crate::Expression>,
4707        resolved: &TypeInner,
4708    ) -> Option<BindingArraySamplerInfo> {
4709        if let TypeInner::BindingArray {
4710            base: base_ty_handle,
4711            ..
4712        } = *resolved
4713        {
4714            let base_ty = &module.types[base_ty_handle].inner;
4715            if let TypeInner::Sampler { comparison, .. } = *base_ty {
4716                let base = &func_ctx.expressions[base];
4717
4718                if let crate::Expression::GlobalVariable(handle) = *base {
4719                    let variable = &module.global_variables[handle];
4720
4721                    let sampler_heap_name = match comparison {
4722                        true => COMPARISON_SAMPLER_HEAP_VAR,
4723                        false => SAMPLER_HEAP_VAR,
4724                    };
4725
4726                    return Some(BindingArraySamplerInfo {
4727                        sampler_heap_name,
4728                        sampler_index_buffer_name: self
4729                            .wrapped
4730                            .sampler_index_buffers
4731                            .get(&super::SamplerIndexBufferKey {
4732                                group: variable.binding.unwrap().group,
4733                            })
4734                            .unwrap()
4735                            .clone(),
4736                        binding_array_base_index_name: self.names[&NameKey::GlobalVariable(handle)]
4737                            .clone(),
4738                    });
4739                }
4740            }
4741        }
4742
4743        None
4744    }
4745
4746    fn write_named_expr(
4747        &mut self,
4748        module: &Module,
4749        handle: Handle<crate::Expression>,
4750        name: String,
4751        // The expression which is being named.
4752        // Generally, this is the same as handle, except in WorkGroupUniformLoad
4753        expr: Handle<crate::Expression>,
4754        func_ctx: &back::FunctionCtx,
4755    ) -> BackendResult {
4756        if let crate::Expression::Load { pointer } = func_ctx.expressions[expr] {
4757            let ty_inner = func_ctx.resolve_type(pointer, &module.types);
4758            if ty_inner.is_atomic_pointer(&module.types) {
4759                let pointer_space = ty_inner.pointer_space().unwrap();
4760                self.write_value_type(module, func_ctx.info[handle].ty.inner_with(&module.types))?;
4761                write!(self.out, " {name}; ")?;
4762                match pointer_space {
4763                    crate::AddressSpace::WorkGroup => {
4764                        write!(self.out, "InterlockedOr(")?;
4765                        self.write_expr(module, pointer, func_ctx)?;
4766                    }
4767                    crate::AddressSpace::Storage { .. } => {
4768                        let var_handle = self.fill_access_chain(module, pointer, func_ctx)?;
4769                        let var_name = &self.names[&NameKey::GlobalVariable(var_handle)];
4770                        write!(self.out, "{var_name}.InterlockedOr(")?;
4771                        let chain = mem::take(&mut self.temp_access_chain);
4772                        self.write_storage_address(module, &chain, func_ctx)?;
4773                        self.temp_access_chain = chain;
4774                    }
4775                    _ => unreachable!(),
4776                }
4777                writeln!(self.out, ", 0, {name});")?;
4778                self.named_expressions.insert(expr, name);
4779                return Ok(());
4780            }
4781        }
4782        match func_ctx.info[expr].ty {
4783            proc::TypeResolution::Handle(ty_handle) => match module.types[ty_handle].inner {
4784                TypeInner::Struct { .. } => {
4785                    let ty_name = &self.names[&NameKey::Type(ty_handle)];
4786                    write!(self.out, "{ty_name}")?;
4787                }
4788                _ => {
4789                    self.write_type(module, ty_handle)?;
4790                }
4791            },
4792            proc::TypeResolution::Value(ref inner) => {
4793                self.write_value_type(module, inner)?;
4794            }
4795        }
4796
4797        let resolved = func_ctx.resolve_type(expr, &module.types);
4798
4799        write!(self.out, " {name}")?;
4800        // If rhs is a array type, we should write array size
4801        if let TypeInner::Array { base, size, .. } = *resolved {
4802            self.write_array_size(module, base, size)?;
4803        }
4804        write!(self.out, " = ")?;
4805        self.write_expr(module, handle, func_ctx)?;
4806        writeln!(self.out, ";")?;
4807        self.named_expressions.insert(expr, name);
4808
4809        Ok(())
4810    }
4811
4812    /// Helper function that write default zero initialization
4813    pub(super) fn write_default_init(
4814        &mut self,
4815        module: &Module,
4816        ty: Handle<crate::Type>,
4817    ) -> BackendResult {
4818        write!(self.out, "(")?;
4819        self.write_type(module, ty)?;
4820        if let TypeInner::Array { base, size, .. } = module.types[ty].inner {
4821            self.write_array_size(module, base, size)?;
4822        }
4823        write!(self.out, ")0")?;
4824        Ok(())
4825    }
4826
4827    pub(super) fn write_control_barrier(
4828        &mut self,
4829        barrier: crate::Barrier,
4830        level: back::Level,
4831    ) -> BackendResult {
4832        if barrier.contains(crate::Barrier::STORAGE) {
4833            writeln!(self.out, "{level}DeviceMemoryBarrierWithGroupSync();")?;
4834        }
4835        if barrier.contains(crate::Barrier::WORK_GROUP) {
4836            writeln!(self.out, "{level}GroupMemoryBarrierWithGroupSync();")?;
4837        }
4838        if barrier.contains(crate::Barrier::SUB_GROUP) {
4839            // Does not exist in DirectX
4840        }
4841        if barrier.contains(crate::Barrier::TEXTURE) {
4842            writeln!(self.out, "{level}DeviceMemoryBarrierWithGroupSync();")?;
4843        }
4844        Ok(())
4845    }
4846
4847    fn write_memory_barrier(
4848        &mut self,
4849        barrier: crate::Barrier,
4850        level: back::Level,
4851    ) -> BackendResult {
4852        if barrier.contains(crate::Barrier::STORAGE) {
4853            writeln!(self.out, "{level}DeviceMemoryBarrier();")?;
4854        }
4855        if barrier.contains(crate::Barrier::WORK_GROUP) {
4856            writeln!(self.out, "{level}GroupMemoryBarrier();")?;
4857        }
4858        if barrier.contains(crate::Barrier::SUB_GROUP) {
4859            // Does not exist in DirectX
4860        }
4861        if barrier.contains(crate::Barrier::TEXTURE) {
4862            writeln!(self.out, "{level}DeviceMemoryBarrier();")?;
4863        }
4864        Ok(())
4865    }
4866
4867    /// Helper to emit the shared tail of an HLSL atomic call (arguments, value, result)
4868    fn emit_hlsl_atomic_tail(
4869        &mut self,
4870        module: &Module,
4871        func_ctx: &back::FunctionCtx<'_>,
4872        fun: &crate::AtomicFunction,
4873        compare_expr: Option<Handle<crate::Expression>>,
4874        value: Handle<crate::Expression>,
4875        res_var_info: &Option<(Handle<crate::Expression>, String)>,
4876    ) -> BackendResult {
4877        if let Some(cmp) = compare_expr {
4878            write!(self.out, ", ")?;
4879            self.write_expr(module, cmp, func_ctx)?;
4880        }
4881        write!(self.out, ", ")?;
4882        if let crate::AtomicFunction::Subtract = *fun {
4883            // we just wrote `InterlockedAdd`, so negate the argument
4884            write!(self.out, "-")?;
4885        }
4886        self.write_expr(module, value, func_ctx)?;
4887        if let Some(&(_res_handle, ref res_name)) = res_var_info.as_ref() {
4888            write!(self.out, ", ")?;
4889            if compare_expr.is_some() {
4890                write!(self.out, "{res_name}.old_value")?;
4891            } else {
4892                write!(self.out, "{res_name}")?;
4893            }
4894        }
4895        writeln!(self.out, ");")?;
4896        Ok(())
4897    }
4898}
4899
4900pub(super) struct MatrixType {
4901    pub(super) columns: crate::VectorSize,
4902    pub(super) rows: crate::VectorSize,
4903    pub(super) width: crate::Bytes,
4904}
4905
4906pub(super) fn get_inner_matrix_data(
4907    module: &Module,
4908    handle: Handle<crate::Type>,
4909) -> Option<MatrixType> {
4910    match module.types[handle].inner {
4911        TypeInner::Matrix {
4912            columns,
4913            rows,
4914            scalar,
4915        } => Some(MatrixType {
4916            columns,
4917            rows,
4918            width: scalar.width,
4919        }),
4920        TypeInner::Array { base, .. } => get_inner_matrix_data(module, base),
4921        _ => None,
4922    }
4923}
4924
4925/// If `base` is an access chain of the form `mat`, `mat[col]`, or `mat[col][row]`,
4926/// returns a tuple of the matrix, the column (vector) index (if present), and
4927/// the row (scalar) index (if present).
4928fn find_matrix_in_access_chain(
4929    module: &Module,
4930    base: Handle<crate::Expression>,
4931    func_ctx: &back::FunctionCtx<'_>,
4932) -> Option<(Handle<crate::Expression>, Option<Index>, Option<Index>)> {
4933    let mut current_base = base;
4934    let mut vector = None;
4935    let mut scalar = None;
4936    loop {
4937        let resolved_tr = func_ctx
4938            .resolve_type(current_base, &module.types)
4939            .pointer_base_type();
4940        let resolved = resolved_tr.as_ref()?.inner_with(&module.types);
4941
4942        match *resolved {
4943            TypeInner::Matrix { .. } => return Some((current_base, vector, scalar)),
4944            TypeInner::Scalar(_) | TypeInner::Vector { .. } => {}
4945            _ => return None,
4946        }
4947
4948        let index;
4949        (current_base, index) = match func_ctx.expressions[current_base] {
4950            crate::Expression::Access { base, index } => (base, Index::Expression(index)),
4951            crate::Expression::AccessIndex { base, index } => (base, Index::Static(index)),
4952            _ => return None,
4953        };
4954
4955        match *resolved {
4956            TypeInner::Scalar(_) => scalar = Some(index),
4957            TypeInner::Vector { .. } => vector = Some(index),
4958            _ => unreachable!(),
4959        }
4960    }
4961}
4962
4963/// Returns the matrix data if the access chain starting at `base`:
4964/// - starts with an expression with resolved type of [`TypeInner::Matrix`] if `direct = true`
4965/// - contains one or more expressions with resolved type of [`TypeInner::Array`] of [`TypeInner::Matrix`]
4966/// - ends at an expression with resolved type of [`TypeInner::Struct`]
4967pub(super) fn get_inner_matrix_of_struct_array_member(
4968    module: &Module,
4969    base: Handle<crate::Expression>,
4970    func_ctx: &back::FunctionCtx<'_>,
4971    direct: bool,
4972) -> Option<MatrixType> {
4973    let mut mat_data = None;
4974    let mut array_base = None;
4975
4976    let mut current_base = base;
4977    loop {
4978        let mut resolved = func_ctx.resolve_type(current_base, &module.types);
4979        if let TypeInner::Pointer { base, .. } = *resolved {
4980            resolved = &module.types[base].inner;
4981        };
4982
4983        match *resolved {
4984            TypeInner::Matrix {
4985                columns,
4986                rows,
4987                scalar,
4988            } => {
4989                mat_data = Some(MatrixType {
4990                    columns,
4991                    rows,
4992                    width: scalar.width,
4993                })
4994            }
4995            TypeInner::Array { base, .. } => {
4996                array_base = Some(base);
4997            }
4998            TypeInner::Struct { .. } => {
4999                if let Some(array_base) = array_base {
5000                    if direct {
5001                        return mat_data;
5002                    } else {
5003                        return get_inner_matrix_data(module, array_base);
5004                    }
5005                }
5006
5007                break;
5008            }
5009            _ => break,
5010        }
5011
5012        current_base = match func_ctx.expressions[current_base] {
5013            crate::Expression::Access { base, .. } => base,
5014            crate::Expression::AccessIndex { base, .. } => base,
5015            _ => break,
5016        };
5017    }
5018    None
5019}
5020
5021/// Simpler version of get_inner_matrix_of_global_uniform that only looks at the
5022/// immediate expression, rather than traversing an access chain.
5023fn get_global_uniform_matrix(
5024    module: &Module,
5025    base: Handle<crate::Expression>,
5026    func_ctx: &back::FunctionCtx<'_>,
5027) -> Option<MatrixType> {
5028    let base_tr = func_ctx
5029        .resolve_type(base, &module.types)
5030        .pointer_base_type();
5031    let base_ty = base_tr.as_ref().map(|tr| tr.inner_with(&module.types));
5032    match (&func_ctx.expressions[base], base_ty) {
5033        (
5034            &crate::Expression::GlobalVariable(handle),
5035            Some(&TypeInner::Matrix {
5036                columns,
5037                rows,
5038                scalar,
5039            }),
5040        ) if module.global_variables[handle].space == crate::AddressSpace::Uniform => {
5041            Some(MatrixType {
5042                columns,
5043                rows,
5044                width: scalar.width,
5045            })
5046        }
5047        _ => None,
5048    }
5049}
5050
5051/// Returns the matrix data if the access chain starting at `base`:
5052/// - starts with an expression with resolved type of [`TypeInner::Matrix`]
5053/// - contains zero or more expressions with resolved type of [`TypeInner::Array`] of [`TypeInner::Matrix`]
5054/// - ends with an [`Expression::GlobalVariable`](crate::Expression::GlobalVariable) in [`AddressSpace::Uniform`](crate::AddressSpace::Uniform)
5055fn get_inner_matrix_of_global_uniform(
5056    module: &Module,
5057    base: Handle<crate::Expression>,
5058    func_ctx: &back::FunctionCtx<'_>,
5059) -> Option<MatrixType> {
5060    let mut mat_data = None;
5061    let mut array_base = None;
5062
5063    let mut current_base = base;
5064    loop {
5065        let mut resolved = func_ctx.resolve_type(current_base, &module.types);
5066        if let TypeInner::Pointer { base, .. } = *resolved {
5067            resolved = &module.types[base].inner;
5068        };
5069
5070        match *resolved {
5071            TypeInner::Matrix {
5072                columns,
5073                rows,
5074                scalar,
5075            } => {
5076                mat_data = Some(MatrixType {
5077                    columns,
5078                    rows,
5079                    width: scalar.width,
5080                })
5081            }
5082            TypeInner::Array { base, .. } => {
5083                array_base = Some(base);
5084            }
5085            _ => break,
5086        }
5087
5088        current_base = match func_ctx.expressions[current_base] {
5089            crate::Expression::Access { base, .. } => base,
5090            crate::Expression::AccessIndex { base, .. } => base,
5091            crate::Expression::GlobalVariable(handle)
5092                if module.global_variables[handle].space == crate::AddressSpace::Uniform =>
5093            {
5094                return mat_data.or_else(|| {
5095                    array_base.and_then(|array_base| get_inner_matrix_data(module, array_base))
5096                })
5097            }
5098            _ => break,
5099        };
5100    }
5101    None
5102}