naga/back/hlsl/
writer.rs

1use alloc::{
2    format,
3    string::{String, ToString},
4    vec::Vec,
5};
6use core::{fmt, mem};
7
8use super::{
9    help,
10    help::{
11        WrappedArrayLength, WrappedConstructor, WrappedImageQuery, WrappedStructMatrixAccess,
12        WrappedZeroValue,
13    },
14    storage::StoreValue,
15    BackendResult, Error, FragmentEntryPoint, Options, PipelineOptions, ShaderModel,
16};
17use crate::{
18    back::{self, get_entry_points, Baked},
19    common,
20    proc::{self, index, NameKey},
21    valid, Handle, Module, RayQueryFunction, Scalar, ScalarKind, ShaderStage, TypeInner,
22};
23
24const LOCATION_SEMANTIC: &str = "LOC";
25const SPECIAL_CBUF_TYPE: &str = "NagaConstants";
26const SPECIAL_CBUF_VAR: &str = "_NagaConstants";
27const SPECIAL_FIRST_VERTEX: &str = "first_vertex";
28const SPECIAL_FIRST_INSTANCE: &str = "first_instance";
29const SPECIAL_OTHER: &str = "other";
30
31pub(crate) const MODF_FUNCTION: &str = "naga_modf";
32pub(crate) const FREXP_FUNCTION: &str = "naga_frexp";
33pub(crate) const EXTRACT_BITS_FUNCTION: &str = "naga_extractBits";
34pub(crate) const INSERT_BITS_FUNCTION: &str = "naga_insertBits";
35pub(crate) const SAMPLER_HEAP_VAR: &str = "nagaSamplerHeap";
36pub(crate) const COMPARISON_SAMPLER_HEAP_VAR: &str = "nagaComparisonSamplerHeap";
37pub(crate) const ABS_FUNCTION: &str = "naga_abs";
38pub(crate) const DIV_FUNCTION: &str = "naga_div";
39pub(crate) const MOD_FUNCTION: &str = "naga_mod";
40pub(crate) const NEG_FUNCTION: &str = "naga_neg";
41pub(crate) const F2I32_FUNCTION: &str = "naga_f2i32";
42pub(crate) const F2U32_FUNCTION: &str = "naga_f2u32";
43pub(crate) const F2I64_FUNCTION: &str = "naga_f2i64";
44pub(crate) const F2U64_FUNCTION: &str = "naga_f2u64";
45pub(crate) const IMAGE_SAMPLE_BASE_CLAMP_TO_EDGE_FUNCTION: &str =
46    "nagaTextureSampleBaseClampToEdge";
47
48enum Index {
49    Expression(Handle<crate::Expression>),
50    Static(u32),
51}
52
53struct EpStructMember {
54    name: String,
55    ty: Handle<crate::Type>,
56    // technically, this should always be `Some`
57    // (we `debug_assert!` this in `write_interface_struct`)
58    binding: Option<crate::Binding>,
59    index: u32,
60}
61
62/// Structure contains information required for generating
63/// wrapped structure of all entry points arguments
64struct EntryPointBinding {
65    /// Name of the fake EP argument that contains the struct
66    /// with all the flattened input data.
67    arg_name: String,
68    /// Generated structure name
69    ty_name: String,
70    /// Members of generated structure
71    members: Vec<EpStructMember>,
72}
73
74pub(super) struct EntryPointInterface {
75    /// If `Some`, the input of an entry point is gathered in a special
76    /// struct with members sorted by binding.
77    /// The `EntryPointBinding::members` array is sorted by index,
78    /// so that we can walk it in `write_ep_arguments_initialization`.
79    input: Option<EntryPointBinding>,
80    /// If `Some`, the output of an entry point is flattened.
81    /// The `EntryPointBinding::members` array is sorted by binding,
82    /// So that we can walk it in `Statement::Return` handler.
83    output: Option<EntryPointBinding>,
84}
85
86#[derive(Clone, Eq, PartialEq, PartialOrd, Ord)]
87enum InterfaceKey {
88    Location(u32),
89    BuiltIn(crate::BuiltIn),
90    Other,
91}
92
93impl InterfaceKey {
94    const fn new(binding: Option<&crate::Binding>) -> Self {
95        match binding {
96            Some(&crate::Binding::Location { location, .. }) => Self::Location(location),
97            Some(&crate::Binding::BuiltIn(built_in)) => Self::BuiltIn(built_in),
98            None => Self::Other,
99        }
100    }
101}
102
103#[derive(Copy, Clone, PartialEq)]
104enum Io {
105    Input,
106    Output,
107}
108
109const fn is_subgroup_builtin_binding(binding: &Option<crate::Binding>) -> bool {
110    let &Some(crate::Binding::BuiltIn(builtin)) = binding else {
111        return false;
112    };
113    matches!(
114        builtin,
115        crate::BuiltIn::SubgroupSize
116            | crate::BuiltIn::SubgroupInvocationId
117            | crate::BuiltIn::NumSubgroups
118            | crate::BuiltIn::SubgroupId
119    )
120}
121
122/// Information for how to generate a `binding_array<sampler>` access.
123struct BindingArraySamplerInfo {
124    /// Variable name of the sampler heap
125    sampler_heap_name: &'static str,
126    /// Variable name of the sampler index buffer
127    sampler_index_buffer_name: String,
128    /// Variable name of the base index _into_ the sampler index buffer
129    binding_array_base_index_name: String,
130}
131
132impl<'a, W: fmt::Write> super::Writer<'a, W> {
133    pub fn new(out: W, options: &'a Options, pipeline_options: &'a PipelineOptions) -> Self {
134        Self {
135            out,
136            names: crate::FastHashMap::default(),
137            namer: proc::Namer::default(),
138            options,
139            pipeline_options,
140            entry_point_io: crate::FastHashMap::default(),
141            named_expressions: crate::NamedExpressions::default(),
142            wrapped: super::Wrapped::default(),
143            written_committed_intersection: false,
144            written_candidate_intersection: false,
145            continue_ctx: back::continue_forward::ContinueCtx::default(),
146            temp_access_chain: Vec::new(),
147            need_bake_expressions: Default::default(),
148        }
149    }
150
151    fn reset(&mut self, module: &Module) {
152        self.names.clear();
153        self.namer.reset(
154            module,
155            &super::keywords::RESERVED_SET,
156            super::keywords::RESERVED_CASE_INSENSITIVE,
157            super::keywords::RESERVED_PREFIXES,
158            &mut self.names,
159        );
160        self.entry_point_io.clear();
161        self.named_expressions.clear();
162        self.wrapped.clear();
163        self.written_committed_intersection = false;
164        self.written_candidate_intersection = false;
165        self.continue_ctx.clear();
166        self.need_bake_expressions.clear();
167    }
168
169    /// Generates statements to be inserted immediately before and at the very
170    /// start of the body of each loop, to defeat infinite loop reasoning.
171    /// The 0th item of the returned tuple should be inserted immediately prior
172    /// to the loop and the 1st item should be inserted at the very start of
173    /// the loop body.
174    ///
175    /// See [`back::msl::Writer::gen_force_bounded_loop_statements`] for details.
176    fn gen_force_bounded_loop_statements(
177        &mut self,
178        level: back::Level,
179    ) -> Option<(String, String)> {
180        if !self.options.force_loop_bounding {
181            return None;
182        }
183
184        let loop_bound_name = self.namer.call("loop_bound");
185        let max = u32::MAX;
186        // Count down from u32::MAX rather than up from 0 to avoid hang on
187        // certain Intel drivers. See <https://github.com/gfx-rs/wgpu/issues/7319>.
188        let decl = format!("{level}uint2 {loop_bound_name} = uint2({max}u, {max}u);");
189        let level = level.next();
190        let break_and_inc = format!(
191            "{level}if (all({loop_bound_name} == uint2(0u, 0u))) {{ break; }}
192{level}{loop_bound_name} -= uint2({loop_bound_name}.y == 0u, 1u);"
193        );
194
195        Some((decl, break_and_inc))
196    }
197
198    /// Helper method used to find which expressions of a given function require baking
199    ///
200    /// # Notes
201    /// Clears `need_bake_expressions` set before adding to it
202    fn update_expressions_to_bake(
203        &mut self,
204        module: &Module,
205        func: &crate::Function,
206        info: &valid::FunctionInfo,
207    ) {
208        use crate::Expression;
209        self.need_bake_expressions.clear();
210        for (exp_handle, expr) in func.expressions.iter() {
211            let expr_info = &info[exp_handle];
212            let min_ref_count = func.expressions[exp_handle].bake_ref_count();
213            if min_ref_count <= expr_info.ref_count {
214                self.need_bake_expressions.insert(exp_handle);
215            }
216
217            if let Expression::Math { fun, arg, arg1, .. } = *expr {
218                match fun {
219                    crate::MathFunction::Asinh
220                    | crate::MathFunction::Acosh
221                    | crate::MathFunction::Atanh
222                    | crate::MathFunction::Unpack2x16float
223                    | crate::MathFunction::Unpack2x16snorm
224                    | crate::MathFunction::Unpack2x16unorm
225                    | crate::MathFunction::Unpack4x8snorm
226                    | crate::MathFunction::Unpack4x8unorm
227                    | crate::MathFunction::Unpack4xI8
228                    | crate::MathFunction::Unpack4xU8
229                    | crate::MathFunction::Pack2x16float
230                    | crate::MathFunction::Pack2x16snorm
231                    | crate::MathFunction::Pack2x16unorm
232                    | crate::MathFunction::Pack4x8snorm
233                    | crate::MathFunction::Pack4x8unorm
234                    | crate::MathFunction::Pack4xI8
235                    | crate::MathFunction::Pack4xU8
236                    | crate::MathFunction::Pack4xI8Clamp
237                    | crate::MathFunction::Pack4xU8Clamp => {
238                        self.need_bake_expressions.insert(arg);
239                    }
240                    crate::MathFunction::CountLeadingZeros => {
241                        let inner = info[exp_handle].ty.inner_with(&module.types);
242                        if let Some(ScalarKind::Sint) = inner.scalar_kind() {
243                            self.need_bake_expressions.insert(arg);
244                        }
245                    }
246                    crate::MathFunction::Dot4U8Packed | crate::MathFunction::Dot4I8Packed => {
247                        self.need_bake_expressions.insert(arg);
248                        self.need_bake_expressions.insert(arg1.unwrap());
249                    }
250                    _ => {}
251                }
252            }
253
254            if let Expression::Derivative { axis, ctrl, expr } = *expr {
255                use crate::{DerivativeAxis as Axis, DerivativeControl as Ctrl};
256                if axis == Axis::Width && (ctrl == Ctrl::Coarse || ctrl == Ctrl::Fine) {
257                    self.need_bake_expressions.insert(expr);
258                }
259            }
260
261            if let Expression::GlobalVariable(_) = *expr {
262                let inner = info[exp_handle].ty.inner_with(&module.types);
263
264                if let TypeInner::Sampler { .. } = *inner {
265                    self.need_bake_expressions.insert(exp_handle);
266                }
267            }
268        }
269        for statement in func.body.iter() {
270            match *statement {
271                crate::Statement::SubgroupCollectiveOperation {
272                    op: _,
273                    collective_op: crate::CollectiveOperation::InclusiveScan,
274                    argument,
275                    result: _,
276                } => {
277                    self.need_bake_expressions.insert(argument);
278                }
279                crate::Statement::Atomic {
280                    fun: crate::AtomicFunction::Exchange { compare: Some(cmp) },
281                    ..
282                } => {
283                    self.need_bake_expressions.insert(cmp);
284                }
285                _ => {}
286            }
287        }
288    }
289
290    pub fn write(
291        &mut self,
292        module: &Module,
293        module_info: &valid::ModuleInfo,
294        fragment_entry_point: Option<&FragmentEntryPoint<'_>>,
295    ) -> Result<super::ReflectionInfo, Error> {
296        self.reset(module);
297
298        // Write special constants, if needed
299        if let Some(ref bt) = self.options.special_constants_binding {
300            writeln!(self.out, "struct {SPECIAL_CBUF_TYPE} {{")?;
301            writeln!(self.out, "{}int {};", back::INDENT, SPECIAL_FIRST_VERTEX)?;
302            writeln!(self.out, "{}int {};", back::INDENT, SPECIAL_FIRST_INSTANCE)?;
303            writeln!(self.out, "{}uint {};", back::INDENT, SPECIAL_OTHER)?;
304            writeln!(self.out, "}};")?;
305            write!(
306                self.out,
307                "ConstantBuffer<{}> {}: register(b{}",
308                SPECIAL_CBUF_TYPE, SPECIAL_CBUF_VAR, bt.register
309            )?;
310            if bt.space != 0 {
311                write!(self.out, ", space{}", bt.space)?;
312            }
313            writeln!(self.out, ");")?;
314
315            // Extra newline for readability
316            writeln!(self.out)?;
317        }
318
319        for (group, bt) in self.options.dynamic_storage_buffer_offsets_targets.iter() {
320            writeln!(self.out, "struct __dynamic_buffer_offsetsTy{group} {{")?;
321            for i in 0..bt.size {
322                writeln!(self.out, "{}uint _{};", back::INDENT, i)?;
323            }
324            writeln!(self.out, "}};")?;
325            writeln!(
326                self.out,
327                "ConstantBuffer<__dynamic_buffer_offsetsTy{}> __dynamic_buffer_offsets{}: register(b{}, space{});",
328                group, group, bt.register, bt.space
329            )?;
330
331            // Extra newline for readability
332            writeln!(self.out)?;
333        }
334
335        // Save all entry point output types
336        let ep_results = module
337            .entry_points
338            .iter()
339            .map(|ep| (ep.stage, ep.function.result.clone()))
340            .collect::<Vec<(ShaderStage, Option<crate::FunctionResult>)>>();
341
342        self.write_all_mat_cx2_typedefs_and_functions(module)?;
343
344        // Write all structs
345        for (handle, ty) in module.types.iter() {
346            if let TypeInner::Struct { ref members, span } = ty.inner {
347                if module.types[members.last().unwrap().ty]
348                    .inner
349                    .is_dynamically_sized(&module.types)
350                {
351                    // unsized arrays can only be in storage buffers,
352                    // for which we use `ByteAddressBuffer` anyway.
353                    continue;
354                }
355
356                let ep_result = ep_results.iter().find(|e| {
357                    if let Some(ref result) = e.1 {
358                        result.ty == handle
359                    } else {
360                        false
361                    }
362                });
363
364                self.write_struct(
365                    module,
366                    handle,
367                    members,
368                    span,
369                    ep_result.map(|r| (r.0, Io::Output)),
370                )?;
371                writeln!(self.out)?;
372            }
373        }
374
375        self.write_special_functions(module)?;
376
377        self.write_wrapped_expression_functions(module, &module.global_expressions, None)?;
378        self.write_wrapped_zero_value_functions(module, &module.global_expressions)?;
379
380        // Write all named constants
381        let mut constants = module
382            .constants
383            .iter()
384            .filter(|&(_, c)| c.name.is_some())
385            .peekable();
386        while let Some((handle, _)) = constants.next() {
387            self.write_global_constant(module, handle)?;
388            // Add extra newline for readability on last iteration
389            if constants.peek().is_none() {
390                writeln!(self.out)?;
391            }
392        }
393
394        // Write all globals
395        for (global, _) in module.global_variables.iter() {
396            self.write_global(module, global)?;
397        }
398
399        if !module.global_variables.is_empty() {
400            // Add extra newline for readability
401            writeln!(self.out)?;
402        }
403
404        let ep_range = get_entry_points(module, self.pipeline_options.entry_point.as_ref())
405            .map_err(|(stage, name)| Error::EntryPointNotFound(stage, name))?;
406
407        // Write all entry points wrapped structs
408        for index in ep_range.clone() {
409            let ep = &module.entry_points[index];
410            let ep_name = self.names[&NameKey::EntryPoint(index as u16)].clone();
411            let ep_io = self.write_ep_interface(
412                module,
413                &ep.function,
414                ep.stage,
415                &ep_name,
416                fragment_entry_point,
417            )?;
418            self.entry_point_io.insert(index, ep_io);
419        }
420
421        // Write all regular functions
422        for (handle, function) in module.functions.iter() {
423            let info = &module_info[handle];
424
425            // Check if all of the globals are accessible
426            if !self.options.fake_missing_bindings {
427                if let Some((var_handle, _)) =
428                    module
429                        .global_variables
430                        .iter()
431                        .find(|&(var_handle, var)| match var.binding {
432                            Some(ref binding) if !info[var_handle].is_empty() => {
433                                self.options.resolve_resource_binding(binding).is_err()
434                            }
435                            _ => false,
436                        })
437                {
438                    log::info!(
439                        "Skipping function {:?} (name {:?}) because global {:?} is inaccessible",
440                        handle,
441                        function.name,
442                        var_handle
443                    );
444                    continue;
445                }
446            }
447
448            let ctx = back::FunctionCtx {
449                ty: back::FunctionType::Function(handle),
450                info,
451                expressions: &function.expressions,
452                named_expressions: &function.named_expressions,
453            };
454            let name = self.names[&NameKey::Function(handle)].clone();
455
456            self.write_wrapped_functions(module, &ctx)?;
457
458            self.write_function(module, name.as_str(), function, &ctx, info)?;
459
460            writeln!(self.out)?;
461        }
462
463        let mut translated_ep_names = Vec::with_capacity(ep_range.len());
464
465        // Write all entry points
466        for index in ep_range {
467            let ep = &module.entry_points[index];
468            let info = module_info.get_entry_point(index);
469
470            if !self.options.fake_missing_bindings {
471                let mut ep_error = None;
472                for (var_handle, var) in module.global_variables.iter() {
473                    match var.binding {
474                        Some(ref binding) if !info[var_handle].is_empty() => {
475                            if let Err(err) = self.options.resolve_resource_binding(binding) {
476                                ep_error = Some(err);
477                                break;
478                            }
479                        }
480                        _ => {}
481                    }
482                }
483                if let Some(err) = ep_error {
484                    translated_ep_names.push(Err(err));
485                    continue;
486                }
487            }
488
489            let ctx = back::FunctionCtx {
490                ty: back::FunctionType::EntryPoint(index as u16),
491                info,
492                expressions: &ep.function.expressions,
493                named_expressions: &ep.function.named_expressions,
494            };
495
496            self.write_wrapped_functions(module, &ctx)?;
497
498            if ep.stage == ShaderStage::Compute {
499                // HLSL is calling workgroup size "num threads"
500                let num_threads = ep.workgroup_size;
501                writeln!(
502                    self.out,
503                    "[numthreads({}, {}, {})]",
504                    num_threads[0], num_threads[1], num_threads[2]
505                )?;
506            }
507
508            let name = self.names[&NameKey::EntryPoint(index as u16)].clone();
509            self.write_function(module, &name, &ep.function, &ctx, info)?;
510
511            if index < module.entry_points.len() - 1 {
512                writeln!(self.out)?;
513            }
514
515            translated_ep_names.push(Ok(name));
516        }
517
518        Ok(super::ReflectionInfo {
519            entry_point_names: translated_ep_names,
520        })
521    }
522
523    fn write_modifier(&mut self, binding: &crate::Binding) -> BackendResult {
524        match *binding {
525            crate::Binding::BuiltIn(crate::BuiltIn::Position { invariant: true }) => {
526                write!(self.out, "precise ")?;
527            }
528            crate::Binding::Location {
529                interpolation,
530                sampling,
531                ..
532            } => {
533                if let Some(interpolation) = interpolation {
534                    if let Some(string) = interpolation.to_hlsl_str() {
535                        write!(self.out, "{string} ")?
536                    }
537                }
538
539                if let Some(sampling) = sampling {
540                    if let Some(string) = sampling.to_hlsl_str() {
541                        write!(self.out, "{string} ")?
542                    }
543                }
544            }
545            crate::Binding::BuiltIn(_) => {}
546        }
547
548        Ok(())
549    }
550
551    //TODO: we could force fragment outputs to always go through `entry_point_io.output` path
552    // if they are struct, so that the `stage` argument here could be omitted.
553    fn write_semantic(
554        &mut self,
555        binding: &Option<crate::Binding>,
556        stage: Option<(ShaderStage, Io)>,
557    ) -> BackendResult {
558        match *binding {
559            Some(crate::Binding::BuiltIn(builtin)) if !is_subgroup_builtin_binding(binding) => {
560                let builtin_str = builtin.to_hlsl_str()?;
561                write!(self.out, " : {builtin_str}")?;
562            }
563            Some(crate::Binding::Location {
564                blend_src: Some(1), ..
565            }) => {
566                write!(self.out, " : SV_Target1")?;
567            }
568            Some(crate::Binding::Location { location, .. }) => {
569                if stage == Some((ShaderStage::Fragment, Io::Output)) {
570                    write!(self.out, " : SV_Target{location}")?;
571                } else {
572                    write!(self.out, " : {LOCATION_SEMANTIC}{location}")?;
573                }
574            }
575            _ => {}
576        }
577
578        Ok(())
579    }
580
581    fn write_interface_struct(
582        &mut self,
583        module: &Module,
584        shader_stage: (ShaderStage, Io),
585        struct_name: String,
586        mut members: Vec<EpStructMember>,
587    ) -> Result<EntryPointBinding, Error> {
588        // Sort the members so that first come the user-defined varyings
589        // in ascending locations, and then built-ins. This allows VS and FS
590        // interfaces to match with regards to order.
591        members.sort_by_key(|m| InterfaceKey::new(m.binding.as_ref()));
592
593        write!(self.out, "struct {struct_name}")?;
594        writeln!(self.out, " {{")?;
595        for m in members.iter() {
596            // Sanity check that each IO member is a built-in or is assigned a
597            // location. Also see note about nesting in `write_ep_input_struct`.
598            debug_assert!(m.binding.is_some());
599
600            if is_subgroup_builtin_binding(&m.binding) {
601                continue;
602            }
603            write!(self.out, "{}", back::INDENT)?;
604            if let Some(ref binding) = m.binding {
605                self.write_modifier(binding)?;
606            }
607            self.write_type(module, m.ty)?;
608            write!(self.out, " {}", &m.name)?;
609            self.write_semantic(&m.binding, Some(shader_stage))?;
610            writeln!(self.out, ";")?;
611        }
612        if members.iter().any(|arg| {
613            matches!(
614                arg.binding,
615                Some(crate::Binding::BuiltIn(crate::BuiltIn::SubgroupId))
616            )
617        }) {
618            writeln!(
619                self.out,
620                "{}uint __local_invocation_index : SV_GroupIndex;",
621                back::INDENT
622            )?;
623        }
624        writeln!(self.out, "}};")?;
625        writeln!(self.out)?;
626
627        // See ordering notes on EntryPointInterface fields
628        match shader_stage.1 {
629            Io::Input => {
630                // bring back the original order
631                members.sort_by_key(|m| m.index);
632            }
633            Io::Output => {
634                // keep it sorted by binding
635            }
636        }
637
638        Ok(EntryPointBinding {
639            arg_name: self.namer.call(struct_name.to_lowercase().as_str()),
640            ty_name: struct_name,
641            members,
642        })
643    }
644
645    /// Flatten all entry point arguments into a single struct.
646    /// This is needed since we need to re-order them: first placing user locations,
647    /// then built-ins.
648    fn write_ep_input_struct(
649        &mut self,
650        module: &Module,
651        func: &crate::Function,
652        stage: ShaderStage,
653        entry_point_name: &str,
654    ) -> Result<EntryPointBinding, Error> {
655        let struct_name = format!("{stage:?}Input_{entry_point_name}");
656
657        let mut fake_members = Vec::new();
658        for arg in func.arguments.iter() {
659            // NOTE: We don't need to handle nesting structs. All members must
660            // be either built-ins or assigned a location. I.E. `binding` is
661            // `Some`. This is checked in `VaryingContext::validate`. See:
662            // https://gpuweb.github.io/gpuweb/wgsl/#input-output-locations
663            match module.types[arg.ty].inner {
664                TypeInner::Struct { ref members, .. } => {
665                    for member in members.iter() {
666                        let name = self.namer.call_or(&member.name, "member");
667                        let index = fake_members.len() as u32;
668                        fake_members.push(EpStructMember {
669                            name,
670                            ty: member.ty,
671                            binding: member.binding.clone(),
672                            index,
673                        });
674                    }
675                }
676                _ => {
677                    let member_name = self.namer.call_or(&arg.name, "member");
678                    let index = fake_members.len() as u32;
679                    fake_members.push(EpStructMember {
680                        name: member_name,
681                        ty: arg.ty,
682                        binding: arg.binding.clone(),
683                        index,
684                    });
685                }
686            }
687        }
688
689        self.write_interface_struct(module, (stage, Io::Input), struct_name, fake_members)
690    }
691
692    /// Flatten all entry point results into a single struct.
693    /// This is needed since we need to re-order them: first placing user locations,
694    /// then built-ins.
695    fn write_ep_output_struct(
696        &mut self,
697        module: &Module,
698        result: &crate::FunctionResult,
699        stage: ShaderStage,
700        entry_point_name: &str,
701        frag_ep: Option<&FragmentEntryPoint<'_>>,
702    ) -> Result<EntryPointBinding, Error> {
703        let struct_name = format!("{stage:?}Output_{entry_point_name}");
704
705        let empty = [];
706        let members = match module.types[result.ty].inner {
707            TypeInner::Struct { ref members, .. } => members,
708            ref other => {
709                log::error!("Unexpected {other:?} output type without a binding");
710                &empty[..]
711            }
712        };
713
714        // Gather list of fragment input locations. We use this below to remove user-defined
715        // varyings from VS outputs that aren't in the FS inputs. This makes the VS interface match
716        // as long as the FS inputs are a subset of the VS outputs. This is only applied if the
717        // writer is supplied with information about the fragment entry point.
718        let fs_input_locs = if let (Some(frag_ep), ShaderStage::Vertex) = (frag_ep, stage) {
719            let mut fs_input_locs = Vec::new();
720            for arg in frag_ep.func.arguments.iter() {
721                let mut push_if_location = |binding: &Option<crate::Binding>| match *binding {
722                    Some(crate::Binding::Location { location, .. }) => fs_input_locs.push(location),
723                    Some(crate::Binding::BuiltIn(_)) | None => {}
724                };
725
726                // NOTE: We don't need to handle struct nesting. See note in
727                // `write_ep_input_struct`.
728                match frag_ep.module.types[arg.ty].inner {
729                    TypeInner::Struct { ref members, .. } => {
730                        for member in members.iter() {
731                            push_if_location(&member.binding);
732                        }
733                    }
734                    _ => push_if_location(&arg.binding),
735                }
736            }
737            fs_input_locs.sort();
738            Some(fs_input_locs)
739        } else {
740            None
741        };
742
743        let mut fake_members = Vec::new();
744        for (index, member) in members.iter().enumerate() {
745            if let Some(ref fs_input_locs) = fs_input_locs {
746                match member.binding {
747                    Some(crate::Binding::Location { location, .. }) => {
748                        if fs_input_locs.binary_search(&location).is_err() {
749                            continue;
750                        }
751                    }
752                    Some(crate::Binding::BuiltIn(_)) | None => {}
753                }
754            }
755
756            let member_name = self.namer.call_or(&member.name, "member");
757            fake_members.push(EpStructMember {
758                name: member_name,
759                ty: member.ty,
760                binding: member.binding.clone(),
761                index: index as u32,
762            });
763        }
764
765        self.write_interface_struct(module, (stage, Io::Output), struct_name, fake_members)
766    }
767
768    /// Writes special interface structures for an entry point. The special structures have
769    /// all the fields flattened into them and sorted by binding. They are needed to emulate
770    /// subgroup built-ins and to make the interfaces between VS outputs and FS inputs match.
771    fn write_ep_interface(
772        &mut self,
773        module: &Module,
774        func: &crate::Function,
775        stage: ShaderStage,
776        ep_name: &str,
777        frag_ep: Option<&FragmentEntryPoint<'_>>,
778    ) -> Result<EntryPointInterface, Error> {
779        Ok(EntryPointInterface {
780            input: if !func.arguments.is_empty()
781                && (stage == ShaderStage::Fragment
782                    || func
783                        .arguments
784                        .iter()
785                        .any(|arg| is_subgroup_builtin_binding(&arg.binding)))
786            {
787                Some(self.write_ep_input_struct(module, func, stage, ep_name)?)
788            } else {
789                None
790            },
791            output: match func.result {
792                Some(ref fr) if fr.binding.is_none() && stage == ShaderStage::Vertex => {
793                    Some(self.write_ep_output_struct(module, fr, stage, ep_name, frag_ep)?)
794                }
795                _ => None,
796            },
797        })
798    }
799
800    fn write_ep_argument_initialization(
801        &mut self,
802        ep: &crate::EntryPoint,
803        ep_input: &EntryPointBinding,
804        fake_member: &EpStructMember,
805    ) -> BackendResult {
806        match fake_member.binding {
807            Some(crate::Binding::BuiltIn(crate::BuiltIn::SubgroupSize)) => {
808                write!(self.out, "WaveGetLaneCount()")?
809            }
810            Some(crate::Binding::BuiltIn(crate::BuiltIn::SubgroupInvocationId)) => {
811                write!(self.out, "WaveGetLaneIndex()")?
812            }
813            Some(crate::Binding::BuiltIn(crate::BuiltIn::NumSubgroups)) => write!(
814                self.out,
815                "({}u + WaveGetLaneCount() - 1u) / WaveGetLaneCount()",
816                ep.workgroup_size[0] * ep.workgroup_size[1] * ep.workgroup_size[2]
817            )?,
818            Some(crate::Binding::BuiltIn(crate::BuiltIn::SubgroupId)) => {
819                write!(
820                    self.out,
821                    "{}.__local_invocation_index / WaveGetLaneCount()",
822                    ep_input.arg_name
823                )?;
824            }
825            _ => {
826                write!(self.out, "{}.{}", ep_input.arg_name, fake_member.name)?;
827            }
828        }
829        Ok(())
830    }
831
832    /// Write an entry point preface that initializes the arguments as specified in IR.
833    fn write_ep_arguments_initialization(
834        &mut self,
835        module: &Module,
836        func: &crate::Function,
837        ep_index: u16,
838    ) -> BackendResult {
839        let ep = &module.entry_points[ep_index as usize];
840        let ep_input = match self
841            .entry_point_io
842            .get_mut(&(ep_index as usize))
843            .unwrap()
844            .input
845            .take()
846        {
847            Some(ep_input) => ep_input,
848            None => return Ok(()),
849        };
850        let mut fake_iter = ep_input.members.iter();
851        for (arg_index, arg) in func.arguments.iter().enumerate() {
852            write!(self.out, "{}", back::INDENT)?;
853            self.write_type(module, arg.ty)?;
854            let arg_name = &self.names[&NameKey::EntryPointArgument(ep_index, arg_index as u32)];
855            write!(self.out, " {arg_name}")?;
856            match module.types[arg.ty].inner {
857                TypeInner::Array { base, size, .. } => {
858                    self.write_array_size(module, base, size)?;
859                    write!(self.out, " = ")?;
860                    self.write_ep_argument_initialization(
861                        ep,
862                        &ep_input,
863                        fake_iter.next().unwrap(),
864                    )?;
865                    writeln!(self.out, ";")?;
866                }
867                TypeInner::Struct { ref members, .. } => {
868                    write!(self.out, " = {{ ")?;
869                    for index in 0..members.len() {
870                        if index != 0 {
871                            write!(self.out, ", ")?;
872                        }
873                        self.write_ep_argument_initialization(
874                            ep,
875                            &ep_input,
876                            fake_iter.next().unwrap(),
877                        )?;
878                    }
879                    writeln!(self.out, " }};")?;
880                }
881                _ => {
882                    write!(self.out, " = ")?;
883                    self.write_ep_argument_initialization(
884                        ep,
885                        &ep_input,
886                        fake_iter.next().unwrap(),
887                    )?;
888                    writeln!(self.out, ";")?;
889                }
890            }
891        }
892        assert!(fake_iter.next().is_none());
893        Ok(())
894    }
895
896    /// Helper method used to write global variables
897    /// # Notes
898    /// Always adds a newline
899    fn write_global(
900        &mut self,
901        module: &Module,
902        handle: Handle<crate::GlobalVariable>,
903    ) -> BackendResult {
904        let global = &module.global_variables[handle];
905        let inner = &module.types[global.ty].inner;
906
907        if let Some(ref binding) = global.binding {
908            if let Err(err) = self.options.resolve_resource_binding(binding) {
909                log::info!(
910                    "Skipping global {:?} (name {:?}) for being inaccessible: {}",
911                    handle,
912                    global.name,
913                    err,
914                );
915                return Ok(());
916            }
917        }
918
919        let handle_ty = match *inner {
920            TypeInner::BindingArray { ref base, .. } => &module.types[*base].inner,
921            _ => inner,
922        };
923
924        // Samplers are handled entirely differently, so defer entirely to that method.
925        let is_sampler = matches!(*handle_ty, TypeInner::Sampler { .. });
926
927        if is_sampler {
928            return self.write_global_sampler(module, handle, global);
929        }
930
931        // https://docs.microsoft.com/en-us/windows/win32/direct3dhlsl/dx-graphics-hlsl-variable-register
932        let register_ty = match global.space {
933            crate::AddressSpace::Function => unreachable!("Function address space"),
934            crate::AddressSpace::Private => {
935                write!(self.out, "static ")?;
936                self.write_type(module, global.ty)?;
937                ""
938            }
939            crate::AddressSpace::WorkGroup => {
940                write!(self.out, "groupshared ")?;
941                self.write_type(module, global.ty)?;
942                ""
943            }
944            crate::AddressSpace::Uniform => {
945                // constant buffer declarations are expected to be inlined, e.g.
946                // `cbuffer foo: register(b0) { field1: type1; }`
947                write!(self.out, "cbuffer")?;
948                "b"
949            }
950            crate::AddressSpace::Storage { access } => {
951                let (prefix, register) = if access.contains(crate::StorageAccess::STORE) {
952                    ("RW", "u")
953                } else {
954                    ("", "t")
955                };
956                write!(self.out, "{prefix}ByteAddressBuffer")?;
957                register
958            }
959            crate::AddressSpace::Handle => {
960                let register = match *handle_ty {
961                    // all storage textures are UAV, unconditionally
962                    TypeInner::Image {
963                        class: crate::ImageClass::Storage { .. },
964                        ..
965                    } => "u",
966                    _ => "t",
967                };
968                self.write_type(module, global.ty)?;
969                register
970            }
971            crate::AddressSpace::PushConstant => {
972                // The type of the push constants will be wrapped in `ConstantBuffer`
973                write!(self.out, "ConstantBuffer<")?;
974                "b"
975            }
976        };
977
978        // If the global is a push constant write the type now because it will be a
979        // generic argument to `ConstantBuffer`
980        if global.space == crate::AddressSpace::PushConstant {
981            self.write_global_type(module, global.ty)?;
982
983            // need to write the array size if the type was emitted with `write_type`
984            if let TypeInner::Array { base, size, .. } = module.types[global.ty].inner {
985                self.write_array_size(module, base, size)?;
986            }
987
988            // Close the angled brackets for the generic argument
989            write!(self.out, ">")?;
990        }
991
992        let name = &self.names[&NameKey::GlobalVariable(handle)];
993        write!(self.out, " {name}")?;
994
995        // Push constants need to be assigned a binding explicitly by the consumer
996        // since naga has no way to know the binding from the shader alone
997        if global.space == crate::AddressSpace::PushConstant {
998            match module.types[global.ty].inner {
999                TypeInner::Struct { .. } => {}
1000                _ => {
1001                    return Err(Error::Unimplemented(format!(
1002                        "push-constant '{name}' has non-struct type; tracked by: https://github.com/gfx-rs/wgpu/issues/5683"
1003                    )));
1004                }
1005            }
1006
1007            let target = self
1008                .options
1009                .push_constants_target
1010                .as_ref()
1011                .expect("No bind target was defined for the push constants block");
1012            write!(self.out, ": register(b{}", target.register)?;
1013            if target.space != 0 {
1014                write!(self.out, ", space{}", target.space)?;
1015            }
1016            write!(self.out, ")")?;
1017        }
1018
1019        if let Some(ref binding) = global.binding {
1020            // this was already resolved earlier when we started evaluating an entry point.
1021            let bt = self.options.resolve_resource_binding(binding).unwrap();
1022
1023            // need to write the binding array size if the type was emitted with `write_type`
1024            if let TypeInner::BindingArray { base, size, .. } = module.types[global.ty].inner {
1025                if let Some(overridden_size) = bt.binding_array_size {
1026                    write!(self.out, "[{overridden_size}]")?;
1027                } else {
1028                    self.write_array_size(module, base, size)?;
1029                }
1030            }
1031
1032            write!(self.out, " : register({}{}", register_ty, bt.register)?;
1033            if bt.space != 0 {
1034                write!(self.out, ", space{}", bt.space)?;
1035            }
1036            write!(self.out, ")")?;
1037        } else {
1038            // need to write the array size if the type was emitted with `write_type`
1039            if let TypeInner::Array { base, size, .. } = module.types[global.ty].inner {
1040                self.write_array_size(module, base, size)?;
1041            }
1042            if global.space == crate::AddressSpace::Private {
1043                write!(self.out, " = ")?;
1044                if let Some(init) = global.init {
1045                    self.write_const_expression(module, init, &module.global_expressions)?;
1046                } else {
1047                    self.write_default_init(module, global.ty)?;
1048                }
1049            }
1050        }
1051
1052        if global.space == crate::AddressSpace::Uniform {
1053            write!(self.out, " {{ ")?;
1054
1055            self.write_global_type(module, global.ty)?;
1056
1057            write!(
1058                self.out,
1059                " {}",
1060                &self.names[&NameKey::GlobalVariable(handle)]
1061            )?;
1062
1063            // need to write the array size if the type was emitted with `write_type`
1064            if let TypeInner::Array { base, size, .. } = module.types[global.ty].inner {
1065                self.write_array_size(module, base, size)?;
1066            }
1067
1068            writeln!(self.out, "; }}")?;
1069        } else {
1070            writeln!(self.out, ";")?;
1071        }
1072
1073        Ok(())
1074    }
1075
1076    fn write_global_sampler(
1077        &mut self,
1078        module: &Module,
1079        handle: Handle<crate::GlobalVariable>,
1080        global: &crate::GlobalVariable,
1081    ) -> BackendResult {
1082        let binding = *global.binding.as_ref().unwrap();
1083
1084        let key = super::SamplerIndexBufferKey {
1085            group: binding.group,
1086        };
1087        self.write_wrapped_sampler_buffer(key)?;
1088
1089        // This was already validated, so we can confidently unwrap it.
1090        let bt = self.options.resolve_resource_binding(&binding).unwrap();
1091
1092        match module.types[global.ty].inner {
1093            TypeInner::Sampler { comparison } => {
1094                // If we are generating a static access, we create a variable for the sampler.
1095                //
1096                // This prevents the DXIL from containing multiple lookups for the sampler, which
1097                // the backend compiler will then have to eliminate. AMD does seem to be able to
1098                // eliminate these, but better safe than sorry.
1099
1100                write!(self.out, "static const ")?;
1101                self.write_type(module, global.ty)?;
1102
1103                let heap_var = if comparison {
1104                    COMPARISON_SAMPLER_HEAP_VAR
1105                } else {
1106                    SAMPLER_HEAP_VAR
1107                };
1108
1109                let index_buffer_name = &self.wrapped.sampler_index_buffers[&key];
1110                let name = &self.names[&NameKey::GlobalVariable(handle)];
1111                writeln!(
1112                    self.out,
1113                    " {name} = {heap_var}[{index_buffer_name}[{register}]];",
1114                    register = bt.register
1115                )?;
1116            }
1117            TypeInner::BindingArray { .. } => {
1118                // If we are generating a binding array, we cannot directly access the sampler as the index
1119                // into the sampler index buffer is unknown at compile time. Instead we generate a constant
1120                // that represents the "base" index into the sampler index buffer. This constant is added
1121                // to the user provided index to get the final index into the sampler index buffer.
1122
1123                let name = &self.names[&NameKey::GlobalVariable(handle)];
1124                writeln!(
1125                    self.out,
1126                    "static const uint {name} = {register};",
1127                    register = bt.register
1128                )?;
1129            }
1130            _ => unreachable!(),
1131        };
1132
1133        Ok(())
1134    }
1135
1136    /// Helper method used to write global constants
1137    ///
1138    /// # Notes
1139    /// Ends in a newline
1140    fn write_global_constant(
1141        &mut self,
1142        module: &Module,
1143        handle: Handle<crate::Constant>,
1144    ) -> BackendResult {
1145        write!(self.out, "static const ")?;
1146        let constant = &module.constants[handle];
1147        self.write_type(module, constant.ty)?;
1148        let name = &self.names[&NameKey::Constant(handle)];
1149        write!(self.out, " {name}")?;
1150        // Write size for array type
1151        if let TypeInner::Array { base, size, .. } = module.types[constant.ty].inner {
1152            self.write_array_size(module, base, size)?;
1153        }
1154        write!(self.out, " = ")?;
1155        self.write_const_expression(module, constant.init, &module.global_expressions)?;
1156        writeln!(self.out, ";")?;
1157        Ok(())
1158    }
1159
1160    pub(super) fn write_array_size(
1161        &mut self,
1162        module: &Module,
1163        base: Handle<crate::Type>,
1164        size: crate::ArraySize,
1165    ) -> BackendResult {
1166        write!(self.out, "[")?;
1167
1168        match size.resolve(module.to_ctx())? {
1169            proc::IndexableLength::Known(size) => {
1170                write!(self.out, "{size}")?;
1171            }
1172            proc::IndexableLength::Dynamic => unreachable!(),
1173        }
1174
1175        write!(self.out, "]")?;
1176
1177        if let TypeInner::Array {
1178            base: next_base,
1179            size: next_size,
1180            ..
1181        } = module.types[base].inner
1182        {
1183            self.write_array_size(module, next_base, next_size)?;
1184        }
1185
1186        Ok(())
1187    }
1188
1189    /// Helper method used to write structs
1190    ///
1191    /// # Notes
1192    /// Ends in a newline
1193    fn write_struct(
1194        &mut self,
1195        module: &Module,
1196        handle: Handle<crate::Type>,
1197        members: &[crate::StructMember],
1198        span: u32,
1199        shader_stage: Option<(ShaderStage, Io)>,
1200    ) -> BackendResult {
1201        // Write struct name
1202        let struct_name = &self.names[&NameKey::Type(handle)];
1203        writeln!(self.out, "struct {struct_name} {{")?;
1204
1205        let mut last_offset = 0;
1206        for (index, member) in members.iter().enumerate() {
1207            if member.binding.is_none() && member.offset > last_offset {
1208                // using int as padding should work as long as the backend
1209                // doesn't support a type that's less than 4 bytes in size
1210                // (Error::UnsupportedScalar catches this)
1211                let padding = (member.offset - last_offset) / 4;
1212                for i in 0..padding {
1213                    writeln!(self.out, "{}int _pad{}_{};", back::INDENT, index, i)?;
1214                }
1215            }
1216            let ty_inner = &module.types[member.ty].inner;
1217            last_offset = member.offset + ty_inner.size_hlsl(module.to_ctx())?;
1218
1219            // The indentation is only for readability
1220            write!(self.out, "{}", back::INDENT)?;
1221
1222            match module.types[member.ty].inner {
1223                TypeInner::Array { base, size, .. } => {
1224                    // HLSL arrays are written as `type name[size]`
1225
1226                    self.write_global_type(module, member.ty)?;
1227
1228                    // Write `name`
1229                    write!(
1230                        self.out,
1231                        " {}",
1232                        &self.names[&NameKey::StructMember(handle, index as u32)]
1233                    )?;
1234                    // Write [size]
1235                    self.write_array_size(module, base, size)?;
1236                }
1237                // We treat matrices of the form `matCx2` as a sequence of C `vec2`s.
1238                // See the module-level block comment in mod.rs for details.
1239                TypeInner::Matrix {
1240                    rows,
1241                    columns,
1242                    scalar,
1243                } if member.binding.is_none() && rows == crate::VectorSize::Bi => {
1244                    let vec_ty = TypeInner::Vector { size: rows, scalar };
1245                    let field_name_key = NameKey::StructMember(handle, index as u32);
1246
1247                    for i in 0..columns as u8 {
1248                        if i != 0 {
1249                            write!(self.out, "; ")?;
1250                        }
1251                        self.write_value_type(module, &vec_ty)?;
1252                        write!(self.out, " {}_{}", &self.names[&field_name_key], i)?;
1253                    }
1254                }
1255                _ => {
1256                    // Write modifier before type
1257                    if let Some(ref binding) = member.binding {
1258                        self.write_modifier(binding)?;
1259                    }
1260
1261                    // Even though Naga IR matrices are column-major, we must describe
1262                    // matrices passed from the CPU as being in row-major order.
1263                    // See the module-level block comment in mod.rs for details.
1264                    if let TypeInner::Matrix { .. } = module.types[member.ty].inner {
1265                        write!(self.out, "row_major ")?;
1266                    }
1267
1268                    // Write the member type and name
1269                    self.write_type(module, member.ty)?;
1270                    write!(
1271                        self.out,
1272                        " {}",
1273                        &self.names[&NameKey::StructMember(handle, index as u32)]
1274                    )?;
1275                }
1276            }
1277
1278            self.write_semantic(&member.binding, shader_stage)?;
1279            writeln!(self.out, ";")?;
1280        }
1281
1282        // add padding at the end since sizes of types don't get rounded up to their alignment in HLSL
1283        if members.last().unwrap().binding.is_none() && span > last_offset {
1284            let padding = (span - last_offset) / 4;
1285            for i in 0..padding {
1286                writeln!(self.out, "{}int _end_pad_{};", back::INDENT, i)?;
1287            }
1288        }
1289
1290        writeln!(self.out, "}};")?;
1291        Ok(())
1292    }
1293
1294    /// Helper method used to write global/structs non image/sampler types
1295    ///
1296    /// # Notes
1297    /// Adds no trailing or leading whitespace
1298    pub(super) fn write_global_type(
1299        &mut self,
1300        module: &Module,
1301        ty: Handle<crate::Type>,
1302    ) -> BackendResult {
1303        let matrix_data = get_inner_matrix_data(module, ty);
1304
1305        // We treat matrices of the form `matCx2` as a sequence of C `vec2`s.
1306        // See the module-level block comment in mod.rs for details.
1307        if let Some(MatrixType {
1308            columns,
1309            rows: crate::VectorSize::Bi,
1310            width: 4,
1311        }) = matrix_data
1312        {
1313            write!(self.out, "__mat{}x2", columns as u8)?;
1314        } else {
1315            // Even though Naga IR matrices are column-major, we must describe
1316            // matrices passed from the CPU as being in row-major order.
1317            // See the module-level block comment in mod.rs for details.
1318            if matrix_data.is_some() {
1319                write!(self.out, "row_major ")?;
1320            }
1321
1322            self.write_type(module, ty)?;
1323        }
1324
1325        Ok(())
1326    }
1327
1328    /// Helper method used to write non image/sampler types
1329    ///
1330    /// # Notes
1331    /// Adds no trailing or leading whitespace
1332    pub(super) fn write_type(&mut self, module: &Module, ty: Handle<crate::Type>) -> BackendResult {
1333        let inner = &module.types[ty].inner;
1334        match *inner {
1335            TypeInner::Struct { .. } => write!(self.out, "{}", self.names[&NameKey::Type(ty)])?,
1336            // hlsl array has the size separated from the base type
1337            TypeInner::Array { base, .. } | TypeInner::BindingArray { base, .. } => {
1338                self.write_type(module, base)?
1339            }
1340            ref other => self.write_value_type(module, other)?,
1341        }
1342
1343        Ok(())
1344    }
1345
1346    /// Helper method used to write value types
1347    ///
1348    /// # Notes
1349    /// Adds no trailing or leading whitespace
1350    pub(super) fn write_value_type(&mut self, module: &Module, inner: &TypeInner) -> BackendResult {
1351        match *inner {
1352            TypeInner::Scalar(scalar) | TypeInner::Atomic(scalar) => {
1353                write!(self.out, "{}", scalar.to_hlsl_str()?)?;
1354            }
1355            TypeInner::Vector { size, scalar } => {
1356                write!(
1357                    self.out,
1358                    "{}{}",
1359                    scalar.to_hlsl_str()?,
1360                    common::vector_size_str(size)
1361                )?;
1362            }
1363            TypeInner::Matrix {
1364                columns,
1365                rows,
1366                scalar,
1367            } => {
1368                // The IR supports only float matrix
1369                // https://docs.microsoft.com/en-us/windows/win32/direct3dhlsl/dx-graphics-hlsl-matrix
1370
1371                // Because of the implicit transpose all matrices have in HLSL, we need to transpose the size as well.
1372                write!(
1373                    self.out,
1374                    "{}{}x{}",
1375                    scalar.to_hlsl_str()?,
1376                    common::vector_size_str(columns),
1377                    common::vector_size_str(rows),
1378                )?;
1379            }
1380            TypeInner::Image {
1381                dim,
1382                arrayed,
1383                class,
1384            } => {
1385                self.write_image_type(dim, arrayed, class)?;
1386            }
1387            TypeInner::Sampler { comparison } => {
1388                let sampler = if comparison {
1389                    "SamplerComparisonState"
1390                } else {
1391                    "SamplerState"
1392                };
1393                write!(self.out, "{sampler}")?;
1394            }
1395            // HLSL arrays are written as `type name[size]`
1396            // Current code is written arrays only as `[size]`
1397            // Base `type` and `name` should be written outside
1398            TypeInner::Array { base, size, .. } | TypeInner::BindingArray { base, size } => {
1399                self.write_array_size(module, base, size)?;
1400            }
1401            TypeInner::AccelerationStructure { .. } => {
1402                write!(self.out, "RaytracingAccelerationStructure")?;
1403            }
1404            TypeInner::RayQuery { .. } => {
1405                // these are constant flags, there are dynamic flags also but constant flags are not supported by naga
1406                write!(self.out, "RayQuery<RAY_FLAG_NONE>")?;
1407            }
1408            _ => return Err(Error::Unimplemented(format!("write_value_type {inner:?}"))),
1409        }
1410
1411        Ok(())
1412    }
1413
1414    /// Helper method used to write functions
1415    /// # Notes
1416    /// Ends in a newline
1417    fn write_function(
1418        &mut self,
1419        module: &Module,
1420        name: &str,
1421        func: &crate::Function,
1422        func_ctx: &back::FunctionCtx<'_>,
1423        info: &valid::FunctionInfo,
1424    ) -> BackendResult {
1425        // Function Declaration Syntax - https://docs.microsoft.com/en-us/windows/win32/direct3dhlsl/dx-graphics-hlsl-function-syntax
1426
1427        self.update_expressions_to_bake(module, func, info);
1428
1429        if let Some(ref result) = func.result {
1430            // Write typedef if return type is an array
1431            let array_return_type = match module.types[result.ty].inner {
1432                TypeInner::Array { base, size, .. } => {
1433                    let array_return_type = self.namer.call(&format!("ret_{name}"));
1434                    write!(self.out, "typedef ")?;
1435                    self.write_type(module, result.ty)?;
1436                    write!(self.out, " {array_return_type}")?;
1437                    self.write_array_size(module, base, size)?;
1438                    writeln!(self.out, ";")?;
1439                    Some(array_return_type)
1440                }
1441                _ => None,
1442            };
1443
1444            // Write modifier
1445            if let Some(
1446                ref binding @ crate::Binding::BuiltIn(crate::BuiltIn::Position { invariant: true }),
1447            ) = result.binding
1448            {
1449                self.write_modifier(binding)?;
1450            }
1451
1452            // Write return type
1453            match func_ctx.ty {
1454                back::FunctionType::Function(_) => {
1455                    if let Some(array_return_type) = array_return_type {
1456                        write!(self.out, "{array_return_type}")?;
1457                    } else {
1458                        self.write_type(module, result.ty)?;
1459                    }
1460                }
1461                back::FunctionType::EntryPoint(index) => {
1462                    if let Some(ref ep_output) =
1463                        self.entry_point_io.get(&(index as usize)).unwrap().output
1464                    {
1465                        write!(self.out, "{}", ep_output.ty_name)?;
1466                    } else {
1467                        self.write_type(module, result.ty)?;
1468                    }
1469                }
1470            }
1471        } else {
1472            write!(self.out, "void")?;
1473        }
1474
1475        // Write function name
1476        write!(self.out, " {name}(")?;
1477
1478        let need_workgroup_variables_initialization =
1479            self.need_workgroup_variables_initialization(func_ctx, module);
1480
1481        // Write function arguments for non entry point functions
1482        match func_ctx.ty {
1483            back::FunctionType::Function(handle) => {
1484                for (index, arg) in func.arguments.iter().enumerate() {
1485                    if index != 0 {
1486                        write!(self.out, ", ")?;
1487                    }
1488                    // Write argument type
1489                    let arg_ty = match module.types[arg.ty].inner {
1490                        // pointers in function arguments are expected and resolve to `inout`
1491                        TypeInner::Pointer { base, .. } => {
1492                            //TODO: can we narrow this down to just `in` when possible?
1493                            write!(self.out, "inout ")?;
1494                            base
1495                        }
1496                        _ => arg.ty,
1497                    };
1498                    self.write_type(module, arg_ty)?;
1499
1500                    let argument_name =
1501                        &self.names[&NameKey::FunctionArgument(handle, index as u32)];
1502
1503                    // Write argument name. Space is important.
1504                    write!(self.out, " {argument_name}")?;
1505                    if let TypeInner::Array { base, size, .. } = module.types[arg_ty].inner {
1506                        self.write_array_size(module, base, size)?;
1507                    }
1508                }
1509            }
1510            back::FunctionType::EntryPoint(ep_index) => {
1511                if let Some(ref ep_input) =
1512                    self.entry_point_io.get(&(ep_index as usize)).unwrap().input
1513                {
1514                    write!(self.out, "{} {}", ep_input.ty_name, ep_input.arg_name)?;
1515                } else {
1516                    let stage = module.entry_points[ep_index as usize].stage;
1517                    for (index, arg) in func.arguments.iter().enumerate() {
1518                        if index != 0 {
1519                            write!(self.out, ", ")?;
1520                        }
1521                        self.write_type(module, arg.ty)?;
1522
1523                        let argument_name =
1524                            &self.names[&NameKey::EntryPointArgument(ep_index, index as u32)];
1525
1526                        write!(self.out, " {argument_name}")?;
1527                        if let TypeInner::Array { base, size, .. } = module.types[arg.ty].inner {
1528                            self.write_array_size(module, base, size)?;
1529                        }
1530
1531                        self.write_semantic(&arg.binding, Some((stage, Io::Input)))?;
1532                    }
1533                }
1534                if need_workgroup_variables_initialization {
1535                    if self
1536                        .entry_point_io
1537                        .get(&(ep_index as usize))
1538                        .unwrap()
1539                        .input
1540                        .is_some()
1541                        || !func.arguments.is_empty()
1542                    {
1543                        write!(self.out, ", ")?;
1544                    }
1545                    write!(self.out, "uint3 __local_invocation_id : SV_GroupThreadID")?;
1546                }
1547            }
1548        }
1549        // Ends of arguments
1550        write!(self.out, ")")?;
1551
1552        // Write semantic if it present
1553        if let back::FunctionType::EntryPoint(index) = func_ctx.ty {
1554            let stage = module.entry_points[index as usize].stage;
1555            if let Some(crate::FunctionResult { ref binding, .. }) = func.result {
1556                self.write_semantic(binding, Some((stage, Io::Output)))?;
1557            }
1558        }
1559
1560        // Function body start
1561        writeln!(self.out)?;
1562        writeln!(self.out, "{{")?;
1563
1564        if need_workgroup_variables_initialization {
1565            self.write_workgroup_variables_initialization(func_ctx, module)?;
1566        }
1567
1568        if let back::FunctionType::EntryPoint(index) = func_ctx.ty {
1569            self.write_ep_arguments_initialization(module, func, index)?;
1570        }
1571
1572        // Write function local variables
1573        for (handle, local) in func.local_variables.iter() {
1574            // Write indentation (only for readability)
1575            write!(self.out, "{}", back::INDENT)?;
1576
1577            // Write the local name
1578            // The leading space is important
1579            self.write_type(module, local.ty)?;
1580            write!(self.out, " {}", self.names[&func_ctx.name_key(handle)])?;
1581            // Write size for array type
1582            if let TypeInner::Array { base, size, .. } = module.types[local.ty].inner {
1583                self.write_array_size(module, base, size)?;
1584            }
1585
1586            match module.types[local.ty].inner {
1587                // from https://microsoft.github.io/DirectX-Specs/d3d/Raytracing.html#tracerayinline-example-1 it seems that ray queries shouldn't be zeroed
1588                TypeInner::RayQuery { .. } => {}
1589                _ => {
1590                    write!(self.out, " = ")?;
1591                    // Write the local initializer if needed
1592                    if let Some(init) = local.init {
1593                        self.write_expr(module, init, func_ctx)?;
1594                    } else {
1595                        // Zero initialize local variables
1596                        self.write_default_init(module, local.ty)?;
1597                    }
1598                }
1599            }
1600            // Finish the local with `;` and add a newline (only for readability)
1601            writeln!(self.out, ";")?
1602        }
1603
1604        if !func.local_variables.is_empty() {
1605            writeln!(self.out)?;
1606        }
1607
1608        // Write the function body (statement list)
1609        for sta in func.body.iter() {
1610            // The indentation should always be 1 when writing the function body
1611            self.write_stmt(module, sta, func_ctx, back::Level(1))?;
1612        }
1613
1614        writeln!(self.out, "}}")?;
1615
1616        self.named_expressions.clear();
1617
1618        Ok(())
1619    }
1620
1621    fn need_workgroup_variables_initialization(
1622        &mut self,
1623        func_ctx: &back::FunctionCtx,
1624        module: &Module,
1625    ) -> bool {
1626        self.options.zero_initialize_workgroup_memory
1627            && func_ctx.ty.is_compute_entry_point(module)
1628            && module.global_variables.iter().any(|(handle, var)| {
1629                !func_ctx.info[handle].is_empty() && var.space == crate::AddressSpace::WorkGroup
1630            })
1631    }
1632
1633    fn write_workgroup_variables_initialization(
1634        &mut self,
1635        func_ctx: &back::FunctionCtx,
1636        module: &Module,
1637    ) -> BackendResult {
1638        let level = back::Level(1);
1639
1640        writeln!(
1641            self.out,
1642            "{level}if (all(__local_invocation_id == uint3(0u, 0u, 0u))) {{"
1643        )?;
1644
1645        let vars = module.global_variables.iter().filter(|&(handle, var)| {
1646            !func_ctx.info[handle].is_empty() && var.space == crate::AddressSpace::WorkGroup
1647        });
1648
1649        for (handle, var) in vars {
1650            let name = &self.names[&NameKey::GlobalVariable(handle)];
1651            write!(self.out, "{}{} = ", level.next(), name)?;
1652            self.write_default_init(module, var.ty)?;
1653            writeln!(self.out, ";")?;
1654        }
1655
1656        writeln!(self.out, "{level}}}")?;
1657        self.write_control_barrier(crate::Barrier::WORK_GROUP, level)
1658    }
1659
1660    /// Helper method used to write switches
1661    fn write_switch(
1662        &mut self,
1663        module: &Module,
1664        func_ctx: &back::FunctionCtx<'_>,
1665        level: back::Level,
1666        selector: Handle<crate::Expression>,
1667        cases: &[crate::SwitchCase],
1668    ) -> BackendResult {
1669        // Write all cases
1670        let indent_level_1 = level.next();
1671        let indent_level_2 = indent_level_1.next();
1672
1673        // See docs of `back::continue_forward` module.
1674        if let Some(variable) = self.continue_ctx.enter_switch(&mut self.namer) {
1675            writeln!(self.out, "{level}bool {variable} = false;",)?;
1676        };
1677
1678        // Check if there is only one body, by seeing if all except the last case are fall through
1679        // with empty bodies. FXC doesn't handle these switches correctly, so
1680        // we generate a `do {} while(false);` loop instead. There must be a default case, so there
1681        // is no need to check if one of the cases would have matched.
1682        let one_body = cases
1683            .iter()
1684            .rev()
1685            .skip(1)
1686            .all(|case| case.fall_through && case.body.is_empty());
1687        if one_body {
1688            // Start the do-while
1689            writeln!(self.out, "{level}do {{")?;
1690            // Note: Expressions have no side-effects so we don't need to emit selector expression.
1691
1692            // Body
1693            if let Some(case) = cases.last() {
1694                for sta in case.body.iter() {
1695                    self.write_stmt(module, sta, func_ctx, indent_level_1)?;
1696                }
1697            }
1698            // End do-while
1699            writeln!(self.out, "{level}}} while(false);")?;
1700        } else {
1701            // Start the switch
1702            write!(self.out, "{level}")?;
1703            write!(self.out, "switch(")?;
1704            self.write_expr(module, selector, func_ctx)?;
1705            writeln!(self.out, ") {{")?;
1706
1707            for (i, case) in cases.iter().enumerate() {
1708                match case.value {
1709                    crate::SwitchValue::I32(value) => {
1710                        write!(self.out, "{indent_level_1}case {value}:")?
1711                    }
1712                    crate::SwitchValue::U32(value) => {
1713                        write!(self.out, "{indent_level_1}case {value}u:")?
1714                    }
1715                    crate::SwitchValue::Default => write!(self.out, "{indent_level_1}default:")?,
1716                }
1717
1718                // The new block is not only stylistic, it plays a role here:
1719                // We might end up having to write the same case body
1720                // multiple times due to FXC not supporting fallthrough.
1721                // Therefore, some `Expression`s written by `Statement::Emit`
1722                // will end up having the same name (`_expr<handle_index>`).
1723                // So we need to put each case in its own scope.
1724                let write_block_braces = !(case.fall_through && case.body.is_empty());
1725                if write_block_braces {
1726                    writeln!(self.out, " {{")?;
1727                } else {
1728                    writeln!(self.out)?;
1729                }
1730
1731                // Although FXC does support a series of case clauses before
1732                // a block[^yes], it does not support fallthrough from a
1733                // non-empty case block to the next[^no]. If this case has a
1734                // non-empty body with a fallthrough, emulate that by
1735                // duplicating the bodies of all the cases it would fall
1736                // into as extensions of this case's own body. This makes
1737                // the HLSL output potentially quadratic in the size of the
1738                // Naga IR.
1739                //
1740                // [^yes]: ```hlsl
1741                // case 1:
1742                // case 2: do_stuff()
1743                // ```
1744                // [^no]: ```hlsl
1745                // case 1: do_this();
1746                // case 2: do_that();
1747                // ```
1748                if case.fall_through && !case.body.is_empty() {
1749                    let curr_len = i + 1;
1750                    let end_case_idx = curr_len
1751                        + cases
1752                            .iter()
1753                            .skip(curr_len)
1754                            .position(|case| !case.fall_through)
1755                            .unwrap();
1756                    let indent_level_3 = indent_level_2.next();
1757                    for case in &cases[i..=end_case_idx] {
1758                        writeln!(self.out, "{indent_level_2}{{")?;
1759                        let prev_len = self.named_expressions.len();
1760                        for sta in case.body.iter() {
1761                            self.write_stmt(module, sta, func_ctx, indent_level_3)?;
1762                        }
1763                        // Clear all named expressions that were previously inserted by the statements in the block
1764                        self.named_expressions.truncate(prev_len);
1765                        writeln!(self.out, "{indent_level_2}}}")?;
1766                    }
1767
1768                    let last_case = &cases[end_case_idx];
1769                    if last_case.body.last().is_none_or(|s| !s.is_terminator()) {
1770                        writeln!(self.out, "{indent_level_2}break;")?;
1771                    }
1772                } else {
1773                    for sta in case.body.iter() {
1774                        self.write_stmt(module, sta, func_ctx, indent_level_2)?;
1775                    }
1776                    if !case.fall_through && case.body.last().is_none_or(|s| !s.is_terminator()) {
1777                        writeln!(self.out, "{indent_level_2}break;")?;
1778                    }
1779                }
1780
1781                if write_block_braces {
1782                    writeln!(self.out, "{indent_level_1}}}")?;
1783                }
1784            }
1785
1786            writeln!(self.out, "{level}}}")?;
1787        }
1788
1789        // Handle any forwarded continue statements.
1790        use back::continue_forward::ExitControlFlow;
1791        let op = match self.continue_ctx.exit_switch() {
1792            ExitControlFlow::None => None,
1793            ExitControlFlow::Continue { variable } => Some(("continue", variable)),
1794            ExitControlFlow::Break { variable } => Some(("break", variable)),
1795        };
1796        if let Some((control_flow, variable)) = op {
1797            writeln!(self.out, "{level}if ({variable}) {{")?;
1798            writeln!(self.out, "{indent_level_1}{control_flow};")?;
1799            writeln!(self.out, "{level}}}")?;
1800        }
1801
1802        Ok(())
1803    }
1804
1805    fn write_index(
1806        &mut self,
1807        module: &Module,
1808        index: Index,
1809        func_ctx: &back::FunctionCtx<'_>,
1810    ) -> BackendResult {
1811        match index {
1812            Index::Static(index) => {
1813                write!(self.out, "{index}")?;
1814            }
1815            Index::Expression(index) => {
1816                self.write_expr(module, index, func_ctx)?;
1817            }
1818        }
1819        Ok(())
1820    }
1821
1822    /// Helper method used to write statements
1823    ///
1824    /// # Notes
1825    /// Always adds a newline
1826    fn write_stmt(
1827        &mut self,
1828        module: &Module,
1829        stmt: &crate::Statement,
1830        func_ctx: &back::FunctionCtx<'_>,
1831        level: back::Level,
1832    ) -> BackendResult {
1833        use crate::Statement;
1834
1835        match *stmt {
1836            Statement::Emit(ref range) => {
1837                for handle in range.clone() {
1838                    let ptr_class = func_ctx.resolve_type(handle, &module.types).pointer_space();
1839                    let expr_name = if ptr_class.is_some() {
1840                        // HLSL can't save a pointer-valued expression in a variable,
1841                        // but we shouldn't ever need to: they should never be named expressions,
1842                        // and none of the expression types flagged by bake_ref_count can be pointer-valued.
1843                        None
1844                    } else if let Some(name) = func_ctx.named_expressions.get(&handle) {
1845                        // Front end provides names for all variables at the start of writing.
1846                        // But we write them to step by step. We need to recache them
1847                        // Otherwise, we could accidentally write variable name instead of full expression.
1848                        // Also, we use sanitized names! It defense backend from generating variable with name from reserved keywords.
1849                        Some(self.namer.call(name))
1850                    } else if self.need_bake_expressions.contains(&handle) {
1851                        Some(Baked(handle).to_string())
1852                    } else {
1853                        None
1854                    };
1855
1856                    if let Some(name) = expr_name {
1857                        write!(self.out, "{level}")?;
1858                        self.write_named_expr(module, handle, name, handle, func_ctx)?;
1859                    }
1860                }
1861            }
1862            // TODO: copy-paste from glsl-out
1863            Statement::Block(ref block) => {
1864                write!(self.out, "{level}")?;
1865                writeln!(self.out, "{{")?;
1866                for sta in block.iter() {
1867                    // Increase the indentation to help with readability
1868                    self.write_stmt(module, sta, func_ctx, level.next())?
1869                }
1870                writeln!(self.out, "{level}}}")?
1871            }
1872            // TODO: copy-paste from glsl-out
1873            Statement::If {
1874                condition,
1875                ref accept,
1876                ref reject,
1877            } => {
1878                write!(self.out, "{level}")?;
1879                write!(self.out, "if (")?;
1880                self.write_expr(module, condition, func_ctx)?;
1881                writeln!(self.out, ") {{")?;
1882
1883                let l2 = level.next();
1884                for sta in accept {
1885                    // Increase indentation to help with readability
1886                    self.write_stmt(module, sta, func_ctx, l2)?;
1887                }
1888
1889                // If there are no statements in the reject block we skip writing it
1890                // This is only for readability
1891                if !reject.is_empty() {
1892                    writeln!(self.out, "{level}}} else {{")?;
1893
1894                    for sta in reject {
1895                        // Increase indentation to help with readability
1896                        self.write_stmt(module, sta, func_ctx, l2)?;
1897                    }
1898                }
1899
1900                writeln!(self.out, "{level}}}")?
1901            }
1902            // TODO: copy-paste from glsl-out
1903            Statement::Kill => writeln!(self.out, "{level}discard;")?,
1904            Statement::Return { value: None } => {
1905                writeln!(self.out, "{level}return;")?;
1906            }
1907            Statement::Return { value: Some(expr) } => {
1908                let base_ty_res = &func_ctx.info[expr].ty;
1909                let mut resolved = base_ty_res.inner_with(&module.types);
1910                if let TypeInner::Pointer { base, space: _ } = *resolved {
1911                    resolved = &module.types[base].inner;
1912                }
1913
1914                if let TypeInner::Struct { .. } = *resolved {
1915                    // We can safely unwrap here, since we now we working with struct
1916                    let ty = base_ty_res.handle().unwrap();
1917                    let struct_name = &self.names[&NameKey::Type(ty)];
1918                    let variable_name = self.namer.call(&struct_name.to_lowercase());
1919                    write!(self.out, "{level}const {struct_name} {variable_name} = ",)?;
1920                    self.write_expr(module, expr, func_ctx)?;
1921                    writeln!(self.out, ";")?;
1922
1923                    // for entry point returns, we may need to reshuffle the outputs into a different struct
1924                    let ep_output = match func_ctx.ty {
1925                        back::FunctionType::Function(_) => None,
1926                        back::FunctionType::EntryPoint(index) => self
1927                            .entry_point_io
1928                            .get(&(index as usize))
1929                            .unwrap()
1930                            .output
1931                            .as_ref(),
1932                    };
1933                    let final_name = match ep_output {
1934                        Some(ep_output) => {
1935                            let final_name = self.namer.call(&variable_name);
1936                            write!(
1937                                self.out,
1938                                "{}const {} {} = {{ ",
1939                                level, ep_output.ty_name, final_name,
1940                            )?;
1941                            for (index, m) in ep_output.members.iter().enumerate() {
1942                                if index != 0 {
1943                                    write!(self.out, ", ")?;
1944                                }
1945                                let member_name = &self.names[&NameKey::StructMember(ty, m.index)];
1946                                write!(self.out, "{variable_name}.{member_name}")?;
1947                            }
1948                            writeln!(self.out, " }};")?;
1949                            final_name
1950                        }
1951                        None => variable_name,
1952                    };
1953                    writeln!(self.out, "{level}return {final_name};")?;
1954                } else {
1955                    write!(self.out, "{level}return ")?;
1956                    self.write_expr(module, expr, func_ctx)?;
1957                    writeln!(self.out, ";")?
1958                }
1959            }
1960            Statement::Store { pointer, value } => {
1961                let ty_inner = func_ctx.resolve_type(pointer, &module.types);
1962                if let Some(crate::AddressSpace::Storage { .. }) = ty_inner.pointer_space() {
1963                    let var_handle = self.fill_access_chain(module, pointer, func_ctx)?;
1964                    self.write_storage_store(
1965                        module,
1966                        var_handle,
1967                        StoreValue::Expression(value),
1968                        func_ctx,
1969                        level,
1970                        None,
1971                    )?;
1972                } else {
1973                    // We treat matrices of the form `matCx2` as a sequence of C `vec2`s.
1974                    // See the module-level block comment in mod.rs for details.
1975                    //
1976                    // We handle matrix Stores here directly (including sub accesses for Vectors and Scalars).
1977                    // Loads are handled by `Expression::AccessIndex` (since sub accesses work fine for Loads).
1978                    enum MatrixAccess {
1979                        Direct {
1980                            base: Handle<crate::Expression>,
1981                            index: u32,
1982                        },
1983                        Struct {
1984                            columns: crate::VectorSize,
1985                            base: Handle<crate::Expression>,
1986                        },
1987                    }
1988
1989                    let get_members = |expr: Handle<crate::Expression>| {
1990                        let resolved = func_ctx.resolve_type(expr, &module.types);
1991                        match *resolved {
1992                            TypeInner::Pointer { base, .. } => match module.types[base].inner {
1993                                TypeInner::Struct { ref members, .. } => Some(members),
1994                                _ => None,
1995                            },
1996                            _ => None,
1997                        }
1998                    };
1999
2000                    write!(self.out, "{level}")?;
2001
2002                    let matrix_access_on_lhs =
2003                        find_matrix_in_access_chain(module, pointer, func_ctx).and_then(
2004                            |(matrix_expr, vector, scalar)| match (
2005                                func_ctx.resolve_type(matrix_expr, &module.types),
2006                                &func_ctx.expressions[matrix_expr],
2007                            ) {
2008                                (
2009                                    &TypeInner::Pointer { base: ty, .. },
2010                                    &crate::Expression::AccessIndex { base, index },
2011                                ) if matches!(
2012                                    module.types[ty].inner,
2013                                    TypeInner::Matrix {
2014                                        rows: crate::VectorSize::Bi,
2015                                        ..
2016                                    }
2017                                ) && get_members(base)
2018                                    .map(|members| members[index as usize].binding.is_none())
2019                                    == Some(true) =>
2020                                {
2021                                    Some((MatrixAccess::Direct { base, index }, vector, scalar))
2022                                }
2023                                _ => {
2024                                    if let Some(MatrixType {
2025                                        columns,
2026                                        rows: crate::VectorSize::Bi,
2027                                        width: 4,
2028                                    }) = get_inner_matrix_of_struct_array_member(
2029                                        module,
2030                                        matrix_expr,
2031                                        func_ctx,
2032                                        true,
2033                                    ) {
2034                                        Some((
2035                                            MatrixAccess::Struct {
2036                                                columns,
2037                                                base: matrix_expr,
2038                                            },
2039                                            vector,
2040                                            scalar,
2041                                        ))
2042                                    } else {
2043                                        None
2044                                    }
2045                                }
2046                            },
2047                        );
2048
2049                    match matrix_access_on_lhs {
2050                        Some((MatrixAccess::Direct { index, base }, vector, scalar)) => {
2051                            let base_ty_res = &func_ctx.info[base].ty;
2052                            let resolved = base_ty_res.inner_with(&module.types);
2053                            let ty = match *resolved {
2054                                TypeInner::Pointer { base, .. } => base,
2055                                _ => base_ty_res.handle().unwrap(),
2056                            };
2057
2058                            if let Some(Index::Static(vec_index)) = vector {
2059                                self.write_expr(module, base, func_ctx)?;
2060                                write!(
2061                                    self.out,
2062                                    ".{}_{}",
2063                                    &self.names[&NameKey::StructMember(ty, index)],
2064                                    vec_index
2065                                )?;
2066
2067                                if let Some(scalar_index) = scalar {
2068                                    write!(self.out, "[")?;
2069                                    self.write_index(module, scalar_index, func_ctx)?;
2070                                    write!(self.out, "]")?;
2071                                }
2072
2073                                write!(self.out, " = ")?;
2074                                self.write_expr(module, value, func_ctx)?;
2075                                writeln!(self.out, ";")?;
2076                            } else {
2077                                let access = WrappedStructMatrixAccess { ty, index };
2078                                match (&vector, &scalar) {
2079                                    (&Some(_), &Some(_)) => {
2080                                        self.write_wrapped_struct_matrix_set_scalar_function_name(
2081                                            access,
2082                                        )?;
2083                                    }
2084                                    (&Some(_), &None) => {
2085                                        self.write_wrapped_struct_matrix_set_vec_function_name(
2086                                            access,
2087                                        )?;
2088                                    }
2089                                    (&None, _) => {
2090                                        self.write_wrapped_struct_matrix_set_function_name(access)?;
2091                                    }
2092                                }
2093
2094                                write!(self.out, "(")?;
2095                                self.write_expr(module, base, func_ctx)?;
2096                                write!(self.out, ", ")?;
2097                                self.write_expr(module, value, func_ctx)?;
2098
2099                                if let Some(Index::Expression(vec_index)) = vector {
2100                                    write!(self.out, ", ")?;
2101                                    self.write_expr(module, vec_index, func_ctx)?;
2102
2103                                    if let Some(scalar_index) = scalar {
2104                                        write!(self.out, ", ")?;
2105                                        self.write_index(module, scalar_index, func_ctx)?;
2106                                    }
2107                                }
2108                                writeln!(self.out, ");")?;
2109                            }
2110                        }
2111                        Some((
2112                            MatrixAccess::Struct { columns, base },
2113                            Some(Index::Expression(vec_index)),
2114                            scalar,
2115                        )) => {
2116                            // We handle `Store`s to __matCx2 column vectors and scalar elements via
2117                            // the previously injected functions __set_col_of_matCx2 / __set_el_of_matCx2.
2118
2119                            if scalar.is_some() {
2120                                write!(self.out, "__set_el_of_mat{}x2", columns as u8)?;
2121                            } else {
2122                                write!(self.out, "__set_col_of_mat{}x2", columns as u8)?;
2123                            }
2124                            write!(self.out, "(")?;
2125                            self.write_expr(module, base, func_ctx)?;
2126                            write!(self.out, ", ")?;
2127                            self.write_expr(module, vec_index, func_ctx)?;
2128
2129                            if let Some(scalar_index) = scalar {
2130                                write!(self.out, ", ")?;
2131                                self.write_index(module, scalar_index, func_ctx)?;
2132                            }
2133
2134                            write!(self.out, ", ")?;
2135                            self.write_expr(module, value, func_ctx)?;
2136
2137                            writeln!(self.out, ");")?;
2138                        }
2139                        Some((MatrixAccess::Struct { .. }, Some(Index::Static(_)), _))
2140                        | Some((MatrixAccess::Struct { .. }, None, _))
2141                        | None => {
2142                            self.write_expr(module, pointer, func_ctx)?;
2143                            write!(self.out, " = ")?;
2144
2145                            // We cast the RHS of this store in cases where the LHS
2146                            // is a struct member with type:
2147                            //  - matCx2 or
2148                            //  - a (possibly nested) array of matCx2's
2149                            if let Some(MatrixType {
2150                                columns,
2151                                rows: crate::VectorSize::Bi,
2152                                width: 4,
2153                            }) = get_inner_matrix_of_struct_array_member(
2154                                module, pointer, func_ctx, false,
2155                            ) {
2156                                let mut resolved = func_ctx.resolve_type(pointer, &module.types);
2157                                if let TypeInner::Pointer { base, .. } = *resolved {
2158                                    resolved = &module.types[base].inner;
2159                                }
2160
2161                                write!(self.out, "(__mat{}x2", columns as u8)?;
2162                                if let TypeInner::Array { base, size, .. } = *resolved {
2163                                    self.write_array_size(module, base, size)?;
2164                                }
2165                                write!(self.out, ")")?;
2166                            }
2167
2168                            self.write_expr(module, value, func_ctx)?;
2169                            writeln!(self.out, ";")?
2170                        }
2171                    }
2172                }
2173            }
2174            Statement::Loop {
2175                ref body,
2176                ref continuing,
2177                break_if,
2178            } => {
2179                let force_loop_bound_statements = self.gen_force_bounded_loop_statements(level);
2180                let gate_name = (!continuing.is_empty() || break_if.is_some())
2181                    .then(|| self.namer.call("loop_init"));
2182
2183                if let Some((ref decl, _)) = force_loop_bound_statements {
2184                    writeln!(self.out, "{decl}")?;
2185                }
2186                if let Some(ref gate_name) = gate_name {
2187                    writeln!(self.out, "{level}bool {gate_name} = true;")?;
2188                }
2189
2190                self.continue_ctx.enter_loop();
2191                writeln!(self.out, "{level}while(true) {{")?;
2192                if let Some((_, ref break_and_inc)) = force_loop_bound_statements {
2193                    writeln!(self.out, "{break_and_inc}")?;
2194                }
2195                let l2 = level.next();
2196                if let Some(gate_name) = gate_name {
2197                    writeln!(self.out, "{l2}if (!{gate_name}) {{")?;
2198                    let l3 = l2.next();
2199                    for sta in continuing.iter() {
2200                        self.write_stmt(module, sta, func_ctx, l3)?;
2201                    }
2202                    if let Some(condition) = break_if {
2203                        write!(self.out, "{l3}if (")?;
2204                        self.write_expr(module, condition, func_ctx)?;
2205                        writeln!(self.out, ") {{")?;
2206                        writeln!(self.out, "{}break;", l3.next())?;
2207                        writeln!(self.out, "{l3}}}")?;
2208                    }
2209                    writeln!(self.out, "{l2}}}")?;
2210                    writeln!(self.out, "{l2}{gate_name} = false;")?;
2211                }
2212
2213                for sta in body.iter() {
2214                    self.write_stmt(module, sta, func_ctx, l2)?;
2215                }
2216
2217                writeln!(self.out, "{level}}}")?;
2218                self.continue_ctx.exit_loop();
2219            }
2220            Statement::Break => writeln!(self.out, "{level}break;")?,
2221            Statement::Continue => {
2222                if let Some(variable) = self.continue_ctx.continue_encountered() {
2223                    writeln!(self.out, "{level}{variable} = true;")?;
2224                    writeln!(self.out, "{level}break;")?
2225                } else {
2226                    writeln!(self.out, "{level}continue;")?
2227                }
2228            }
2229            Statement::ControlBarrier(barrier) => {
2230                self.write_control_barrier(barrier, level)?;
2231            }
2232            Statement::MemoryBarrier(barrier) => {
2233                self.write_memory_barrier(barrier, level)?;
2234            }
2235            Statement::ImageStore {
2236                image,
2237                coordinate,
2238                array_index,
2239                value,
2240            } => {
2241                write!(self.out, "{level}")?;
2242                self.write_expr(module, image, func_ctx)?;
2243
2244                write!(self.out, "[")?;
2245                if let Some(index) = array_index {
2246                    // Array index accepted only for texture_storage_2d_array, so we can safety use int3(coordinate, array_index) here
2247                    write!(self.out, "int3(")?;
2248                    self.write_expr(module, coordinate, func_ctx)?;
2249                    write!(self.out, ", ")?;
2250                    self.write_expr(module, index, func_ctx)?;
2251                    write!(self.out, ")")?;
2252                } else {
2253                    self.write_expr(module, coordinate, func_ctx)?;
2254                }
2255                write!(self.out, "]")?;
2256
2257                write!(self.out, " = ")?;
2258                self.write_expr(module, value, func_ctx)?;
2259                writeln!(self.out, ";")?;
2260            }
2261            Statement::Call {
2262                function,
2263                ref arguments,
2264                result,
2265            } => {
2266                write!(self.out, "{level}")?;
2267                if let Some(expr) = result {
2268                    write!(self.out, "const ")?;
2269                    let name = Baked(expr).to_string();
2270                    let expr_ty = &func_ctx.info[expr].ty;
2271                    let ty_inner = match *expr_ty {
2272                        proc::TypeResolution::Handle(handle) => {
2273                            self.write_type(module, handle)?;
2274                            &module.types[handle].inner
2275                        }
2276                        proc::TypeResolution::Value(ref value) => {
2277                            self.write_value_type(module, value)?;
2278                            value
2279                        }
2280                    };
2281                    write!(self.out, " {name}")?;
2282                    if let TypeInner::Array { base, size, .. } = *ty_inner {
2283                        self.write_array_size(module, base, size)?;
2284                    }
2285                    write!(self.out, " = ")?;
2286                    self.named_expressions.insert(expr, name);
2287                }
2288                let func_name = &self.names[&NameKey::Function(function)];
2289                write!(self.out, "{func_name}(")?;
2290                for (index, argument) in arguments.iter().enumerate() {
2291                    if index != 0 {
2292                        write!(self.out, ", ")?;
2293                    }
2294                    self.write_expr(module, *argument, func_ctx)?;
2295                }
2296                writeln!(self.out, ");")?
2297            }
2298            Statement::Atomic {
2299                pointer,
2300                ref fun,
2301                value,
2302                result,
2303            } => {
2304                write!(self.out, "{level}")?;
2305                let res_var_info = if let Some(res_handle) = result {
2306                    let name = Baked(res_handle).to_string();
2307                    match func_ctx.info[res_handle].ty {
2308                        proc::TypeResolution::Handle(handle) => self.write_type(module, handle)?,
2309                        proc::TypeResolution::Value(ref value) => {
2310                            self.write_value_type(module, value)?
2311                        }
2312                    };
2313                    write!(self.out, " {name}; ")?;
2314                    self.named_expressions.insert(res_handle, name.clone());
2315                    Some((res_handle, name))
2316                } else {
2317                    None
2318                };
2319                let pointer_space = func_ctx
2320                    .resolve_type(pointer, &module.types)
2321                    .pointer_space()
2322                    .unwrap();
2323                let fun_str = fun.to_hlsl_suffix();
2324                let compare_expr = match *fun {
2325                    crate::AtomicFunction::Exchange { compare: Some(cmp) } => Some(cmp),
2326                    _ => None,
2327                };
2328                match pointer_space {
2329                    crate::AddressSpace::WorkGroup => {
2330                        write!(self.out, "Interlocked{fun_str}(")?;
2331                        self.write_expr(module, pointer, func_ctx)?;
2332                        self.emit_hlsl_atomic_tail(
2333                            module,
2334                            func_ctx,
2335                            fun,
2336                            compare_expr,
2337                            value,
2338                            &res_var_info,
2339                        )?;
2340                    }
2341                    crate::AddressSpace::Storage { .. } => {
2342                        let var_handle = self.fill_access_chain(module, pointer, func_ctx)?;
2343                        let var_name = &self.names[&NameKey::GlobalVariable(var_handle)];
2344                        let width = match func_ctx.resolve_type(value, &module.types) {
2345                            &TypeInner::Scalar(Scalar { width: 8, .. }) => "64",
2346                            _ => "",
2347                        };
2348                        write!(self.out, "{var_name}.Interlocked{fun_str}{width}(")?;
2349                        let chain = mem::take(&mut self.temp_access_chain);
2350                        self.write_storage_address(module, &chain, func_ctx)?;
2351                        self.temp_access_chain = chain;
2352                        self.emit_hlsl_atomic_tail(
2353                            module,
2354                            func_ctx,
2355                            fun,
2356                            compare_expr,
2357                            value,
2358                            &res_var_info,
2359                        )?;
2360                    }
2361                    ref other => {
2362                        return Err(Error::Custom(format!(
2363                            "invalid address space {other:?} for atomic statement"
2364                        )))
2365                    }
2366                }
2367                if let Some(cmp) = compare_expr {
2368                    if let Some(&(_res_handle, ref res_name)) = res_var_info.as_ref() {
2369                        write!(
2370                            self.out,
2371                            "{level}{res_name}.exchanged = ({res_name}.old_value == "
2372                        )?;
2373                        self.write_expr(module, cmp, func_ctx)?;
2374                        writeln!(self.out, ");")?;
2375                    }
2376                }
2377            }
2378            Statement::ImageAtomic {
2379                image,
2380                coordinate,
2381                array_index,
2382                fun,
2383                value,
2384            } => {
2385                write!(self.out, "{level}")?;
2386
2387                let fun_str = fun.to_hlsl_suffix();
2388                write!(self.out, "Interlocked{fun_str}(")?;
2389                self.write_expr(module, image, func_ctx)?;
2390                write!(self.out, "[")?;
2391                self.write_texture_coordinates(
2392                    "int",
2393                    coordinate,
2394                    array_index,
2395                    None,
2396                    module,
2397                    func_ctx,
2398                )?;
2399                write!(self.out, "],")?;
2400
2401                self.write_expr(module, value, func_ctx)?;
2402                writeln!(self.out, ");")?;
2403            }
2404            Statement::WorkGroupUniformLoad { pointer, result } => {
2405                self.write_control_barrier(crate::Barrier::WORK_GROUP, level)?;
2406                write!(self.out, "{level}")?;
2407                let name = Baked(result).to_string();
2408                self.write_named_expr(module, pointer, name, result, func_ctx)?;
2409
2410                self.write_control_barrier(crate::Barrier::WORK_GROUP, level)?;
2411            }
2412            Statement::Switch {
2413                selector,
2414                ref cases,
2415            } => {
2416                self.write_switch(module, func_ctx, level, selector, cases)?;
2417            }
2418            Statement::RayQuery { query, ref fun } => match *fun {
2419                RayQueryFunction::Initialize {
2420                    acceleration_structure,
2421                    descriptor,
2422                } => {
2423                    write!(self.out, "{level}")?;
2424                    self.write_expr(module, query, func_ctx)?;
2425                    write!(self.out, ".TraceRayInline(")?;
2426                    self.write_expr(module, acceleration_structure, func_ctx)?;
2427                    write!(self.out, ", ")?;
2428                    self.write_expr(module, descriptor, func_ctx)?;
2429                    write!(self.out, ".flags, ")?;
2430                    self.write_expr(module, descriptor, func_ctx)?;
2431                    write!(self.out, ".cull_mask, ")?;
2432                    write!(self.out, "RayDescFromRayDesc_(")?;
2433                    self.write_expr(module, descriptor, func_ctx)?;
2434                    writeln!(self.out, "));")?;
2435                }
2436                RayQueryFunction::Proceed { result } => {
2437                    write!(self.out, "{level}")?;
2438                    let name = Baked(result).to_string();
2439                    write!(self.out, "const bool {name} = ")?;
2440                    self.named_expressions.insert(result, name);
2441                    self.write_expr(module, query, func_ctx)?;
2442                    writeln!(self.out, ".Proceed();")?;
2443                }
2444                RayQueryFunction::GenerateIntersection { hit_t } => {
2445                    write!(self.out, "{level}")?;
2446                    self.write_expr(module, query, func_ctx)?;
2447                    write!(self.out, ".CommitProceduralPrimitiveHit(")?;
2448                    self.write_expr(module, hit_t, func_ctx)?;
2449                    writeln!(self.out, ");")?;
2450                }
2451                RayQueryFunction::ConfirmIntersection => {
2452                    write!(self.out, "{level}")?;
2453                    self.write_expr(module, query, func_ctx)?;
2454                    writeln!(self.out, ".CommitNonOpaqueTriangleHit();")?;
2455                }
2456                RayQueryFunction::Terminate => {
2457                    write!(self.out, "{level}")?;
2458                    self.write_expr(module, query, func_ctx)?;
2459                    writeln!(self.out, ".Abort();")?;
2460                }
2461            },
2462            Statement::SubgroupBallot { result, predicate } => {
2463                write!(self.out, "{level}")?;
2464                let name = Baked(result).to_string();
2465                write!(self.out, "const uint4 {name} = ")?;
2466                self.named_expressions.insert(result, name);
2467
2468                write!(self.out, "WaveActiveBallot(")?;
2469                match predicate {
2470                    Some(predicate) => self.write_expr(module, predicate, func_ctx)?,
2471                    None => write!(self.out, "true")?,
2472                }
2473                writeln!(self.out, ");")?;
2474            }
2475            Statement::SubgroupCollectiveOperation {
2476                op,
2477                collective_op,
2478                argument,
2479                result,
2480            } => {
2481                write!(self.out, "{level}")?;
2482                write!(self.out, "const ")?;
2483                let name = Baked(result).to_string();
2484                match func_ctx.info[result].ty {
2485                    proc::TypeResolution::Handle(handle) => self.write_type(module, handle)?,
2486                    proc::TypeResolution::Value(ref value) => {
2487                        self.write_value_type(module, value)?
2488                    }
2489                };
2490                write!(self.out, " {name} = ")?;
2491                self.named_expressions.insert(result, name);
2492
2493                match (collective_op, op) {
2494                    (crate::CollectiveOperation::Reduce, crate::SubgroupOperation::All) => {
2495                        write!(self.out, "WaveActiveAllTrue(")?
2496                    }
2497                    (crate::CollectiveOperation::Reduce, crate::SubgroupOperation::Any) => {
2498                        write!(self.out, "WaveActiveAnyTrue(")?
2499                    }
2500                    (crate::CollectiveOperation::Reduce, crate::SubgroupOperation::Add) => {
2501                        write!(self.out, "WaveActiveSum(")?
2502                    }
2503                    (crate::CollectiveOperation::Reduce, crate::SubgroupOperation::Mul) => {
2504                        write!(self.out, "WaveActiveProduct(")?
2505                    }
2506                    (crate::CollectiveOperation::Reduce, crate::SubgroupOperation::Max) => {
2507                        write!(self.out, "WaveActiveMax(")?
2508                    }
2509                    (crate::CollectiveOperation::Reduce, crate::SubgroupOperation::Min) => {
2510                        write!(self.out, "WaveActiveMin(")?
2511                    }
2512                    (crate::CollectiveOperation::Reduce, crate::SubgroupOperation::And) => {
2513                        write!(self.out, "WaveActiveBitAnd(")?
2514                    }
2515                    (crate::CollectiveOperation::Reduce, crate::SubgroupOperation::Or) => {
2516                        write!(self.out, "WaveActiveBitOr(")?
2517                    }
2518                    (crate::CollectiveOperation::Reduce, crate::SubgroupOperation::Xor) => {
2519                        write!(self.out, "WaveActiveBitXor(")?
2520                    }
2521                    (crate::CollectiveOperation::ExclusiveScan, crate::SubgroupOperation::Add) => {
2522                        write!(self.out, "WavePrefixSum(")?
2523                    }
2524                    (crate::CollectiveOperation::ExclusiveScan, crate::SubgroupOperation::Mul) => {
2525                        write!(self.out, "WavePrefixProduct(")?
2526                    }
2527                    (crate::CollectiveOperation::InclusiveScan, crate::SubgroupOperation::Add) => {
2528                        self.write_expr(module, argument, func_ctx)?;
2529                        write!(self.out, " + WavePrefixSum(")?;
2530                    }
2531                    (crate::CollectiveOperation::InclusiveScan, crate::SubgroupOperation::Mul) => {
2532                        self.write_expr(module, argument, func_ctx)?;
2533                        write!(self.out, " * WavePrefixProduct(")?;
2534                    }
2535                    _ => unimplemented!(),
2536                }
2537                self.write_expr(module, argument, func_ctx)?;
2538                writeln!(self.out, ");")?;
2539            }
2540            Statement::SubgroupGather {
2541                mode,
2542                argument,
2543                result,
2544            } => {
2545                write!(self.out, "{level}")?;
2546                write!(self.out, "const ")?;
2547                let name = Baked(result).to_string();
2548                match func_ctx.info[result].ty {
2549                    proc::TypeResolution::Handle(handle) => self.write_type(module, handle)?,
2550                    proc::TypeResolution::Value(ref value) => {
2551                        self.write_value_type(module, value)?
2552                    }
2553                };
2554                write!(self.out, " {name} = ")?;
2555                self.named_expressions.insert(result, name);
2556                match mode {
2557                    crate::GatherMode::BroadcastFirst => {
2558                        write!(self.out, "WaveReadLaneFirst(")?;
2559                        self.write_expr(module, argument, func_ctx)?;
2560                    }
2561                    crate::GatherMode::QuadBroadcast(index) => {
2562                        write!(self.out, "QuadReadLaneAt(")?;
2563                        self.write_expr(module, argument, func_ctx)?;
2564                        write!(self.out, ", ")?;
2565                        self.write_expr(module, index, func_ctx)?;
2566                    }
2567                    crate::GatherMode::QuadSwap(direction) => {
2568                        match direction {
2569                            crate::Direction::X => {
2570                                write!(self.out, "QuadReadAcrossX(")?;
2571                            }
2572                            crate::Direction::Y => {
2573                                write!(self.out, "QuadReadAcrossY(")?;
2574                            }
2575                            crate::Direction::Diagonal => {
2576                                write!(self.out, "QuadReadAcrossDiagonal(")?;
2577                            }
2578                        }
2579                        self.write_expr(module, argument, func_ctx)?;
2580                    }
2581                    _ => {
2582                        write!(self.out, "WaveReadLaneAt(")?;
2583                        self.write_expr(module, argument, func_ctx)?;
2584                        write!(self.out, ", ")?;
2585                        match mode {
2586                            crate::GatherMode::BroadcastFirst => unreachable!(),
2587                            crate::GatherMode::Broadcast(index)
2588                            | crate::GatherMode::Shuffle(index) => {
2589                                self.write_expr(module, index, func_ctx)?;
2590                            }
2591                            crate::GatherMode::ShuffleDown(index) => {
2592                                write!(self.out, "WaveGetLaneIndex() + ")?;
2593                                self.write_expr(module, index, func_ctx)?;
2594                            }
2595                            crate::GatherMode::ShuffleUp(index) => {
2596                                write!(self.out, "WaveGetLaneIndex() - ")?;
2597                                self.write_expr(module, index, func_ctx)?;
2598                            }
2599                            crate::GatherMode::ShuffleXor(index) => {
2600                                write!(self.out, "WaveGetLaneIndex() ^ ")?;
2601                                self.write_expr(module, index, func_ctx)?;
2602                            }
2603                            crate::GatherMode::QuadBroadcast(_) => unreachable!(),
2604                            crate::GatherMode::QuadSwap(_) => unreachable!(),
2605                        }
2606                    }
2607                }
2608                writeln!(self.out, ");")?;
2609            }
2610        }
2611
2612        Ok(())
2613    }
2614
2615    fn write_const_expression(
2616        &mut self,
2617        module: &Module,
2618        expr: Handle<crate::Expression>,
2619        arena: &crate::Arena<crate::Expression>,
2620    ) -> BackendResult {
2621        self.write_possibly_const_expression(module, expr, arena, |writer, expr| {
2622            writer.write_const_expression(module, expr, arena)
2623        })
2624    }
2625
2626    pub(super) fn write_literal(&mut self, literal: crate::Literal) -> BackendResult {
2627        match literal {
2628            crate::Literal::F64(value) => write!(self.out, "{value:?}L")?,
2629            crate::Literal::F32(value) => write!(self.out, "{value:?}")?,
2630            crate::Literal::F16(value) => write!(self.out, "{value:?}h")?,
2631            crate::Literal::U32(value) => write!(self.out, "{value}u")?,
2632            // `-2147483648` is parsed by some compilers as unary negation of
2633            // positive 2147483648, which is too large for an int, causing
2634            // issues for some compilers. Neither DXC nor FXC appear to have
2635            // this problem, but this is not specified and could change. We
2636            // therefore use `-2147483647 - 1` as a precaution.
2637            crate::Literal::I32(value) if value == i32::MIN => {
2638                write!(self.out, "int({} - 1)", value + 1)?
2639            }
2640            // HLSL has no suffix for explicit i32 literals, but not using any suffix
2641            // makes the type ambiguous which prevents overload resolution from
2642            // working. So we explicitly use the int() constructor syntax.
2643            crate::Literal::I32(value) => write!(self.out, "int({value})")?,
2644            crate::Literal::U64(value) => write!(self.out, "{value}uL")?,
2645            // I64 version of the minimum I32 value issue described above.
2646            crate::Literal::I64(value) if value == i64::MIN => {
2647                write!(self.out, "({}L - 1L)", value + 1)?;
2648            }
2649            crate::Literal::I64(value) => write!(self.out, "{value}L")?,
2650            crate::Literal::Bool(value) => write!(self.out, "{value}")?,
2651            crate::Literal::AbstractInt(_) | crate::Literal::AbstractFloat(_) => {
2652                return Err(Error::Custom(
2653                    "Abstract types should not appear in IR presented to backends".into(),
2654                ));
2655            }
2656        }
2657        Ok(())
2658    }
2659
2660    fn write_possibly_const_expression<E>(
2661        &mut self,
2662        module: &Module,
2663        expr: Handle<crate::Expression>,
2664        expressions: &crate::Arena<crate::Expression>,
2665        write_expression: E,
2666    ) -> BackendResult
2667    where
2668        E: Fn(&mut Self, Handle<crate::Expression>) -> BackendResult,
2669    {
2670        use crate::Expression;
2671
2672        match expressions[expr] {
2673            Expression::Literal(literal) => {
2674                self.write_literal(literal)?;
2675            }
2676            Expression::Constant(handle) => {
2677                let constant = &module.constants[handle];
2678                if constant.name.is_some() {
2679                    write!(self.out, "{}", self.names[&NameKey::Constant(handle)])?;
2680                } else {
2681                    self.write_const_expression(module, constant.init, &module.global_expressions)?;
2682                }
2683            }
2684            Expression::ZeroValue(ty) => {
2685                self.write_wrapped_zero_value_function_name(module, WrappedZeroValue { ty })?;
2686                write!(self.out, "()")?;
2687            }
2688            Expression::Compose { ty, ref components } => {
2689                match module.types[ty].inner {
2690                    TypeInner::Struct { .. } | TypeInner::Array { .. } => {
2691                        self.write_wrapped_constructor_function_name(
2692                            module,
2693                            WrappedConstructor { ty },
2694                        )?;
2695                    }
2696                    _ => {
2697                        self.write_type(module, ty)?;
2698                    }
2699                };
2700                write!(self.out, "(")?;
2701                for (index, component) in components.iter().enumerate() {
2702                    if index != 0 {
2703                        write!(self.out, ", ")?;
2704                    }
2705                    write_expression(self, *component)?;
2706                }
2707                write!(self.out, ")")?;
2708            }
2709            Expression::Splat { size, value } => {
2710                // hlsl is not supported one value constructor
2711                // if we write, for example, int4(0), dxc returns error:
2712                // error: too few elements in vector initialization (expected 4 elements, have 1)
2713                let number_of_components = match size {
2714                    crate::VectorSize::Bi => "xx",
2715                    crate::VectorSize::Tri => "xxx",
2716                    crate::VectorSize::Quad => "xxxx",
2717                };
2718                write!(self.out, "(")?;
2719                write_expression(self, value)?;
2720                write!(self.out, ").{number_of_components}")?
2721            }
2722            _ => {
2723                return Err(Error::Override);
2724            }
2725        }
2726
2727        Ok(())
2728    }
2729
2730    /// Helper method to write expressions
2731    ///
2732    /// # Notes
2733    /// Doesn't add any newlines or leading/trailing spaces
2734    pub(super) fn write_expr(
2735        &mut self,
2736        module: &Module,
2737        expr: Handle<crate::Expression>,
2738        func_ctx: &back::FunctionCtx<'_>,
2739    ) -> BackendResult {
2740        use crate::Expression;
2741
2742        // Handle the special semantics of vertex_index/instance_index
2743        let ff_input = if self.options.special_constants_binding.is_some() {
2744            func_ctx.is_fixed_function_input(expr, module)
2745        } else {
2746            None
2747        };
2748        let closing_bracket = match ff_input {
2749            Some(crate::BuiltIn::VertexIndex) => {
2750                write!(self.out, "({SPECIAL_CBUF_VAR}.{SPECIAL_FIRST_VERTEX} + ")?;
2751                ")"
2752            }
2753            Some(crate::BuiltIn::InstanceIndex) => {
2754                write!(self.out, "({SPECIAL_CBUF_VAR}.{SPECIAL_FIRST_INSTANCE} + ",)?;
2755                ")"
2756            }
2757            Some(crate::BuiltIn::NumWorkGroups) => {
2758                // Note: despite their names (`FIRST_VERTEX` and `FIRST_INSTANCE`),
2759                // in compute shaders the special constants contain the number
2760                // of workgroups, which we are using here.
2761                write!(
2762                    self.out,
2763                    "uint3({SPECIAL_CBUF_VAR}.{SPECIAL_FIRST_VERTEX}, {SPECIAL_CBUF_VAR}.{SPECIAL_FIRST_INSTANCE}, {SPECIAL_CBUF_VAR}.{SPECIAL_OTHER})",
2764                )?;
2765                return Ok(());
2766            }
2767            _ => "",
2768        };
2769
2770        if let Some(name) = self.named_expressions.get(&expr) {
2771            write!(self.out, "{name}{closing_bracket}")?;
2772            return Ok(());
2773        }
2774
2775        let expression = &func_ctx.expressions[expr];
2776
2777        match *expression {
2778            Expression::Literal(_)
2779            | Expression::Constant(_)
2780            | Expression::ZeroValue(_)
2781            | Expression::Compose { .. }
2782            | Expression::Splat { .. } => {
2783                self.write_possibly_const_expression(
2784                    module,
2785                    expr,
2786                    func_ctx.expressions,
2787                    |writer, expr| writer.write_expr(module, expr, func_ctx),
2788                )?;
2789            }
2790            Expression::Override(_) => return Err(Error::Override),
2791            // Avoid undefined behaviour for addition, subtraction, and
2792            // multiplication of signed integers by casting operands to
2793            // unsigned, performing the operation, then casting the result back
2794            // to signed.
2795            // TODO(#7109): This relies on the asint()/asuint() functions which only work
2796            // for 32-bit types, so we must find another solution for different bit widths.
2797            Expression::Binary {
2798                op:
2799                    op @ crate::BinaryOperator::Add
2800                    | op @ crate::BinaryOperator::Subtract
2801                    | op @ crate::BinaryOperator::Multiply,
2802                left,
2803                right,
2804            } if matches!(
2805                func_ctx.resolve_type(expr, &module.types).scalar(),
2806                Some(Scalar::I32)
2807            ) =>
2808            {
2809                write!(self.out, "asint(asuint(",)?;
2810                self.write_expr(module, left, func_ctx)?;
2811                write!(self.out, ") {} asuint(", back::binary_operation_str(op))?;
2812                self.write_expr(module, right, func_ctx)?;
2813                write!(self.out, "))")?;
2814            }
2815            // All of the multiplication can be expressed as `mul`,
2816            // except vector * vector, which needs to use the "*" operator.
2817            Expression::Binary {
2818                op: crate::BinaryOperator::Multiply,
2819                left,
2820                right,
2821            } if func_ctx.resolve_type(left, &module.types).is_matrix()
2822                || func_ctx.resolve_type(right, &module.types).is_matrix() =>
2823            {
2824                // We intentionally flip the order of multiplication as our matrices are implicitly transposed.
2825                write!(self.out, "mul(")?;
2826                self.write_expr(module, right, func_ctx)?;
2827                write!(self.out, ", ")?;
2828                self.write_expr(module, left, func_ctx)?;
2829                write!(self.out, ")")?;
2830            }
2831
2832            // WGSL says that floating-point division by zero should return
2833            // infinity. Microsoft's Direct3D 11 functional specification
2834            // (https://microsoft.github.io/DirectX-Specs/d3d/archive/D3D11_3_FunctionalSpec.htm)
2835            // says:
2836            //
2837            //     Divide by 0 produces +/- INF, except 0/0 which results in NaN.
2838            //
2839            // which is what we want. The DXIL specification for the FDiv
2840            // instruction corroborates this:
2841            //
2842            // https://github.com/microsoft/DirectXShaderCompiler/blob/main/docs/DXIL.rst#fdiv
2843            Expression::Binary {
2844                op: crate::BinaryOperator::Divide,
2845                left,
2846                right,
2847            } if matches!(
2848                func_ctx.resolve_type(expr, &module.types).scalar_kind(),
2849                Some(ScalarKind::Sint | ScalarKind::Uint)
2850            ) =>
2851            {
2852                write!(self.out, "{DIV_FUNCTION}(")?;
2853                self.write_expr(module, left, func_ctx)?;
2854                write!(self.out, ", ")?;
2855                self.write_expr(module, right, func_ctx)?;
2856                write!(self.out, ")")?;
2857            }
2858
2859            Expression::Binary {
2860                op: crate::BinaryOperator::Modulo,
2861                left,
2862                right,
2863            } if matches!(
2864                func_ctx.resolve_type(expr, &module.types).scalar_kind(),
2865                Some(ScalarKind::Sint | ScalarKind::Uint | ScalarKind::Float)
2866            ) =>
2867            {
2868                write!(self.out, "{MOD_FUNCTION}(")?;
2869                self.write_expr(module, left, func_ctx)?;
2870                write!(self.out, ", ")?;
2871                self.write_expr(module, right, func_ctx)?;
2872                write!(self.out, ")")?;
2873            }
2874
2875            Expression::Binary { op, left, right } => {
2876                write!(self.out, "(")?;
2877                self.write_expr(module, left, func_ctx)?;
2878                write!(self.out, " {} ", back::binary_operation_str(op))?;
2879                self.write_expr(module, right, func_ctx)?;
2880                write!(self.out, ")")?;
2881            }
2882            Expression::Access { base, index } => {
2883                if let Some(crate::AddressSpace::Storage { .. }) =
2884                    func_ctx.resolve_type(expr, &module.types).pointer_space()
2885                {
2886                    // do nothing, the chain is written on `Load`/`Store`
2887                } else {
2888                    // We use the function __get_col_of_matCx2 here in cases
2889                    // where `base`s type resolves to a matCx2 and is part of a
2890                    // struct member with type of (possibly nested) array of matCx2's.
2891                    //
2892                    // Note that this only works for `Load`s and we handle
2893                    // `Store`s differently in `Statement::Store`.
2894                    if let Some(MatrixType {
2895                        columns,
2896                        rows: crate::VectorSize::Bi,
2897                        width: 4,
2898                    }) = get_inner_matrix_of_struct_array_member(module, base, func_ctx, true)
2899                        .or_else(|| get_global_uniform_matrix(module, base, func_ctx))
2900                    {
2901                        write!(self.out, "__get_col_of_mat{}x2(", columns as u8)?;
2902                        self.write_expr(module, base, func_ctx)?;
2903                        write!(self.out, ", ")?;
2904                        self.write_expr(module, index, func_ctx)?;
2905                        write!(self.out, ")")?;
2906                        return Ok(());
2907                    }
2908
2909                    let resolved = func_ctx.resolve_type(base, &module.types);
2910
2911                    let (indexing_binding_array, non_uniform_qualifier) = match *resolved {
2912                        TypeInner::BindingArray { .. } => {
2913                            let uniformity = &func_ctx.info[index].uniformity;
2914
2915                            (true, uniformity.non_uniform_result.is_some())
2916                        }
2917                        _ => (false, false),
2918                    };
2919
2920                    self.write_expr(module, base, func_ctx)?;
2921
2922                    let array_sampler_info = self.sampler_binding_array_info_from_expression(
2923                        module, func_ctx, base, resolved,
2924                    );
2925
2926                    if let Some(ref info) = array_sampler_info {
2927                        write!(self.out, "{}[", info.sampler_heap_name)?;
2928                    } else {
2929                        write!(self.out, "[")?;
2930                    }
2931
2932                    let needs_bound_check = self.options.restrict_indexing
2933                        && !indexing_binding_array
2934                        && match resolved.pointer_space() {
2935                            Some(
2936                                crate::AddressSpace::Function
2937                                | crate::AddressSpace::Private
2938                                | crate::AddressSpace::WorkGroup
2939                                | crate::AddressSpace::PushConstant,
2940                            )
2941                            | None => true,
2942                            Some(crate::AddressSpace::Uniform) => {
2943                                // check if BindTarget.restrict_indexing is set, this is used for dynamic buffers
2944                                let var_handle = self.fill_access_chain(module, base, func_ctx)?;
2945                                let bind_target = self
2946                                    .options
2947                                    .resolve_resource_binding(
2948                                        module.global_variables[var_handle]
2949                                            .binding
2950                                            .as_ref()
2951                                            .unwrap(),
2952                                    )
2953                                    .unwrap();
2954                                bind_target.restrict_indexing
2955                            }
2956                            Some(
2957                                crate::AddressSpace::Handle | crate::AddressSpace::Storage { .. },
2958                            ) => unreachable!(),
2959                        };
2960                    // Decide whether this index needs to be clamped to fall within range.
2961                    let restriction_needed = if needs_bound_check {
2962                        index::access_needs_check(
2963                            base,
2964                            index::GuardedIndex::Expression(index),
2965                            module,
2966                            func_ctx.expressions,
2967                            func_ctx.info,
2968                        )
2969                    } else {
2970                        None
2971                    };
2972                    if let Some(limit) = restriction_needed {
2973                        write!(self.out, "min(uint(")?;
2974                        self.write_expr(module, index, func_ctx)?;
2975                        write!(self.out, "), ")?;
2976                        match limit {
2977                            index::IndexableLength::Known(limit) => {
2978                                write!(self.out, "{}u", limit - 1)?;
2979                            }
2980                            index::IndexableLength::Dynamic => unreachable!(),
2981                        }
2982                        write!(self.out, ")")?;
2983                    } else {
2984                        if non_uniform_qualifier {
2985                            write!(self.out, "NonUniformResourceIndex(")?;
2986                        }
2987                        if let Some(ref info) = array_sampler_info {
2988                            write!(
2989                                self.out,
2990                                "{}[{} + ",
2991                                info.sampler_index_buffer_name, info.binding_array_base_index_name,
2992                            )?;
2993                        }
2994                        self.write_expr(module, index, func_ctx)?;
2995                        if array_sampler_info.is_some() {
2996                            write!(self.out, "]")?;
2997                        }
2998                        if non_uniform_qualifier {
2999                            write!(self.out, ")")?;
3000                        }
3001                    }
3002
3003                    write!(self.out, "]")?;
3004                }
3005            }
3006            Expression::AccessIndex { base, index } => {
3007                if let Some(crate::AddressSpace::Storage { .. }) =
3008                    func_ctx.resolve_type(expr, &module.types).pointer_space()
3009                {
3010                    // do nothing, the chain is written on `Load`/`Store`
3011                } else {
3012                    // See if we need to write the matrix column access in a
3013                    // special way since the type of `base` is our special
3014                    // __matCx2 struct.
3015                    if let Some(MatrixType {
3016                        rows: crate::VectorSize::Bi,
3017                        width: 4,
3018                        ..
3019                    }) = get_inner_matrix_of_struct_array_member(module, base, func_ctx, true)
3020                        .or_else(|| get_global_uniform_matrix(module, base, func_ctx))
3021                    {
3022                        self.write_expr(module, base, func_ctx)?;
3023                        write!(self.out, "._{index}")?;
3024                        return Ok(());
3025                    }
3026
3027                    let base_ty_res = &func_ctx.info[base].ty;
3028                    let mut resolved = base_ty_res.inner_with(&module.types);
3029                    let base_ty_handle = match *resolved {
3030                        TypeInner::Pointer { base, .. } => {
3031                            resolved = &module.types[base].inner;
3032                            Some(base)
3033                        }
3034                        _ => base_ty_res.handle(),
3035                    };
3036
3037                    // We treat matrices of the form `matCx2` as a sequence of C `vec2`s.
3038                    // See the module-level block comment in mod.rs for details.
3039                    //
3040                    // We handle matrix reconstruction here for Loads.
3041                    // Stores are handled directly by `Statement::Store`.
3042                    if let TypeInner::Struct { ref members, .. } = *resolved {
3043                        let member = &members[index as usize];
3044
3045                        match module.types[member.ty].inner {
3046                            TypeInner::Matrix {
3047                                rows: crate::VectorSize::Bi,
3048                                ..
3049                            } if member.binding.is_none() => {
3050                                let ty = base_ty_handle.unwrap();
3051                                self.write_wrapped_struct_matrix_get_function_name(
3052                                    WrappedStructMatrixAccess { ty, index },
3053                                )?;
3054                                write!(self.out, "(")?;
3055                                self.write_expr(module, base, func_ctx)?;
3056                                write!(self.out, ")")?;
3057                                return Ok(());
3058                            }
3059                            _ => {}
3060                        }
3061                    }
3062
3063                    let array_sampler_info = self.sampler_binding_array_info_from_expression(
3064                        module, func_ctx, base, resolved,
3065                    );
3066
3067                    if let Some(ref info) = array_sampler_info {
3068                        write!(
3069                            self.out,
3070                            "{}[{}",
3071                            info.sampler_heap_name, info.sampler_index_buffer_name
3072                        )?;
3073                    }
3074
3075                    self.write_expr(module, base, func_ctx)?;
3076
3077                    match *resolved {
3078                        // We specifically lift the ValuePointer to this case. While `[0]` is valid
3079                        // HLSL for any vector behind a value pointer, FXC completely miscompiles
3080                        // it and generates completely nonsensical DXBC.
3081                        //
3082                        // See https://github.com/gfx-rs/naga/issues/2095 for more details.
3083                        TypeInner::Vector { .. } | TypeInner::ValuePointer { .. } => {
3084                            // Write vector access as a swizzle
3085                            write!(self.out, ".{}", back::COMPONENTS[index as usize])?
3086                        }
3087                        TypeInner::Matrix { .. }
3088                        | TypeInner::Array { .. }
3089                        | TypeInner::BindingArray { .. } => {
3090                            if let Some(ref info) = array_sampler_info {
3091                                write!(
3092                                    self.out,
3093                                    "[{} + {index}]",
3094                                    info.binding_array_base_index_name
3095                                )?;
3096                            } else {
3097                                write!(self.out, "[{index}]")?;
3098                            }
3099                        }
3100                        TypeInner::Struct { .. } => {
3101                            // This will never panic in case the type is a `Struct`, this is not true
3102                            // for other types so we can only check while inside this match arm
3103                            let ty = base_ty_handle.unwrap();
3104
3105                            write!(
3106                                self.out,
3107                                ".{}",
3108                                &self.names[&NameKey::StructMember(ty, index)]
3109                            )?
3110                        }
3111                        ref other => return Err(Error::Custom(format!("Cannot index {other:?}"))),
3112                    }
3113
3114                    if array_sampler_info.is_some() {
3115                        write!(self.out, "]")?;
3116                    }
3117                }
3118            }
3119            Expression::FunctionArgument(pos) => {
3120                let key = func_ctx.argument_key(pos);
3121                let name = &self.names[&key];
3122                write!(self.out, "{name}")?;
3123            }
3124            Expression::ImageSample {
3125                coordinate,
3126                image,
3127                sampler,
3128                clamp_to_edge: true,
3129                gather: None,
3130                array_index: None,
3131                offset: None,
3132                level: crate::SampleLevel::Zero,
3133                depth_ref: None,
3134            } => {
3135                write!(self.out, "{IMAGE_SAMPLE_BASE_CLAMP_TO_EDGE_FUNCTION}(")?;
3136                self.write_expr(module, image, func_ctx)?;
3137                write!(self.out, ", ")?;
3138                self.write_expr(module, sampler, func_ctx)?;
3139                write!(self.out, ", ")?;
3140                self.write_expr(module, coordinate, func_ctx)?;
3141                write!(self.out, ")")?;
3142            }
3143            Expression::ImageSample {
3144                image,
3145                sampler,
3146                gather,
3147                coordinate,
3148                array_index,
3149                offset,
3150                level,
3151                depth_ref,
3152                clamp_to_edge,
3153            } => {
3154                if clamp_to_edge {
3155                    return Err(Error::Custom(
3156                        "ImageSample::clamp_to_edge should have been validated out".to_string(),
3157                    ));
3158                }
3159
3160                use crate::SampleLevel as Sl;
3161                const COMPONENTS: [&str; 4] = ["", "Green", "Blue", "Alpha"];
3162
3163                let (base_str, component_str) = match gather {
3164                    Some(component) => ("Gather", COMPONENTS[component as usize]),
3165                    None => ("Sample", ""),
3166                };
3167                let cmp_str = match depth_ref {
3168                    Some(_) => "Cmp",
3169                    None => "",
3170                };
3171                let level_str = match level {
3172                    Sl::Zero if gather.is_none() => "LevelZero",
3173                    Sl::Auto | Sl::Zero => "",
3174                    Sl::Exact(_) => "Level",
3175                    Sl::Bias(_) => "Bias",
3176                    Sl::Gradient { .. } => "Grad",
3177                };
3178
3179                self.write_expr(module, image, func_ctx)?;
3180                write!(self.out, ".{base_str}{cmp_str}{component_str}{level_str}(")?;
3181                self.write_expr(module, sampler, func_ctx)?;
3182                write!(self.out, ", ")?;
3183                self.write_texture_coordinates(
3184                    "float",
3185                    coordinate,
3186                    array_index,
3187                    None,
3188                    module,
3189                    func_ctx,
3190                )?;
3191
3192                if let Some(depth_ref) = depth_ref {
3193                    write!(self.out, ", ")?;
3194                    self.write_expr(module, depth_ref, func_ctx)?;
3195                }
3196
3197                match level {
3198                    Sl::Auto | Sl::Zero => {}
3199                    Sl::Exact(expr) => {
3200                        write!(self.out, ", ")?;
3201                        self.write_expr(module, expr, func_ctx)?;
3202                    }
3203                    Sl::Bias(expr) => {
3204                        write!(self.out, ", ")?;
3205                        self.write_expr(module, expr, func_ctx)?;
3206                    }
3207                    Sl::Gradient { x, y } => {
3208                        write!(self.out, ", ")?;
3209                        self.write_expr(module, x, func_ctx)?;
3210                        write!(self.out, ", ")?;
3211                        self.write_expr(module, y, func_ctx)?;
3212                    }
3213                }
3214
3215                if let Some(offset) = offset {
3216                    write!(self.out, ", ")?;
3217                    write!(self.out, "int2(")?; // work around https://github.com/microsoft/DirectXShaderCompiler/issues/5082#issuecomment-1540147807
3218                    self.write_const_expression(module, offset, func_ctx.expressions)?;
3219                    write!(self.out, ")")?;
3220                }
3221
3222                write!(self.out, ")")?;
3223            }
3224            Expression::ImageQuery { image, query } => {
3225                // use wrapped image query function
3226                if let TypeInner::Image {
3227                    dim,
3228                    arrayed,
3229                    class,
3230                } = *func_ctx.resolve_type(image, &module.types)
3231                {
3232                    let wrapped_image_query = WrappedImageQuery {
3233                        dim,
3234                        arrayed,
3235                        class,
3236                        query: query.into(),
3237                    };
3238
3239                    self.write_wrapped_image_query_function_name(wrapped_image_query)?;
3240                    write!(self.out, "(")?;
3241                    // Image always first param
3242                    self.write_expr(module, image, func_ctx)?;
3243                    if let crate::ImageQuery::Size { level: Some(level) } = query {
3244                        write!(self.out, ", ")?;
3245                        self.write_expr(module, level, func_ctx)?;
3246                    }
3247                    write!(self.out, ")")?;
3248                }
3249            }
3250            Expression::ImageLoad {
3251                image,
3252                coordinate,
3253                array_index,
3254                sample,
3255                level,
3256            } => self.write_image_load(
3257                &module,
3258                expr,
3259                func_ctx,
3260                image,
3261                coordinate,
3262                array_index,
3263                sample,
3264                level,
3265            )?,
3266            Expression::GlobalVariable(handle) => {
3267                let global_variable = &module.global_variables[handle];
3268                let ty = &module.types[global_variable.ty].inner;
3269
3270                // In the case of binding arrays of samplers, we need to not write anything
3271                // as the we are in the wrong position to fully write the expression.
3272                //
3273                // The entire writing is done by AccessIndex.
3274                let is_binding_array_of_samplers = match *ty {
3275                    TypeInner::BindingArray { base, .. } => {
3276                        let base_ty = &module.types[base].inner;
3277                        matches!(*base_ty, TypeInner::Sampler { .. })
3278                    }
3279                    _ => false,
3280                };
3281
3282                let is_storage_space =
3283                    matches!(global_variable.space, crate::AddressSpace::Storage { .. });
3284
3285                if !is_binding_array_of_samplers && !is_storage_space {
3286                    let name = &self.names[&NameKey::GlobalVariable(handle)];
3287                    write!(self.out, "{name}")?;
3288                }
3289            }
3290            Expression::LocalVariable(handle) => {
3291                write!(self.out, "{}", self.names[&func_ctx.name_key(handle)])?
3292            }
3293            Expression::Load { pointer } => {
3294                match func_ctx
3295                    .resolve_type(pointer, &module.types)
3296                    .pointer_space()
3297                {
3298                    Some(crate::AddressSpace::Storage { .. }) => {
3299                        let var_handle = self.fill_access_chain(module, pointer, func_ctx)?;
3300                        let result_ty = func_ctx.info[expr].ty.clone();
3301                        self.write_storage_load(module, var_handle, result_ty, func_ctx)?;
3302                    }
3303                    _ => {
3304                        let mut close_paren = false;
3305
3306                        // We cast the value loaded to a native HLSL floatCx2
3307                        // in cases where it is of type:
3308                        //  - __matCx2 or
3309                        //  - a (possibly nested) array of __matCx2's
3310                        if let Some(MatrixType {
3311                            rows: crate::VectorSize::Bi,
3312                            width: 4,
3313                            ..
3314                        }) = get_inner_matrix_of_struct_array_member(
3315                            module, pointer, func_ctx, false,
3316                        )
3317                        .or_else(|| get_inner_matrix_of_global_uniform(module, pointer, func_ctx))
3318                        {
3319                            let mut resolved = func_ctx.resolve_type(pointer, &module.types);
3320                            let ptr_tr = resolved.pointer_base_type();
3321                            if let Some(ptr_ty) =
3322                                ptr_tr.as_ref().map(|tr| tr.inner_with(&module.types))
3323                            {
3324                                resolved = ptr_ty;
3325                            }
3326
3327                            write!(self.out, "((")?;
3328                            if let TypeInner::Array { base, size, .. } = *resolved {
3329                                self.write_type(module, base)?;
3330                                self.write_array_size(module, base, size)?;
3331                            } else {
3332                                self.write_value_type(module, resolved)?;
3333                            }
3334                            write!(self.out, ")")?;
3335                            close_paren = true;
3336                        }
3337
3338                        self.write_expr(module, pointer, func_ctx)?;
3339
3340                        if close_paren {
3341                            write!(self.out, ")")?;
3342                        }
3343                    }
3344                }
3345            }
3346            Expression::Unary { op, expr } => {
3347                // https://docs.microsoft.com/en-us/windows/win32/direct3dhlsl/dx-graphics-hlsl-operators#unary-operators
3348                let op_str = match op {
3349                    crate::UnaryOperator::Negate => {
3350                        match func_ctx.resolve_type(expr, &module.types).scalar() {
3351                            Some(Scalar::I32) => NEG_FUNCTION,
3352                            _ => "-",
3353                        }
3354                    }
3355                    crate::UnaryOperator::LogicalNot => "!",
3356                    crate::UnaryOperator::BitwiseNot => "~",
3357                };
3358                write!(self.out, "{op_str}(")?;
3359                self.write_expr(module, expr, func_ctx)?;
3360                write!(self.out, ")")?;
3361            }
3362            Expression::As {
3363                expr,
3364                kind,
3365                convert,
3366            } => {
3367                let inner = func_ctx.resolve_type(expr, &module.types);
3368                if inner.scalar_kind() == Some(ScalarKind::Float)
3369                    && (kind == ScalarKind::Sint || kind == ScalarKind::Uint)
3370                    && convert.is_some()
3371                {
3372                    // Use helper functions for float to int casts in order to
3373                    // avoid undefined behaviour when value is out of range for
3374                    // the target type.
3375                    let fun_name = match (kind, convert) {
3376                        (ScalarKind::Sint, Some(4)) => F2I32_FUNCTION,
3377                        (ScalarKind::Uint, Some(4)) => F2U32_FUNCTION,
3378                        (ScalarKind::Sint, Some(8)) => F2I64_FUNCTION,
3379                        (ScalarKind::Uint, Some(8)) => F2U64_FUNCTION,
3380                        _ => unreachable!(),
3381                    };
3382                    write!(self.out, "{fun_name}(")?;
3383                    self.write_expr(module, expr, func_ctx)?;
3384                    write!(self.out, ")")?;
3385                } else {
3386                    let close_paren = match convert {
3387                        Some(dst_width) => {
3388                            let scalar = Scalar {
3389                                kind,
3390                                width: dst_width,
3391                            };
3392                            match *inner {
3393                                TypeInner::Vector { size, .. } => {
3394                                    write!(
3395                                        self.out,
3396                                        "{}{}(",
3397                                        scalar.to_hlsl_str()?,
3398                                        common::vector_size_str(size)
3399                                    )?;
3400                                }
3401                                TypeInner::Scalar(_) => {
3402                                    write!(self.out, "{}(", scalar.to_hlsl_str()?,)?;
3403                                }
3404                                TypeInner::Matrix { columns, rows, .. } => {
3405                                    write!(
3406                                        self.out,
3407                                        "{}{}x{}(",
3408                                        scalar.to_hlsl_str()?,
3409                                        common::vector_size_str(columns),
3410                                        common::vector_size_str(rows)
3411                                    )?;
3412                                }
3413                                _ => {
3414                                    return Err(Error::Unimplemented(format!(
3415                                        "write_expr expression::as {inner:?}"
3416                                    )));
3417                                }
3418                            };
3419                            true
3420                        }
3421                        None => {
3422                            if inner.scalar_width() == Some(8) {
3423                                false
3424                            } else {
3425                                write!(self.out, "{}(", kind.to_hlsl_cast(),)?;
3426                                true
3427                            }
3428                        }
3429                    };
3430                    self.write_expr(module, expr, func_ctx)?;
3431                    if close_paren {
3432                        write!(self.out, ")")?;
3433                    }
3434                }
3435            }
3436            Expression::Math {
3437                fun,
3438                arg,
3439                arg1,
3440                arg2,
3441                arg3,
3442            } => {
3443                use crate::MathFunction as Mf;
3444
3445                enum Function {
3446                    Asincosh { is_sin: bool },
3447                    Atanh,
3448                    Pack2x16float,
3449                    Pack2x16snorm,
3450                    Pack2x16unorm,
3451                    Pack4x8snorm,
3452                    Pack4x8unorm,
3453                    Pack4xI8,
3454                    Pack4xU8,
3455                    Pack4xI8Clamp,
3456                    Pack4xU8Clamp,
3457                    Unpack2x16float,
3458                    Unpack2x16snorm,
3459                    Unpack2x16unorm,
3460                    Unpack4x8snorm,
3461                    Unpack4x8unorm,
3462                    Unpack4xI8,
3463                    Unpack4xU8,
3464                    Dot4I8Packed,
3465                    Dot4U8Packed,
3466                    QuantizeToF16,
3467                    Regular(&'static str),
3468                    MissingIntOverload(&'static str),
3469                    MissingIntReturnType(&'static str),
3470                    CountTrailingZeros,
3471                    CountLeadingZeros,
3472                }
3473
3474                let fun = match fun {
3475                    // comparison
3476                    Mf::Abs => match func_ctx.resolve_type(arg, &module.types).scalar() {
3477                        Some(Scalar::I32) => Function::Regular(ABS_FUNCTION),
3478                        _ => Function::Regular("abs"),
3479                    },
3480                    Mf::Min => Function::Regular("min"),
3481                    Mf::Max => Function::Regular("max"),
3482                    Mf::Clamp => Function::Regular("clamp"),
3483                    Mf::Saturate => Function::Regular("saturate"),
3484                    // trigonometry
3485                    Mf::Cos => Function::Regular("cos"),
3486                    Mf::Cosh => Function::Regular("cosh"),
3487                    Mf::Sin => Function::Regular("sin"),
3488                    Mf::Sinh => Function::Regular("sinh"),
3489                    Mf::Tan => Function::Regular("tan"),
3490                    Mf::Tanh => Function::Regular("tanh"),
3491                    Mf::Acos => Function::Regular("acos"),
3492                    Mf::Asin => Function::Regular("asin"),
3493                    Mf::Atan => Function::Regular("atan"),
3494                    Mf::Atan2 => Function::Regular("atan2"),
3495                    Mf::Asinh => Function::Asincosh { is_sin: true },
3496                    Mf::Acosh => Function::Asincosh { is_sin: false },
3497                    Mf::Atanh => Function::Atanh,
3498                    Mf::Radians => Function::Regular("radians"),
3499                    Mf::Degrees => Function::Regular("degrees"),
3500                    // decomposition
3501                    Mf::Ceil => Function::Regular("ceil"),
3502                    Mf::Floor => Function::Regular("floor"),
3503                    Mf::Round => Function::Regular("round"),
3504                    Mf::Fract => Function::Regular("frac"),
3505                    Mf::Trunc => Function::Regular("trunc"),
3506                    Mf::Modf => Function::Regular(MODF_FUNCTION),
3507                    Mf::Frexp => Function::Regular(FREXP_FUNCTION),
3508                    Mf::Ldexp => Function::Regular("ldexp"),
3509                    // exponent
3510                    Mf::Exp => Function::Regular("exp"),
3511                    Mf::Exp2 => Function::Regular("exp2"),
3512                    Mf::Log => Function::Regular("log"),
3513                    Mf::Log2 => Function::Regular("log2"),
3514                    Mf::Pow => Function::Regular("pow"),
3515                    // geometry
3516                    Mf::Dot => Function::Regular("dot"),
3517                    Mf::Dot4I8Packed => Function::Dot4I8Packed,
3518                    Mf::Dot4U8Packed => Function::Dot4U8Packed,
3519                    //Mf::Outer => ,
3520                    Mf::Cross => Function::Regular("cross"),
3521                    Mf::Distance => Function::Regular("distance"),
3522                    Mf::Length => Function::Regular("length"),
3523                    Mf::Normalize => Function::Regular("normalize"),
3524                    Mf::FaceForward => Function::Regular("faceforward"),
3525                    Mf::Reflect => Function::Regular("reflect"),
3526                    Mf::Refract => Function::Regular("refract"),
3527                    // computational
3528                    Mf::Sign => Function::Regular("sign"),
3529                    Mf::Fma => Function::Regular("mad"),
3530                    Mf::Mix => Function::Regular("lerp"),
3531                    Mf::Step => Function::Regular("step"),
3532                    Mf::SmoothStep => Function::Regular("smoothstep"),
3533                    Mf::Sqrt => Function::Regular("sqrt"),
3534                    Mf::InverseSqrt => Function::Regular("rsqrt"),
3535                    //Mf::Inverse =>,
3536                    Mf::Transpose => Function::Regular("transpose"),
3537                    Mf::Determinant => Function::Regular("determinant"),
3538                    Mf::QuantizeToF16 => Function::QuantizeToF16,
3539                    // bits
3540                    Mf::CountTrailingZeros => Function::CountTrailingZeros,
3541                    Mf::CountLeadingZeros => Function::CountLeadingZeros,
3542                    Mf::CountOneBits => Function::MissingIntOverload("countbits"),
3543                    Mf::ReverseBits => Function::MissingIntOverload("reversebits"),
3544                    Mf::FirstTrailingBit => Function::MissingIntReturnType("firstbitlow"),
3545                    Mf::FirstLeadingBit => Function::MissingIntReturnType("firstbithigh"),
3546                    Mf::ExtractBits => Function::Regular(EXTRACT_BITS_FUNCTION),
3547                    Mf::InsertBits => Function::Regular(INSERT_BITS_FUNCTION),
3548                    // Data Packing
3549                    Mf::Pack2x16float => Function::Pack2x16float,
3550                    Mf::Pack2x16snorm => Function::Pack2x16snorm,
3551                    Mf::Pack2x16unorm => Function::Pack2x16unorm,
3552                    Mf::Pack4x8snorm => Function::Pack4x8snorm,
3553                    Mf::Pack4x8unorm => Function::Pack4x8unorm,
3554                    Mf::Pack4xI8 => Function::Pack4xI8,
3555                    Mf::Pack4xU8 => Function::Pack4xU8,
3556                    Mf::Pack4xI8Clamp => Function::Pack4xI8Clamp,
3557                    Mf::Pack4xU8Clamp => Function::Pack4xU8Clamp,
3558                    // Data Unpacking
3559                    Mf::Unpack2x16float => Function::Unpack2x16float,
3560                    Mf::Unpack2x16snorm => Function::Unpack2x16snorm,
3561                    Mf::Unpack2x16unorm => Function::Unpack2x16unorm,
3562                    Mf::Unpack4x8snorm => Function::Unpack4x8snorm,
3563                    Mf::Unpack4x8unorm => Function::Unpack4x8unorm,
3564                    Mf::Unpack4xI8 => Function::Unpack4xI8,
3565                    Mf::Unpack4xU8 => Function::Unpack4xU8,
3566                    _ => return Err(Error::Unimplemented(format!("write_expr_math {fun:?}"))),
3567                };
3568
3569                match fun {
3570                    Function::Asincosh { is_sin } => {
3571                        write!(self.out, "log(")?;
3572                        self.write_expr(module, arg, func_ctx)?;
3573                        write!(self.out, " + sqrt(")?;
3574                        self.write_expr(module, arg, func_ctx)?;
3575                        write!(self.out, " * ")?;
3576                        self.write_expr(module, arg, func_ctx)?;
3577                        match is_sin {
3578                            true => write!(self.out, " + 1.0))")?,
3579                            false => write!(self.out, " - 1.0))")?,
3580                        }
3581                    }
3582                    Function::Atanh => {
3583                        write!(self.out, "0.5 * log((1.0 + ")?;
3584                        self.write_expr(module, arg, func_ctx)?;
3585                        write!(self.out, ") / (1.0 - ")?;
3586                        self.write_expr(module, arg, func_ctx)?;
3587                        write!(self.out, "))")?;
3588                    }
3589                    Function::Pack2x16float => {
3590                        write!(self.out, "(f32tof16(")?;
3591                        self.write_expr(module, arg, func_ctx)?;
3592                        write!(self.out, "[0]) | f32tof16(")?;
3593                        self.write_expr(module, arg, func_ctx)?;
3594                        write!(self.out, "[1]) << 16)")?;
3595                    }
3596                    Function::Pack2x16snorm => {
3597                        let scale = 32767;
3598
3599                        write!(self.out, "uint((int(round(clamp(")?;
3600                        self.write_expr(module, arg, func_ctx)?;
3601                        write!(
3602                            self.out,
3603                            "[0], -1.0, 1.0) * {scale}.0)) & 0xFFFF) | ((int(round(clamp("
3604                        )?;
3605                        self.write_expr(module, arg, func_ctx)?;
3606                        write!(self.out, "[1], -1.0, 1.0) * {scale}.0)) & 0xFFFF) << 16))",)?;
3607                    }
3608                    Function::Pack2x16unorm => {
3609                        let scale = 65535;
3610
3611                        write!(self.out, "(uint(round(clamp(")?;
3612                        self.write_expr(module, arg, func_ctx)?;
3613                        write!(self.out, "[0], 0.0, 1.0) * {scale}.0)) | uint(round(clamp(")?;
3614                        self.write_expr(module, arg, func_ctx)?;
3615                        write!(self.out, "[1], 0.0, 1.0) * {scale}.0)) << 16)")?;
3616                    }
3617                    Function::Pack4x8snorm => {
3618                        let scale = 127;
3619
3620                        write!(self.out, "uint((int(round(clamp(")?;
3621                        self.write_expr(module, arg, func_ctx)?;
3622                        write!(
3623                            self.out,
3624                            "[0], -1.0, 1.0) * {scale}.0)) & 0xFF) | ((int(round(clamp("
3625                        )?;
3626                        self.write_expr(module, arg, func_ctx)?;
3627                        write!(
3628                            self.out,
3629                            "[1], -1.0, 1.0) * {scale}.0)) & 0xFF) << 8) | ((int(round(clamp("
3630                        )?;
3631                        self.write_expr(module, arg, func_ctx)?;
3632                        write!(
3633                            self.out,
3634                            "[2], -1.0, 1.0) * {scale}.0)) & 0xFF) << 16) | ((int(round(clamp("
3635                        )?;
3636                        self.write_expr(module, arg, func_ctx)?;
3637                        write!(self.out, "[3], -1.0, 1.0) * {scale}.0)) & 0xFF) << 24))",)?;
3638                    }
3639                    Function::Pack4x8unorm => {
3640                        let scale = 255;
3641
3642                        write!(self.out, "(uint(round(clamp(")?;
3643                        self.write_expr(module, arg, func_ctx)?;
3644                        write!(self.out, "[0], 0.0, 1.0) * {scale}.0)) | uint(round(clamp(")?;
3645                        self.write_expr(module, arg, func_ctx)?;
3646                        write!(
3647                            self.out,
3648                            "[1], 0.0, 1.0) * {scale}.0)) << 8 | uint(round(clamp("
3649                        )?;
3650                        self.write_expr(module, arg, func_ctx)?;
3651                        write!(
3652                            self.out,
3653                            "[2], 0.0, 1.0) * {scale}.0)) << 16 | uint(round(clamp("
3654                        )?;
3655                        self.write_expr(module, arg, func_ctx)?;
3656                        write!(self.out, "[3], 0.0, 1.0) * {scale}.0)) << 24)")?;
3657                    }
3658                    fun @ (Function::Pack4xI8
3659                    | Function::Pack4xU8
3660                    | Function::Pack4xI8Clamp
3661                    | Function::Pack4xU8Clamp) => {
3662                        let was_signed =
3663                            matches!(fun, Function::Pack4xI8 | Function::Pack4xI8Clamp);
3664                        let clamp_bounds = match fun {
3665                            Function::Pack4xI8Clamp => Some(("-128", "127")),
3666                            Function::Pack4xU8Clamp => Some(("0", "255")),
3667                            _ => None,
3668                        };
3669                        if was_signed {
3670                            write!(self.out, "uint(")?;
3671                        }
3672                        let write_arg = |this: &mut Self| -> BackendResult {
3673                            if let Some((min, max)) = clamp_bounds {
3674                                write!(this.out, "clamp(")?;
3675                                this.write_expr(module, arg, func_ctx)?;
3676                                write!(this.out, ", {min}, {max})")?;
3677                            } else {
3678                                this.write_expr(module, arg, func_ctx)?;
3679                            }
3680                            Ok(())
3681                        };
3682                        write!(self.out, "(")?;
3683                        write_arg(self)?;
3684                        write!(self.out, "[0] & 0xFF) | ((")?;
3685                        write_arg(self)?;
3686                        write!(self.out, "[1] & 0xFF) << 8) | ((")?;
3687                        write_arg(self)?;
3688                        write!(self.out, "[2] & 0xFF) << 16) | ((")?;
3689                        write_arg(self)?;
3690                        write!(self.out, "[3] & 0xFF) << 24)")?;
3691                        if was_signed {
3692                            write!(self.out, ")")?;
3693                        }
3694                    }
3695
3696                    Function::Unpack2x16float => {
3697                        write!(self.out, "float2(f16tof32(")?;
3698                        self.write_expr(module, arg, func_ctx)?;
3699                        write!(self.out, "), f16tof32((")?;
3700                        self.write_expr(module, arg, func_ctx)?;
3701                        write!(self.out, ") >> 16))")?;
3702                    }
3703                    Function::Unpack2x16snorm => {
3704                        let scale = 32767;
3705
3706                        write!(self.out, "(float2(int2(")?;
3707                        self.write_expr(module, arg, func_ctx)?;
3708                        write!(self.out, " << 16, ")?;
3709                        self.write_expr(module, arg, func_ctx)?;
3710                        write!(self.out, ") >> 16) / {scale}.0)")?;
3711                    }
3712                    Function::Unpack2x16unorm => {
3713                        let scale = 65535;
3714
3715                        write!(self.out, "(float2(")?;
3716                        self.write_expr(module, arg, func_ctx)?;
3717                        write!(self.out, " & 0xFFFF, ")?;
3718                        self.write_expr(module, arg, func_ctx)?;
3719                        write!(self.out, " >> 16) / {scale}.0)")?;
3720                    }
3721                    Function::Unpack4x8snorm => {
3722                        let scale = 127;
3723
3724                        write!(self.out, "(float4(int4(")?;
3725                        self.write_expr(module, arg, func_ctx)?;
3726                        write!(self.out, " << 24, ")?;
3727                        self.write_expr(module, arg, func_ctx)?;
3728                        write!(self.out, " << 16, ")?;
3729                        self.write_expr(module, arg, func_ctx)?;
3730                        write!(self.out, " << 8, ")?;
3731                        self.write_expr(module, arg, func_ctx)?;
3732                        write!(self.out, ") >> 24) / {scale}.0)")?;
3733                    }
3734                    Function::Unpack4x8unorm => {
3735                        let scale = 255;
3736
3737                        write!(self.out, "(float4(")?;
3738                        self.write_expr(module, arg, func_ctx)?;
3739                        write!(self.out, " & 0xFF, ")?;
3740                        self.write_expr(module, arg, func_ctx)?;
3741                        write!(self.out, " >> 8 & 0xFF, ")?;
3742                        self.write_expr(module, arg, func_ctx)?;
3743                        write!(self.out, " >> 16 & 0xFF, ")?;
3744                        self.write_expr(module, arg, func_ctx)?;
3745                        write!(self.out, " >> 24) / {scale}.0)")?;
3746                    }
3747                    fun @ (Function::Unpack4xI8 | Function::Unpack4xU8) => {
3748                        write!(self.out, "(")?;
3749                        if matches!(fun, Function::Unpack4xU8) {
3750                            write!(self.out, "u")?;
3751                        }
3752                        write!(self.out, "int4(")?;
3753                        self.write_expr(module, arg, func_ctx)?;
3754                        write!(self.out, ", ")?;
3755                        self.write_expr(module, arg, func_ctx)?;
3756                        write!(self.out, " >> 8, ")?;
3757                        self.write_expr(module, arg, func_ctx)?;
3758                        write!(self.out, " >> 16, ")?;
3759                        self.write_expr(module, arg, func_ctx)?;
3760                        write!(self.out, " >> 24) << 24 >> 24)")?;
3761                    }
3762                    fun @ (Function::Dot4I8Packed | Function::Dot4U8Packed) => {
3763                        let arg1 = arg1.unwrap();
3764
3765                        if self.options.shader_model >= ShaderModel::V6_4 {
3766                            // Intrinsics `dot4add_{i, u}8packed` are available in SM 6.4 and later.
3767                            let function_name = match fun {
3768                                Function::Dot4I8Packed => "dot4add_i8packed",
3769                                Function::Dot4U8Packed => "dot4add_u8packed",
3770                                _ => unreachable!(),
3771                            };
3772                            write!(self.out, "{function_name}(")?;
3773                            self.write_expr(module, arg, func_ctx)?;
3774                            write!(self.out, ", ")?;
3775                            self.write_expr(module, arg1, func_ctx)?;
3776                            write!(self.out, ", 0)")?;
3777                        } else {
3778                            // Fall back to a polyfill as `dot4add_u8packed` is not available.
3779                            write!(self.out, "dot(")?;
3780
3781                            if matches!(fun, Function::Dot4U8Packed) {
3782                                write!(self.out, "u")?;
3783                            }
3784                            write!(self.out, "int4(")?;
3785                            self.write_expr(module, arg, func_ctx)?;
3786                            write!(self.out, ", ")?;
3787                            self.write_expr(module, arg, func_ctx)?;
3788                            write!(self.out, " >> 8, ")?;
3789                            self.write_expr(module, arg, func_ctx)?;
3790                            write!(self.out, " >> 16, ")?;
3791                            self.write_expr(module, arg, func_ctx)?;
3792                            write!(self.out, " >> 24) << 24 >> 24, ")?;
3793
3794                            if matches!(fun, Function::Dot4U8Packed) {
3795                                write!(self.out, "u")?;
3796                            }
3797                            write!(self.out, "int4(")?;
3798                            self.write_expr(module, arg1, func_ctx)?;
3799                            write!(self.out, ", ")?;
3800                            self.write_expr(module, arg1, func_ctx)?;
3801                            write!(self.out, " >> 8, ")?;
3802                            self.write_expr(module, arg1, func_ctx)?;
3803                            write!(self.out, " >> 16, ")?;
3804                            self.write_expr(module, arg1, func_ctx)?;
3805                            write!(self.out, " >> 24) << 24 >> 24)")?;
3806                        }
3807                    }
3808                    Function::QuantizeToF16 => {
3809                        write!(self.out, "f16tof32(f32tof16(")?;
3810                        self.write_expr(module, arg, func_ctx)?;
3811                        write!(self.out, "))")?;
3812                    }
3813                    Function::Regular(fun_name) => {
3814                        write!(self.out, "{fun_name}(")?;
3815                        self.write_expr(module, arg, func_ctx)?;
3816                        if let Some(arg) = arg1 {
3817                            write!(self.out, ", ")?;
3818                            self.write_expr(module, arg, func_ctx)?;
3819                        }
3820                        if let Some(arg) = arg2 {
3821                            write!(self.out, ", ")?;
3822                            self.write_expr(module, arg, func_ctx)?;
3823                        }
3824                        if let Some(arg) = arg3 {
3825                            write!(self.out, ", ")?;
3826                            self.write_expr(module, arg, func_ctx)?;
3827                        }
3828                        write!(self.out, ")")?
3829                    }
3830                    // These overloads are only missing on FXC, so this is only needed for 32bit types,
3831                    // as non-32bit types are DXC only.
3832                    Function::MissingIntOverload(fun_name) => {
3833                        let scalar_kind = func_ctx.resolve_type(arg, &module.types).scalar();
3834                        if let Some(Scalar::I32) = scalar_kind {
3835                            write!(self.out, "asint({fun_name}(asuint(")?;
3836                            self.write_expr(module, arg, func_ctx)?;
3837                            write!(self.out, ")))")?;
3838                        } else {
3839                            write!(self.out, "{fun_name}(")?;
3840                            self.write_expr(module, arg, func_ctx)?;
3841                            write!(self.out, ")")?;
3842                        }
3843                    }
3844                    // These overloads are only missing on FXC, so this is only needed for 32bit types,
3845                    // as non-32bit types are DXC only.
3846                    Function::MissingIntReturnType(fun_name) => {
3847                        let scalar_kind = func_ctx.resolve_type(arg, &module.types).scalar();
3848                        if let Some(Scalar::I32) = scalar_kind {
3849                            write!(self.out, "asint({fun_name}(")?;
3850                            self.write_expr(module, arg, func_ctx)?;
3851                            write!(self.out, "))")?;
3852                        } else {
3853                            write!(self.out, "{fun_name}(")?;
3854                            self.write_expr(module, arg, func_ctx)?;
3855                            write!(self.out, ")")?;
3856                        }
3857                    }
3858                    Function::CountTrailingZeros => {
3859                        match *func_ctx.resolve_type(arg, &module.types) {
3860                            TypeInner::Vector { size, scalar } => {
3861                                let s = match size {
3862                                    crate::VectorSize::Bi => ".xx",
3863                                    crate::VectorSize::Tri => ".xxx",
3864                                    crate::VectorSize::Quad => ".xxxx",
3865                                };
3866
3867                                let scalar_width_bits = scalar.width * 8;
3868
3869                                if scalar.kind == ScalarKind::Uint || scalar.width != 4 {
3870                                    write!(
3871                                        self.out,
3872                                        "min(({scalar_width_bits}u){s}, firstbitlow("
3873                                    )?;
3874                                    self.write_expr(module, arg, func_ctx)?;
3875                                    write!(self.out, "))")?;
3876                                } else {
3877                                    // This is only needed for the FXC path, on 32bit signed integers.
3878                                    write!(
3879                                        self.out,
3880                                        "asint(min(({scalar_width_bits}u){s}, firstbitlow("
3881                                    )?;
3882                                    self.write_expr(module, arg, func_ctx)?;
3883                                    write!(self.out, ")))")?;
3884                                }
3885                            }
3886                            TypeInner::Scalar(scalar) => {
3887                                let scalar_width_bits = scalar.width * 8;
3888
3889                                if scalar.kind == ScalarKind::Uint || scalar.width != 4 {
3890                                    write!(self.out, "min({scalar_width_bits}u, firstbitlow(")?;
3891                                    self.write_expr(module, arg, func_ctx)?;
3892                                    write!(self.out, "))")?;
3893                                } else {
3894                                    // This is only needed for the FXC path, on 32bit signed integers.
3895                                    write!(
3896                                        self.out,
3897                                        "asint(min({scalar_width_bits}u, firstbitlow("
3898                                    )?;
3899                                    self.write_expr(module, arg, func_ctx)?;
3900                                    write!(self.out, ")))")?;
3901                                }
3902                            }
3903                            _ => unreachable!(),
3904                        }
3905
3906                        return Ok(());
3907                    }
3908                    Function::CountLeadingZeros => {
3909                        match *func_ctx.resolve_type(arg, &module.types) {
3910                            TypeInner::Vector { size, scalar } => {
3911                                let s = match size {
3912                                    crate::VectorSize::Bi => ".xx",
3913                                    crate::VectorSize::Tri => ".xxx",
3914                                    crate::VectorSize::Quad => ".xxxx",
3915                                };
3916
3917                                // scalar width - 1
3918                                let constant = scalar.width * 8 - 1;
3919
3920                                if scalar.kind == ScalarKind::Uint {
3921                                    write!(self.out, "(({constant}u){s} - firstbithigh(")?;
3922                                    self.write_expr(module, arg, func_ctx)?;
3923                                    write!(self.out, "))")?;
3924                                } else {
3925                                    let conversion_func = match scalar.width {
3926                                        4 => "asint",
3927                                        _ => "",
3928                                    };
3929                                    write!(self.out, "(")?;
3930                                    self.write_expr(module, arg, func_ctx)?;
3931                                    write!(
3932                                        self.out,
3933                                        " < (0){s} ? (0){s} : ({constant}){s} - {conversion_func}(firstbithigh("
3934                                    )?;
3935                                    self.write_expr(module, arg, func_ctx)?;
3936                                    write!(self.out, ")))")?;
3937                                }
3938                            }
3939                            TypeInner::Scalar(scalar) => {
3940                                // scalar width - 1
3941                                let constant = scalar.width * 8 - 1;
3942
3943                                if let ScalarKind::Uint = scalar.kind {
3944                                    write!(self.out, "({constant}u - firstbithigh(")?;
3945                                    self.write_expr(module, arg, func_ctx)?;
3946                                    write!(self.out, "))")?;
3947                                } else {
3948                                    let conversion_func = match scalar.width {
3949                                        4 => "asint",
3950                                        _ => "",
3951                                    };
3952                                    write!(self.out, "(")?;
3953                                    self.write_expr(module, arg, func_ctx)?;
3954                                    write!(
3955                                        self.out,
3956                                        " < 0 ? 0 : {constant} - {conversion_func}(firstbithigh("
3957                                    )?;
3958                                    self.write_expr(module, arg, func_ctx)?;
3959                                    write!(self.out, ")))")?;
3960                                }
3961                            }
3962                            _ => unreachable!(),
3963                        }
3964
3965                        return Ok(());
3966                    }
3967                }
3968            }
3969            Expression::Swizzle {
3970                size,
3971                vector,
3972                pattern,
3973            } => {
3974                self.write_expr(module, vector, func_ctx)?;
3975                write!(self.out, ".")?;
3976                for &sc in pattern[..size as usize].iter() {
3977                    self.out.write_char(back::COMPONENTS[sc as usize])?;
3978                }
3979            }
3980            Expression::ArrayLength(expr) => {
3981                let var_handle = match func_ctx.expressions[expr] {
3982                    Expression::AccessIndex { base, index: _ } => {
3983                        match func_ctx.expressions[base] {
3984                            Expression::GlobalVariable(handle) => handle,
3985                            _ => unreachable!(),
3986                        }
3987                    }
3988                    Expression::GlobalVariable(handle) => handle,
3989                    _ => unreachable!(),
3990                };
3991
3992                let var = &module.global_variables[var_handle];
3993                let (offset, stride) = match module.types[var.ty].inner {
3994                    TypeInner::Array { stride, .. } => (0, stride),
3995                    TypeInner::Struct { ref members, .. } => {
3996                        let last = members.last().unwrap();
3997                        let stride = match module.types[last.ty].inner {
3998                            TypeInner::Array { stride, .. } => stride,
3999                            _ => unreachable!(),
4000                        };
4001                        (last.offset, stride)
4002                    }
4003                    _ => unreachable!(),
4004                };
4005
4006                let storage_access = match var.space {
4007                    crate::AddressSpace::Storage { access } => access,
4008                    _ => crate::StorageAccess::default(),
4009                };
4010                let wrapped_array_length = WrappedArrayLength {
4011                    writable: storage_access.contains(crate::StorageAccess::STORE),
4012                };
4013
4014                write!(self.out, "((")?;
4015                self.write_wrapped_array_length_function_name(wrapped_array_length)?;
4016                let var_name = &self.names[&NameKey::GlobalVariable(var_handle)];
4017                write!(self.out, "({var_name}) - {offset}) / {stride})")?
4018            }
4019            Expression::Derivative { axis, ctrl, expr } => {
4020                use crate::{DerivativeAxis as Axis, DerivativeControl as Ctrl};
4021                if axis == Axis::Width && (ctrl == Ctrl::Coarse || ctrl == Ctrl::Fine) {
4022                    let tail = match ctrl {
4023                        Ctrl::Coarse => "coarse",
4024                        Ctrl::Fine => "fine",
4025                        Ctrl::None => unreachable!(),
4026                    };
4027                    write!(self.out, "abs(ddx_{tail}(")?;
4028                    self.write_expr(module, expr, func_ctx)?;
4029                    write!(self.out, ")) + abs(ddy_{tail}(")?;
4030                    self.write_expr(module, expr, func_ctx)?;
4031                    write!(self.out, "))")?
4032                } else {
4033                    let fun_str = match (axis, ctrl) {
4034                        (Axis::X, Ctrl::Coarse) => "ddx_coarse",
4035                        (Axis::X, Ctrl::Fine) => "ddx_fine",
4036                        (Axis::X, Ctrl::None) => "ddx",
4037                        (Axis::Y, Ctrl::Coarse) => "ddy_coarse",
4038                        (Axis::Y, Ctrl::Fine) => "ddy_fine",
4039                        (Axis::Y, Ctrl::None) => "ddy",
4040                        (Axis::Width, Ctrl::Coarse | Ctrl::Fine) => unreachable!(),
4041                        (Axis::Width, Ctrl::None) => "fwidth",
4042                    };
4043                    write!(self.out, "{fun_str}(")?;
4044                    self.write_expr(module, expr, func_ctx)?;
4045                    write!(self.out, ")")?
4046                }
4047            }
4048            Expression::Relational { fun, argument } => {
4049                use crate::RelationalFunction as Rf;
4050
4051                let fun_str = match fun {
4052                    Rf::All => "all",
4053                    Rf::Any => "any",
4054                    Rf::IsNan => "isnan",
4055                    Rf::IsInf => "isinf",
4056                };
4057                write!(self.out, "{fun_str}(")?;
4058                self.write_expr(module, argument, func_ctx)?;
4059                write!(self.out, ")")?
4060            }
4061            Expression::Select {
4062                condition,
4063                accept,
4064                reject,
4065            } => {
4066                write!(self.out, "(")?;
4067                self.write_expr(module, condition, func_ctx)?;
4068                write!(self.out, " ? ")?;
4069                self.write_expr(module, accept, func_ctx)?;
4070                write!(self.out, " : ")?;
4071                self.write_expr(module, reject, func_ctx)?;
4072                write!(self.out, ")")?
4073            }
4074            Expression::RayQueryGetIntersection { query, committed } => {
4075                if committed {
4076                    write!(self.out, "GetCommittedIntersection(")?;
4077                    self.write_expr(module, query, func_ctx)?;
4078                    write!(self.out, ")")?;
4079                } else {
4080                    write!(self.out, "GetCandidateIntersection(")?;
4081                    self.write_expr(module, query, func_ctx)?;
4082                    write!(self.out, ")")?;
4083                }
4084            }
4085            // Not supported yet
4086            Expression::RayQueryVertexPositions { .. } => unreachable!(),
4087            // Nothing to do here, since call expression already cached
4088            Expression::CallResult(_)
4089            | Expression::AtomicResult { .. }
4090            | Expression::WorkGroupUniformLoadResult { .. }
4091            | Expression::RayQueryProceedResult
4092            | Expression::SubgroupBallotResult
4093            | Expression::SubgroupOperationResult { .. } => {}
4094        }
4095
4096        if !closing_bracket.is_empty() {
4097            write!(self.out, "{closing_bracket}")?;
4098        }
4099        Ok(())
4100    }
4101
4102    #[allow(clippy::too_many_arguments)]
4103    fn write_image_load(
4104        &mut self,
4105        module: &&Module,
4106        expr: Handle<crate::Expression>,
4107        func_ctx: &back::FunctionCtx,
4108        image: Handle<crate::Expression>,
4109        coordinate: Handle<crate::Expression>,
4110        array_index: Option<Handle<crate::Expression>>,
4111        sample: Option<Handle<crate::Expression>>,
4112        level: Option<Handle<crate::Expression>>,
4113    ) -> Result<(), Error> {
4114        let mut wrapping_type = None;
4115        match *func_ctx.resolve_type(image, &module.types) {
4116            TypeInner::Image {
4117                class: crate::ImageClass::Storage { format, .. },
4118                ..
4119            } => {
4120                if format.single_component() {
4121                    wrapping_type = Some(Scalar::from(format));
4122                }
4123            }
4124            _ => {}
4125        }
4126        if let Some(scalar) = wrapping_type {
4127            write!(
4128                self.out,
4129                "{}{}(",
4130                help::IMAGE_STORAGE_LOAD_SCALAR_WRAPPER,
4131                scalar.to_hlsl_str()?
4132            )?;
4133        }
4134        // https://docs.microsoft.com/en-us/windows/win32/direct3dhlsl/dx-graphics-hlsl-to-load
4135        self.write_expr(module, image, func_ctx)?;
4136        write!(self.out, ".Load(")?;
4137
4138        self.write_texture_coordinates("int", coordinate, array_index, level, module, func_ctx)?;
4139
4140        if let Some(sample) = sample {
4141            write!(self.out, ", ")?;
4142            self.write_expr(module, sample, func_ctx)?;
4143        }
4144
4145        // close bracket for Load function
4146        write!(self.out, ")")?;
4147
4148        if wrapping_type.is_some() {
4149            write!(self.out, ")")?;
4150        }
4151
4152        // return x component if return type is scalar
4153        if let TypeInner::Scalar(_) = *func_ctx.resolve_type(expr, &module.types) {
4154            write!(self.out, ".x")?;
4155        }
4156        Ok(())
4157    }
4158
4159    /// Find the [`BindingArraySamplerInfo`] from an expression so that such an access
4160    /// can be generated later.
4161    fn sampler_binding_array_info_from_expression(
4162        &mut self,
4163        module: &Module,
4164        func_ctx: &back::FunctionCtx<'_>,
4165        base: Handle<crate::Expression>,
4166        resolved: &TypeInner,
4167    ) -> Option<BindingArraySamplerInfo> {
4168        if let TypeInner::BindingArray {
4169            base: base_ty_handle,
4170            ..
4171        } = *resolved
4172        {
4173            let base_ty = &module.types[base_ty_handle].inner;
4174            if let TypeInner::Sampler { comparison, .. } = *base_ty {
4175                let base = &func_ctx.expressions[base];
4176
4177                if let crate::Expression::GlobalVariable(handle) = *base {
4178                    let variable = &module.global_variables[handle];
4179
4180                    let sampler_heap_name = match comparison {
4181                        true => COMPARISON_SAMPLER_HEAP_VAR,
4182                        false => SAMPLER_HEAP_VAR,
4183                    };
4184
4185                    return Some(BindingArraySamplerInfo {
4186                        sampler_heap_name,
4187                        sampler_index_buffer_name: self
4188                            .wrapped
4189                            .sampler_index_buffers
4190                            .get(&super::SamplerIndexBufferKey {
4191                                group: variable.binding.unwrap().group,
4192                            })
4193                            .unwrap()
4194                            .clone(),
4195                        binding_array_base_index_name: self.names[&NameKey::GlobalVariable(handle)]
4196                            .clone(),
4197                    });
4198                }
4199            }
4200        }
4201
4202        None
4203    }
4204
4205    fn write_named_expr(
4206        &mut self,
4207        module: &Module,
4208        handle: Handle<crate::Expression>,
4209        name: String,
4210        // The expression which is being named.
4211        // Generally, this is the same as handle, except in WorkGroupUniformLoad
4212        named: Handle<crate::Expression>,
4213        ctx: &back::FunctionCtx,
4214    ) -> BackendResult {
4215        match ctx.info[named].ty {
4216            proc::TypeResolution::Handle(ty_handle) => match module.types[ty_handle].inner {
4217                TypeInner::Struct { .. } => {
4218                    let ty_name = &self.names[&NameKey::Type(ty_handle)];
4219                    write!(self.out, "{ty_name}")?;
4220                }
4221                _ => {
4222                    self.write_type(module, ty_handle)?;
4223                }
4224            },
4225            proc::TypeResolution::Value(ref inner) => {
4226                self.write_value_type(module, inner)?;
4227            }
4228        }
4229
4230        let resolved = ctx.resolve_type(named, &module.types);
4231
4232        write!(self.out, " {name}")?;
4233        // If rhs is a array type, we should write array size
4234        if let TypeInner::Array { base, size, .. } = *resolved {
4235            self.write_array_size(module, base, size)?;
4236        }
4237        write!(self.out, " = ")?;
4238        self.write_expr(module, handle, ctx)?;
4239        writeln!(self.out, ";")?;
4240        self.named_expressions.insert(named, name);
4241
4242        Ok(())
4243    }
4244
4245    /// Helper function that write default zero initialization
4246    pub(super) fn write_default_init(
4247        &mut self,
4248        module: &Module,
4249        ty: Handle<crate::Type>,
4250    ) -> BackendResult {
4251        write!(self.out, "(")?;
4252        self.write_type(module, ty)?;
4253        if let TypeInner::Array { base, size, .. } = module.types[ty].inner {
4254            self.write_array_size(module, base, size)?;
4255        }
4256        write!(self.out, ")0")?;
4257        Ok(())
4258    }
4259
4260    fn write_control_barrier(
4261        &mut self,
4262        barrier: crate::Barrier,
4263        level: back::Level,
4264    ) -> BackendResult {
4265        if barrier.contains(crate::Barrier::STORAGE) {
4266            writeln!(self.out, "{level}DeviceMemoryBarrierWithGroupSync();")?;
4267        }
4268        if barrier.contains(crate::Barrier::WORK_GROUP) {
4269            writeln!(self.out, "{level}GroupMemoryBarrierWithGroupSync();")?;
4270        }
4271        if barrier.contains(crate::Barrier::SUB_GROUP) {
4272            // Does not exist in DirectX
4273        }
4274        if barrier.contains(crate::Barrier::TEXTURE) {
4275            writeln!(self.out, "{level}DeviceMemoryBarrierWithGroupSync();")?;
4276        }
4277        Ok(())
4278    }
4279
4280    fn write_memory_barrier(
4281        &mut self,
4282        barrier: crate::Barrier,
4283        level: back::Level,
4284    ) -> BackendResult {
4285        if barrier.contains(crate::Barrier::STORAGE) {
4286            writeln!(self.out, "{level}DeviceMemoryBarrier();")?;
4287        }
4288        if barrier.contains(crate::Barrier::WORK_GROUP) {
4289            writeln!(self.out, "{level}GroupMemoryBarrier();")?;
4290        }
4291        if barrier.contains(crate::Barrier::SUB_GROUP) {
4292            // Does not exist in DirectX
4293        }
4294        if barrier.contains(crate::Barrier::TEXTURE) {
4295            writeln!(self.out, "{level}DeviceMemoryBarrier();")?;
4296        }
4297        Ok(())
4298    }
4299
4300    /// Helper to emit the shared tail of an HLSL atomic call (arguments, value, result)
4301    fn emit_hlsl_atomic_tail(
4302        &mut self,
4303        module: &Module,
4304        func_ctx: &back::FunctionCtx<'_>,
4305        fun: &crate::AtomicFunction,
4306        compare_expr: Option<Handle<crate::Expression>>,
4307        value: Handle<crate::Expression>,
4308        res_var_info: &Option<(Handle<crate::Expression>, String)>,
4309    ) -> BackendResult {
4310        if let Some(cmp) = compare_expr {
4311            write!(self.out, ", ")?;
4312            self.write_expr(module, cmp, func_ctx)?;
4313        }
4314        write!(self.out, ", ")?;
4315        if let crate::AtomicFunction::Subtract = *fun {
4316            // we just wrote `InterlockedAdd`, so negate the argument
4317            write!(self.out, "-")?;
4318        }
4319        self.write_expr(module, value, func_ctx)?;
4320        if let Some(&(_res_handle, ref res_name)) = res_var_info.as_ref() {
4321            write!(self.out, ", ")?;
4322            if compare_expr.is_some() {
4323                write!(self.out, "{res_name}.old_value")?;
4324            } else {
4325                write!(self.out, "{res_name}")?;
4326            }
4327        }
4328        writeln!(self.out, ");")?;
4329        Ok(())
4330    }
4331}
4332
4333pub(super) struct MatrixType {
4334    pub(super) columns: crate::VectorSize,
4335    pub(super) rows: crate::VectorSize,
4336    pub(super) width: crate::Bytes,
4337}
4338
4339pub(super) fn get_inner_matrix_data(
4340    module: &Module,
4341    handle: Handle<crate::Type>,
4342) -> Option<MatrixType> {
4343    match module.types[handle].inner {
4344        TypeInner::Matrix {
4345            columns,
4346            rows,
4347            scalar,
4348        } => Some(MatrixType {
4349            columns,
4350            rows,
4351            width: scalar.width,
4352        }),
4353        TypeInner::Array { base, .. } => get_inner_matrix_data(module, base),
4354        _ => None,
4355    }
4356}
4357
4358/// If `base` is an access chain of the form `mat`, `mat[col]`, or `mat[col][row]`,
4359/// returns a tuple of the matrix, the column (vector) index (if present), and
4360/// the row (scalar) index (if present).
4361fn find_matrix_in_access_chain(
4362    module: &Module,
4363    base: Handle<crate::Expression>,
4364    func_ctx: &back::FunctionCtx<'_>,
4365) -> Option<(Handle<crate::Expression>, Option<Index>, Option<Index>)> {
4366    let mut current_base = base;
4367    let mut vector = None;
4368    let mut scalar = None;
4369    loop {
4370        let resolved_tr = func_ctx
4371            .resolve_type(current_base, &module.types)
4372            .pointer_base_type();
4373        let resolved = resolved_tr.as_ref()?.inner_with(&module.types);
4374
4375        match *resolved {
4376            TypeInner::Matrix { .. } => return Some((current_base, vector, scalar)),
4377            TypeInner::Scalar(_) | TypeInner::Vector { .. } => {}
4378            _ => return None,
4379        }
4380
4381        let index;
4382        (current_base, index) = match func_ctx.expressions[current_base] {
4383            crate::Expression::Access { base, index } => (base, Index::Expression(index)),
4384            crate::Expression::AccessIndex { base, index } => (base, Index::Static(index)),
4385            _ => return None,
4386        };
4387
4388        match *resolved {
4389            TypeInner::Scalar(_) => scalar = Some(index),
4390            TypeInner::Vector { .. } => vector = Some(index),
4391            _ => unreachable!(),
4392        }
4393    }
4394}
4395
4396/// Returns the matrix data if the access chain starting at `base`:
4397/// - starts with an expression with resolved type of [`TypeInner::Matrix`] if `direct = true`
4398/// - contains one or more expressions with resolved type of [`TypeInner::Array`] of [`TypeInner::Matrix`]
4399/// - ends at an expression with resolved type of [`TypeInner::Struct`]
4400pub(super) fn get_inner_matrix_of_struct_array_member(
4401    module: &Module,
4402    base: Handle<crate::Expression>,
4403    func_ctx: &back::FunctionCtx<'_>,
4404    direct: bool,
4405) -> Option<MatrixType> {
4406    let mut mat_data = None;
4407    let mut array_base = None;
4408
4409    let mut current_base = base;
4410    loop {
4411        let mut resolved = func_ctx.resolve_type(current_base, &module.types);
4412        if let TypeInner::Pointer { base, .. } = *resolved {
4413            resolved = &module.types[base].inner;
4414        };
4415
4416        match *resolved {
4417            TypeInner::Matrix {
4418                columns,
4419                rows,
4420                scalar,
4421            } => {
4422                mat_data = Some(MatrixType {
4423                    columns,
4424                    rows,
4425                    width: scalar.width,
4426                })
4427            }
4428            TypeInner::Array { base, .. } => {
4429                array_base = Some(base);
4430            }
4431            TypeInner::Struct { .. } => {
4432                if let Some(array_base) = array_base {
4433                    if direct {
4434                        return mat_data;
4435                    } else {
4436                        return get_inner_matrix_data(module, array_base);
4437                    }
4438                }
4439
4440                break;
4441            }
4442            _ => break,
4443        }
4444
4445        current_base = match func_ctx.expressions[current_base] {
4446            crate::Expression::Access { base, .. } => base,
4447            crate::Expression::AccessIndex { base, .. } => base,
4448            _ => break,
4449        };
4450    }
4451    None
4452}
4453
4454/// Simpler version of get_inner_matrix_of_global_uniform that only looks at the
4455/// immediate expression, rather than traversing an access chain.
4456fn get_global_uniform_matrix(
4457    module: &Module,
4458    base: Handle<crate::Expression>,
4459    func_ctx: &back::FunctionCtx<'_>,
4460) -> Option<MatrixType> {
4461    let base_tr = func_ctx
4462        .resolve_type(base, &module.types)
4463        .pointer_base_type();
4464    let base_ty = base_tr.as_ref().map(|tr| tr.inner_with(&module.types));
4465    match (&func_ctx.expressions[base], base_ty) {
4466        (
4467            &crate::Expression::GlobalVariable(handle),
4468            Some(&TypeInner::Matrix {
4469                columns,
4470                rows,
4471                scalar,
4472            }),
4473        ) if module.global_variables[handle].space == crate::AddressSpace::Uniform => {
4474            Some(MatrixType {
4475                columns,
4476                rows,
4477                width: scalar.width,
4478            })
4479        }
4480        _ => None,
4481    }
4482}
4483
4484/// Returns the matrix data if the access chain starting at `base`:
4485/// - starts with an expression with resolved type of [`TypeInner::Matrix`]
4486/// - contains zero or more expressions with resolved type of [`TypeInner::Array`] of [`TypeInner::Matrix`]
4487/// - ends with an [`Expression::GlobalVariable`](crate::Expression::GlobalVariable) in [`AddressSpace::Uniform`](crate::AddressSpace::Uniform)
4488fn get_inner_matrix_of_global_uniform(
4489    module: &Module,
4490    base: Handle<crate::Expression>,
4491    func_ctx: &back::FunctionCtx<'_>,
4492) -> Option<MatrixType> {
4493    let mut mat_data = None;
4494    let mut array_base = None;
4495
4496    let mut current_base = base;
4497    loop {
4498        let mut resolved = func_ctx.resolve_type(current_base, &module.types);
4499        if let TypeInner::Pointer { base, .. } = *resolved {
4500            resolved = &module.types[base].inner;
4501        };
4502
4503        match *resolved {
4504            TypeInner::Matrix {
4505                columns,
4506                rows,
4507                scalar,
4508            } => {
4509                mat_data = Some(MatrixType {
4510                    columns,
4511                    rows,
4512                    width: scalar.width,
4513                })
4514            }
4515            TypeInner::Array { base, .. } => {
4516                array_base = Some(base);
4517            }
4518            _ => break,
4519        }
4520
4521        current_base = match func_ctx.expressions[current_base] {
4522            crate::Expression::Access { base, .. } => base,
4523            crate::Expression::AccessIndex { base, .. } => base,
4524            crate::Expression::GlobalVariable(handle)
4525                if module.global_variables[handle].space == crate::AddressSpace::Uniform =>
4526            {
4527                return mat_data.or_else(|| {
4528                    array_base.and_then(|array_base| get_inner_matrix_data(module, array_base))
4529                })
4530            }
4531            _ => break,
4532        };
4533    }
4534    None
4535}