naga/back/hlsl/
writer.rs

1use alloc::{
2    format,
3    string::{String, ToString},
4    vec::Vec,
5};
6use core::{fmt, mem};
7
8use super::{
9    help,
10    help::{
11        WrappedArrayLength, WrappedConstructor, WrappedImageQuery, WrappedStructMatrixAccess,
12        WrappedZeroValue,
13    },
14    storage::StoreValue,
15    BackendResult, Error, FragmentEntryPoint, Options, PipelineOptions, ShaderModel,
16};
17use crate::{
18    back::{self, get_entry_points, Baked},
19    common,
20    proc::{self, index, ExternalTextureNameKey, NameKey},
21    valid, Handle, Module, RayQueryFunction, Scalar, ScalarKind, ShaderStage, TypeInner,
22};
23
24const LOCATION_SEMANTIC: &str = "LOC";
25const SPECIAL_CBUF_TYPE: &str = "NagaConstants";
26const SPECIAL_CBUF_VAR: &str = "_NagaConstants";
27const SPECIAL_FIRST_VERTEX: &str = "first_vertex";
28const SPECIAL_FIRST_INSTANCE: &str = "first_instance";
29const SPECIAL_OTHER: &str = "other";
30
31pub(crate) const MODF_FUNCTION: &str = "naga_modf";
32pub(crate) const FREXP_FUNCTION: &str = "naga_frexp";
33pub(crate) const EXTRACT_BITS_FUNCTION: &str = "naga_extractBits";
34pub(crate) const INSERT_BITS_FUNCTION: &str = "naga_insertBits";
35pub(crate) const SAMPLER_HEAP_VAR: &str = "nagaSamplerHeap";
36pub(crate) const COMPARISON_SAMPLER_HEAP_VAR: &str = "nagaComparisonSamplerHeap";
37pub(crate) const SAMPLE_EXTERNAL_TEXTURE_FUNCTION: &str = "nagaSampleExternalTexture";
38pub(crate) const ABS_FUNCTION: &str = "naga_abs";
39pub(crate) const DIV_FUNCTION: &str = "naga_div";
40pub(crate) const MOD_FUNCTION: &str = "naga_mod";
41pub(crate) const NEG_FUNCTION: &str = "naga_neg";
42pub(crate) const F2I32_FUNCTION: &str = "naga_f2i32";
43pub(crate) const F2U32_FUNCTION: &str = "naga_f2u32";
44pub(crate) const F2I64_FUNCTION: &str = "naga_f2i64";
45pub(crate) const F2U64_FUNCTION: &str = "naga_f2u64";
46pub(crate) const IMAGE_SAMPLE_BASE_CLAMP_TO_EDGE_FUNCTION: &str =
47    "nagaTextureSampleBaseClampToEdge";
48pub(crate) const IMAGE_LOAD_EXTERNAL_FUNCTION: &str = "nagaTextureLoadExternal";
49
50enum Index {
51    Expression(Handle<crate::Expression>),
52    Static(u32),
53}
54
55struct EpStructMember {
56    name: String,
57    ty: Handle<crate::Type>,
58    // technically, this should always be `Some`
59    // (we `debug_assert!` this in `write_interface_struct`)
60    binding: Option<crate::Binding>,
61    index: u32,
62}
63
64/// Structure contains information required for generating
65/// wrapped structure of all entry points arguments
66struct EntryPointBinding {
67    /// Name of the fake EP argument that contains the struct
68    /// with all the flattened input data.
69    arg_name: String,
70    /// Generated structure name
71    ty_name: String,
72    /// Members of generated structure
73    members: Vec<EpStructMember>,
74}
75
76pub(super) struct EntryPointInterface {
77    /// If `Some`, the input of an entry point is gathered in a special
78    /// struct with members sorted by binding.
79    /// The `EntryPointBinding::members` array is sorted by index,
80    /// so that we can walk it in `write_ep_arguments_initialization`.
81    input: Option<EntryPointBinding>,
82    /// If `Some`, the output of an entry point is flattened.
83    /// The `EntryPointBinding::members` array is sorted by binding,
84    /// So that we can walk it in `Statement::Return` handler.
85    output: Option<EntryPointBinding>,
86}
87
88#[derive(Clone, Eq, PartialEq, PartialOrd, Ord)]
89enum InterfaceKey {
90    Location(u32),
91    BuiltIn(crate::BuiltIn),
92    Other,
93}
94
95impl InterfaceKey {
96    const fn new(binding: Option<&crate::Binding>) -> Self {
97        match binding {
98            Some(&crate::Binding::Location { location, .. }) => Self::Location(location),
99            Some(&crate::Binding::BuiltIn(built_in)) => Self::BuiltIn(built_in),
100            None => Self::Other,
101        }
102    }
103}
104
105#[derive(Copy, Clone, PartialEq)]
106enum Io {
107    Input,
108    Output,
109}
110
111const fn is_subgroup_builtin_binding(binding: &Option<crate::Binding>) -> bool {
112    let &Some(crate::Binding::BuiltIn(builtin)) = binding else {
113        return false;
114    };
115    matches!(
116        builtin,
117        crate::BuiltIn::SubgroupSize
118            | crate::BuiltIn::SubgroupInvocationId
119            | crate::BuiltIn::NumSubgroups
120            | crate::BuiltIn::SubgroupId
121    )
122}
123
124/// Information for how to generate a `binding_array<sampler>` access.
125struct BindingArraySamplerInfo {
126    /// Variable name of the sampler heap
127    sampler_heap_name: &'static str,
128    /// Variable name of the sampler index buffer
129    sampler_index_buffer_name: String,
130    /// Variable name of the base index _into_ the sampler index buffer
131    binding_array_base_index_name: String,
132}
133
134impl<'a, W: fmt::Write> super::Writer<'a, W> {
135    pub fn new(out: W, options: &'a Options, pipeline_options: &'a PipelineOptions) -> Self {
136        Self {
137            out,
138            names: crate::FastHashMap::default(),
139            namer: proc::Namer::default(),
140            options,
141            pipeline_options,
142            entry_point_io: crate::FastHashMap::default(),
143            named_expressions: crate::NamedExpressions::default(),
144            wrapped: super::Wrapped::default(),
145            written_committed_intersection: false,
146            written_candidate_intersection: false,
147            continue_ctx: back::continue_forward::ContinueCtx::default(),
148            temp_access_chain: Vec::new(),
149            need_bake_expressions: Default::default(),
150        }
151    }
152
153    fn reset(&mut self, module: &Module) {
154        self.names.clear();
155        self.namer.reset(
156            module,
157            &super::keywords::RESERVED_SET,
158            &super::keywords::RESERVED_CASE_INSENSITIVE_SET,
159            super::keywords::RESERVED_PREFIXES,
160            &mut self.names,
161        );
162        self.entry_point_io.clear();
163        self.named_expressions.clear();
164        self.wrapped.clear();
165        self.written_committed_intersection = false;
166        self.written_candidate_intersection = false;
167        self.continue_ctx.clear();
168        self.need_bake_expressions.clear();
169    }
170
171    /// Generates statements to be inserted immediately before and at the very
172    /// start of the body of each loop, to defeat infinite loop reasoning.
173    /// The 0th item of the returned tuple should be inserted immediately prior
174    /// to the loop and the 1st item should be inserted at the very start of
175    /// the loop body.
176    ///
177    /// See [`back::msl::Writer::gen_force_bounded_loop_statements`] for details.
178    fn gen_force_bounded_loop_statements(
179        &mut self,
180        level: back::Level,
181    ) -> Option<(String, String)> {
182        if !self.options.force_loop_bounding {
183            return None;
184        }
185
186        let loop_bound_name = self.namer.call("loop_bound");
187        let max = u32::MAX;
188        // Count down from u32::MAX rather than up from 0 to avoid hang on
189        // certain Intel drivers. See <https://github.com/gfx-rs/wgpu/issues/7319>.
190        let decl = format!("{level}uint2 {loop_bound_name} = uint2({max}u, {max}u);");
191        let level = level.next();
192        let break_and_inc = format!(
193            "{level}if (all({loop_bound_name} == uint2(0u, 0u))) {{ break; }}
194{level}{loop_bound_name} -= uint2({loop_bound_name}.y == 0u, 1u);"
195        );
196
197        Some((decl, break_and_inc))
198    }
199
200    /// Helper method used to find which expressions of a given function require baking
201    ///
202    /// # Notes
203    /// Clears `need_bake_expressions` set before adding to it
204    fn update_expressions_to_bake(
205        &mut self,
206        module: &Module,
207        func: &crate::Function,
208        info: &valid::FunctionInfo,
209    ) {
210        use crate::Expression;
211        self.need_bake_expressions.clear();
212        for (exp_handle, expr) in func.expressions.iter() {
213            let expr_info = &info[exp_handle];
214            let min_ref_count = func.expressions[exp_handle].bake_ref_count();
215            if min_ref_count <= expr_info.ref_count {
216                self.need_bake_expressions.insert(exp_handle);
217            }
218
219            if let Expression::Math { fun, arg, arg1, .. } = *expr {
220                match fun {
221                    crate::MathFunction::Asinh
222                    | crate::MathFunction::Acosh
223                    | crate::MathFunction::Atanh
224                    | crate::MathFunction::Unpack2x16float
225                    | crate::MathFunction::Unpack2x16snorm
226                    | crate::MathFunction::Unpack2x16unorm
227                    | crate::MathFunction::Unpack4x8snorm
228                    | crate::MathFunction::Unpack4x8unorm
229                    | crate::MathFunction::Unpack4xI8
230                    | crate::MathFunction::Unpack4xU8
231                    | crate::MathFunction::Pack2x16float
232                    | crate::MathFunction::Pack2x16snorm
233                    | crate::MathFunction::Pack2x16unorm
234                    | crate::MathFunction::Pack4x8snorm
235                    | crate::MathFunction::Pack4x8unorm
236                    | crate::MathFunction::Pack4xI8
237                    | crate::MathFunction::Pack4xU8
238                    | crate::MathFunction::Pack4xI8Clamp
239                    | crate::MathFunction::Pack4xU8Clamp => {
240                        self.need_bake_expressions.insert(arg);
241                    }
242                    crate::MathFunction::CountLeadingZeros => {
243                        let inner = info[exp_handle].ty.inner_with(&module.types);
244                        if let Some(ScalarKind::Sint) = inner.scalar_kind() {
245                            self.need_bake_expressions.insert(arg);
246                        }
247                    }
248                    crate::MathFunction::Dot4U8Packed | crate::MathFunction::Dot4I8Packed => {
249                        self.need_bake_expressions.insert(arg);
250                        self.need_bake_expressions.insert(arg1.unwrap());
251                    }
252                    _ => {}
253                }
254            }
255
256            if let Expression::Derivative { axis, ctrl, expr } = *expr {
257                use crate::{DerivativeAxis as Axis, DerivativeControl as Ctrl};
258                if axis == Axis::Width && (ctrl == Ctrl::Coarse || ctrl == Ctrl::Fine) {
259                    self.need_bake_expressions.insert(expr);
260                }
261            }
262
263            if let Expression::GlobalVariable(_) = *expr {
264                let inner = info[exp_handle].ty.inner_with(&module.types);
265
266                if let TypeInner::Sampler { .. } = *inner {
267                    self.need_bake_expressions.insert(exp_handle);
268                }
269            }
270        }
271        for statement in func.body.iter() {
272            match *statement {
273                crate::Statement::SubgroupCollectiveOperation {
274                    op: _,
275                    collective_op: crate::CollectiveOperation::InclusiveScan,
276                    argument,
277                    result: _,
278                } => {
279                    self.need_bake_expressions.insert(argument);
280                }
281                crate::Statement::Atomic {
282                    fun: crate::AtomicFunction::Exchange { compare: Some(cmp) },
283                    ..
284                } => {
285                    self.need_bake_expressions.insert(cmp);
286                }
287                _ => {}
288            }
289        }
290    }
291
292    pub fn write(
293        &mut self,
294        module: &Module,
295        module_info: &valid::ModuleInfo,
296        fragment_entry_point: Option<&FragmentEntryPoint<'_>>,
297    ) -> Result<super::ReflectionInfo, Error> {
298        self.reset(module);
299
300        // Write special constants, if needed
301        if let Some(ref bt) = self.options.special_constants_binding {
302            writeln!(self.out, "struct {SPECIAL_CBUF_TYPE} {{")?;
303            writeln!(self.out, "{}int {};", back::INDENT, SPECIAL_FIRST_VERTEX)?;
304            writeln!(self.out, "{}int {};", back::INDENT, SPECIAL_FIRST_INSTANCE)?;
305            writeln!(self.out, "{}uint {};", back::INDENT, SPECIAL_OTHER)?;
306            writeln!(self.out, "}};")?;
307            write!(
308                self.out,
309                "ConstantBuffer<{}> {}: register(b{}",
310                SPECIAL_CBUF_TYPE, SPECIAL_CBUF_VAR, bt.register
311            )?;
312            if bt.space != 0 {
313                write!(self.out, ", space{}", bt.space)?;
314            }
315            writeln!(self.out, ");")?;
316
317            // Extra newline for readability
318            writeln!(self.out)?;
319        }
320
321        for (group, bt) in self.options.dynamic_storage_buffer_offsets_targets.iter() {
322            writeln!(self.out, "struct __dynamic_buffer_offsetsTy{group} {{")?;
323            for i in 0..bt.size {
324                writeln!(self.out, "{}uint _{};", back::INDENT, i)?;
325            }
326            writeln!(self.out, "}};")?;
327            writeln!(
328                self.out,
329                "ConstantBuffer<__dynamic_buffer_offsetsTy{}> __dynamic_buffer_offsets{}: register(b{}, space{});",
330                group, group, bt.register, bt.space
331            )?;
332
333            // Extra newline for readability
334            writeln!(self.out)?;
335        }
336
337        // Save all entry point output types
338        let ep_results = module
339            .entry_points
340            .iter()
341            .map(|ep| (ep.stage, ep.function.result.clone()))
342            .collect::<Vec<(ShaderStage, Option<crate::FunctionResult>)>>();
343
344        self.write_all_mat_cx2_typedefs_and_functions(module)?;
345
346        // Write all structs
347        for (handle, ty) in module.types.iter() {
348            if let TypeInner::Struct { ref members, span } = ty.inner {
349                if module.types[members.last().unwrap().ty]
350                    .inner
351                    .is_dynamically_sized(&module.types)
352                {
353                    // unsized arrays can only be in storage buffers,
354                    // for which we use `ByteAddressBuffer` anyway.
355                    continue;
356                }
357
358                let ep_result = ep_results.iter().find(|e| {
359                    if let Some(ref result) = e.1 {
360                        result.ty == handle
361                    } else {
362                        false
363                    }
364                });
365
366                self.write_struct(
367                    module,
368                    handle,
369                    members,
370                    span,
371                    ep_result.map(|r| (r.0, Io::Output)),
372                )?;
373                writeln!(self.out)?;
374            }
375        }
376
377        self.write_special_functions(module)?;
378
379        self.write_wrapped_expression_functions(module, &module.global_expressions, None)?;
380        self.write_wrapped_zero_value_functions(module, &module.global_expressions)?;
381
382        // Write all named constants
383        let mut constants = module
384            .constants
385            .iter()
386            .filter(|&(_, c)| c.name.is_some())
387            .peekable();
388        while let Some((handle, _)) = constants.next() {
389            self.write_global_constant(module, handle)?;
390            // Add extra newline for readability on last iteration
391            if constants.peek().is_none() {
392                writeln!(self.out)?;
393            }
394        }
395
396        // Write all globals
397        for (global, _) in module.global_variables.iter() {
398            self.write_global(module, global)?;
399        }
400
401        if !module.global_variables.is_empty() {
402            // Add extra newline for readability
403            writeln!(self.out)?;
404        }
405
406        let ep_range = get_entry_points(module, self.pipeline_options.entry_point.as_ref())
407            .map_err(|(stage, name)| Error::EntryPointNotFound(stage, name))?;
408
409        // Write all entry points wrapped structs
410        for index in ep_range.clone() {
411            let ep = &module.entry_points[index];
412            let ep_name = self.names[&NameKey::EntryPoint(index as u16)].clone();
413            let ep_io = self.write_ep_interface(
414                module,
415                &ep.function,
416                ep.stage,
417                &ep_name,
418                fragment_entry_point,
419            )?;
420            self.entry_point_io.insert(index, ep_io);
421        }
422
423        // Write all regular functions
424        for (handle, function) in module.functions.iter() {
425            let info = &module_info[handle];
426
427            // Check if all of the globals are accessible
428            if !self.options.fake_missing_bindings {
429                if let Some((var_handle, _)) =
430                    module
431                        .global_variables
432                        .iter()
433                        .find(|&(var_handle, var)| match var.binding {
434                            Some(ref binding) if !info[var_handle].is_empty() => {
435                                self.options.resolve_resource_binding(binding).is_err()
436                                    && self
437                                        .options
438                                        .resolve_external_texture_resource_binding(binding)
439                                        .is_err()
440                            }
441                            _ => false,
442                        })
443                {
444                    log::info!(
445                        "Skipping function {:?} (name {:?}) because global {:?} is inaccessible",
446                        handle,
447                        function.name,
448                        var_handle
449                    );
450                    continue;
451                }
452            }
453
454            let ctx = back::FunctionCtx {
455                ty: back::FunctionType::Function(handle),
456                info,
457                expressions: &function.expressions,
458                named_expressions: &function.named_expressions,
459            };
460            let name = self.names[&NameKey::Function(handle)].clone();
461
462            self.write_wrapped_functions(module, &ctx)?;
463
464            self.write_function(module, name.as_str(), function, &ctx, info)?;
465
466            writeln!(self.out)?;
467        }
468
469        let mut translated_ep_names = Vec::with_capacity(ep_range.len());
470
471        // Write all entry points
472        for index in ep_range {
473            let ep = &module.entry_points[index];
474            let info = module_info.get_entry_point(index);
475
476            if !self.options.fake_missing_bindings {
477                let mut ep_error = None;
478                for (var_handle, var) in module.global_variables.iter() {
479                    match var.binding {
480                        Some(ref binding) if !info[var_handle].is_empty() => {
481                            if let Err(err) = self.options.resolve_resource_binding(binding) {
482                                if self
483                                    .options
484                                    .resolve_external_texture_resource_binding(binding)
485                                    .is_err()
486                                {
487                                    ep_error = Some(err);
488                                    break;
489                                }
490                            }
491                        }
492                        _ => {}
493                    }
494                }
495                if let Some(err) = ep_error {
496                    translated_ep_names.push(Err(err));
497                    continue;
498                }
499            }
500
501            let ctx = back::FunctionCtx {
502                ty: back::FunctionType::EntryPoint(index as u16),
503                info,
504                expressions: &ep.function.expressions,
505                named_expressions: &ep.function.named_expressions,
506            };
507
508            self.write_wrapped_functions(module, &ctx)?;
509
510            if ep.stage.compute_like() {
511                // HLSL is calling workgroup size "num threads"
512                let num_threads = ep.workgroup_size;
513                writeln!(
514                    self.out,
515                    "[numthreads({}, {}, {})]",
516                    num_threads[0], num_threads[1], num_threads[2]
517                )?;
518            }
519
520            let name = self.names[&NameKey::EntryPoint(index as u16)].clone();
521            self.write_function(module, &name, &ep.function, &ctx, info)?;
522
523            if index < module.entry_points.len() - 1 {
524                writeln!(self.out)?;
525            }
526
527            translated_ep_names.push(Ok(name));
528        }
529
530        Ok(super::ReflectionInfo {
531            entry_point_names: translated_ep_names,
532        })
533    }
534
535    fn write_modifier(&mut self, binding: &crate::Binding) -> BackendResult {
536        match *binding {
537            crate::Binding::BuiltIn(crate::BuiltIn::Position { invariant: true }) => {
538                write!(self.out, "precise ")?;
539            }
540            crate::Binding::Location {
541                interpolation,
542                sampling,
543                ..
544            } => {
545                if let Some(interpolation) = interpolation {
546                    if let Some(string) = interpolation.to_hlsl_str() {
547                        write!(self.out, "{string} ")?
548                    }
549                }
550
551                if let Some(sampling) = sampling {
552                    if let Some(string) = sampling.to_hlsl_str() {
553                        write!(self.out, "{string} ")?
554                    }
555                }
556            }
557            crate::Binding::BuiltIn(_) => {}
558        }
559
560        Ok(())
561    }
562
563    //TODO: we could force fragment outputs to always go through `entry_point_io.output` path
564    // if they are struct, so that the `stage` argument here could be omitted.
565    fn write_semantic(
566        &mut self,
567        binding: &Option<crate::Binding>,
568        stage: Option<(ShaderStage, Io)>,
569    ) -> BackendResult {
570        match *binding {
571            Some(crate::Binding::BuiltIn(builtin)) if !is_subgroup_builtin_binding(binding) => {
572                if builtin == crate::BuiltIn::ViewIndex
573                    && self.options.shader_model < ShaderModel::V6_1
574                {
575                    return Err(Error::ShaderModelTooLow(
576                        "used @builtin(view_index) or SV_ViewID".to_string(),
577                        ShaderModel::V6_1,
578                    ));
579                }
580                let builtin_str = builtin.to_hlsl_str()?;
581                write!(self.out, " : {builtin_str}")?;
582            }
583            Some(crate::Binding::Location {
584                blend_src: Some(1), ..
585            }) => {
586                write!(self.out, " : SV_Target1")?;
587            }
588            Some(crate::Binding::Location { location, .. }) => {
589                if stage == Some((ShaderStage::Fragment, Io::Output)) {
590                    write!(self.out, " : SV_Target{location}")?;
591                } else {
592                    write!(self.out, " : {LOCATION_SEMANTIC}{location}")?;
593                }
594            }
595            _ => {}
596        }
597
598        Ok(())
599    }
600
601    fn write_interface_struct(
602        &mut self,
603        module: &Module,
604        shader_stage: (ShaderStage, Io),
605        struct_name: String,
606        mut members: Vec<EpStructMember>,
607    ) -> Result<EntryPointBinding, Error> {
608        // Sort the members so that first come the user-defined varyings
609        // in ascending locations, and then built-ins. This allows VS and FS
610        // interfaces to match with regards to order.
611        members.sort_by_key(|m| InterfaceKey::new(m.binding.as_ref()));
612
613        write!(self.out, "struct {struct_name}")?;
614        writeln!(self.out, " {{")?;
615        for m in members.iter() {
616            // Sanity check that each IO member is a built-in or is assigned a
617            // location. Also see note about nesting in `write_ep_input_struct`.
618            debug_assert!(m.binding.is_some());
619
620            if is_subgroup_builtin_binding(&m.binding) {
621                continue;
622            }
623            write!(self.out, "{}", back::INDENT)?;
624            if let Some(ref binding) = m.binding {
625                self.write_modifier(binding)?;
626            }
627            self.write_type(module, m.ty)?;
628            write!(self.out, " {}", &m.name)?;
629            self.write_semantic(&m.binding, Some(shader_stage))?;
630            writeln!(self.out, ";")?;
631        }
632        if members.iter().any(|arg| {
633            matches!(
634                arg.binding,
635                Some(crate::Binding::BuiltIn(crate::BuiltIn::SubgroupId))
636            )
637        }) {
638            writeln!(
639                self.out,
640                "{}uint __local_invocation_index : SV_GroupIndex;",
641                back::INDENT
642            )?;
643        }
644        writeln!(self.out, "}};")?;
645        writeln!(self.out)?;
646
647        // See ordering notes on EntryPointInterface fields
648        match shader_stage.1 {
649            Io::Input => {
650                // bring back the original order
651                members.sort_by_key(|m| m.index);
652            }
653            Io::Output => {
654                // keep it sorted by binding
655            }
656        }
657
658        Ok(EntryPointBinding {
659            arg_name: self.namer.call(struct_name.to_lowercase().as_str()),
660            ty_name: struct_name,
661            members,
662        })
663    }
664
665    /// Flatten all entry point arguments into a single struct.
666    /// This is needed since we need to re-order them: first placing user locations,
667    /// then built-ins.
668    fn write_ep_input_struct(
669        &mut self,
670        module: &Module,
671        func: &crate::Function,
672        stage: ShaderStage,
673        entry_point_name: &str,
674    ) -> Result<EntryPointBinding, Error> {
675        let struct_name = format!("{stage:?}Input_{entry_point_name}");
676
677        let mut fake_members = Vec::new();
678        for arg in func.arguments.iter() {
679            // NOTE: We don't need to handle nesting structs. All members must
680            // be either built-ins or assigned a location. I.E. `binding` is
681            // `Some`. This is checked in `VaryingContext::validate`. See:
682            // https://gpuweb.github.io/gpuweb/wgsl/#input-output-locations
683            match module.types[arg.ty].inner {
684                TypeInner::Struct { ref members, .. } => {
685                    for member in members.iter() {
686                        let name = self.namer.call_or(&member.name, "member");
687                        let index = fake_members.len() as u32;
688                        fake_members.push(EpStructMember {
689                            name,
690                            ty: member.ty,
691                            binding: member.binding.clone(),
692                            index,
693                        });
694                    }
695                }
696                _ => {
697                    let member_name = self.namer.call_or(&arg.name, "member");
698                    let index = fake_members.len() as u32;
699                    fake_members.push(EpStructMember {
700                        name: member_name,
701                        ty: arg.ty,
702                        binding: arg.binding.clone(),
703                        index,
704                    });
705                }
706            }
707        }
708
709        self.write_interface_struct(module, (stage, Io::Input), struct_name, fake_members)
710    }
711
712    /// Flatten all entry point results into a single struct.
713    /// This is needed since we need to re-order them: first placing user locations,
714    /// then built-ins.
715    fn write_ep_output_struct(
716        &mut self,
717        module: &Module,
718        result: &crate::FunctionResult,
719        stage: ShaderStage,
720        entry_point_name: &str,
721        frag_ep: Option<&FragmentEntryPoint<'_>>,
722    ) -> Result<EntryPointBinding, Error> {
723        let struct_name = format!("{stage:?}Output_{entry_point_name}");
724
725        let empty = [];
726        let members = match module.types[result.ty].inner {
727            TypeInner::Struct { ref members, .. } => members,
728            ref other => {
729                log::error!("Unexpected {other:?} output type without a binding");
730                &empty[..]
731            }
732        };
733
734        // Gather list of fragment input locations. We use this below to remove user-defined
735        // varyings from VS outputs that aren't in the FS inputs. This makes the VS interface match
736        // as long as the FS inputs are a subset of the VS outputs. This is only applied if the
737        // writer is supplied with information about the fragment entry point.
738        let fs_input_locs = if let (Some(frag_ep), ShaderStage::Vertex) = (frag_ep, stage) {
739            let mut fs_input_locs = Vec::new();
740            for arg in frag_ep.func.arguments.iter() {
741                let mut push_if_location = |binding: &Option<crate::Binding>| match *binding {
742                    Some(crate::Binding::Location { location, .. }) => fs_input_locs.push(location),
743                    Some(crate::Binding::BuiltIn(_)) | None => {}
744                };
745
746                // NOTE: We don't need to handle struct nesting. See note in
747                // `write_ep_input_struct`.
748                match frag_ep.module.types[arg.ty].inner {
749                    TypeInner::Struct { ref members, .. } => {
750                        for member in members.iter() {
751                            push_if_location(&member.binding);
752                        }
753                    }
754                    _ => push_if_location(&arg.binding),
755                }
756            }
757            fs_input_locs.sort();
758            Some(fs_input_locs)
759        } else {
760            None
761        };
762
763        let mut fake_members = Vec::new();
764        for (index, member) in members.iter().enumerate() {
765            if let Some(ref fs_input_locs) = fs_input_locs {
766                match member.binding {
767                    Some(crate::Binding::Location { location, .. }) => {
768                        if fs_input_locs.binary_search(&location).is_err() {
769                            continue;
770                        }
771                    }
772                    Some(crate::Binding::BuiltIn(_)) | None => {}
773                }
774            }
775
776            let member_name = self.namer.call_or(&member.name, "member");
777            fake_members.push(EpStructMember {
778                name: member_name,
779                ty: member.ty,
780                binding: member.binding.clone(),
781                index: index as u32,
782            });
783        }
784
785        self.write_interface_struct(module, (stage, Io::Output), struct_name, fake_members)
786    }
787
788    /// Writes special interface structures for an entry point. The special structures have
789    /// all the fields flattened into them and sorted by binding. They are needed to emulate
790    /// subgroup built-ins and to make the interfaces between VS outputs and FS inputs match.
791    fn write_ep_interface(
792        &mut self,
793        module: &Module,
794        func: &crate::Function,
795        stage: ShaderStage,
796        ep_name: &str,
797        frag_ep: Option<&FragmentEntryPoint<'_>>,
798    ) -> Result<EntryPointInterface, Error> {
799        Ok(EntryPointInterface {
800            input: if !func.arguments.is_empty()
801                && (stage == ShaderStage::Fragment
802                    || func
803                        .arguments
804                        .iter()
805                        .any(|arg| is_subgroup_builtin_binding(&arg.binding)))
806            {
807                Some(self.write_ep_input_struct(module, func, stage, ep_name)?)
808            } else {
809                None
810            },
811            output: match func.result {
812                Some(ref fr) if fr.binding.is_none() && stage == ShaderStage::Vertex => {
813                    Some(self.write_ep_output_struct(module, fr, stage, ep_name, frag_ep)?)
814                }
815                _ => None,
816            },
817        })
818    }
819
820    fn write_ep_argument_initialization(
821        &mut self,
822        ep: &crate::EntryPoint,
823        ep_input: &EntryPointBinding,
824        fake_member: &EpStructMember,
825    ) -> BackendResult {
826        match fake_member.binding {
827            Some(crate::Binding::BuiltIn(crate::BuiltIn::SubgroupSize)) => {
828                write!(self.out, "WaveGetLaneCount()")?
829            }
830            Some(crate::Binding::BuiltIn(crate::BuiltIn::SubgroupInvocationId)) => {
831                write!(self.out, "WaveGetLaneIndex()")?
832            }
833            Some(crate::Binding::BuiltIn(crate::BuiltIn::NumSubgroups)) => write!(
834                self.out,
835                "({}u + WaveGetLaneCount() - 1u) / WaveGetLaneCount()",
836                ep.workgroup_size[0] * ep.workgroup_size[1] * ep.workgroup_size[2]
837            )?,
838            Some(crate::Binding::BuiltIn(crate::BuiltIn::SubgroupId)) => {
839                write!(
840                    self.out,
841                    "{}.__local_invocation_index / WaveGetLaneCount()",
842                    ep_input.arg_name
843                )?;
844            }
845            _ => {
846                write!(self.out, "{}.{}", ep_input.arg_name, fake_member.name)?;
847            }
848        }
849        Ok(())
850    }
851
852    /// Write an entry point preface that initializes the arguments as specified in IR.
853    fn write_ep_arguments_initialization(
854        &mut self,
855        module: &Module,
856        func: &crate::Function,
857        ep_index: u16,
858    ) -> BackendResult {
859        let ep = &module.entry_points[ep_index as usize];
860        let ep_input = match self
861            .entry_point_io
862            .get_mut(&(ep_index as usize))
863            .unwrap()
864            .input
865            .take()
866        {
867            Some(ep_input) => ep_input,
868            None => return Ok(()),
869        };
870        let mut fake_iter = ep_input.members.iter();
871        for (arg_index, arg) in func.arguments.iter().enumerate() {
872            write!(self.out, "{}", back::INDENT)?;
873            self.write_type(module, arg.ty)?;
874            let arg_name = &self.names[&NameKey::EntryPointArgument(ep_index, arg_index as u32)];
875            write!(self.out, " {arg_name}")?;
876            match module.types[arg.ty].inner {
877                TypeInner::Array { base, size, .. } => {
878                    self.write_array_size(module, base, size)?;
879                    write!(self.out, " = ")?;
880                    self.write_ep_argument_initialization(
881                        ep,
882                        &ep_input,
883                        fake_iter.next().unwrap(),
884                    )?;
885                    writeln!(self.out, ";")?;
886                }
887                TypeInner::Struct { ref members, .. } => {
888                    write!(self.out, " = {{ ")?;
889                    for index in 0..members.len() {
890                        if index != 0 {
891                            write!(self.out, ", ")?;
892                        }
893                        self.write_ep_argument_initialization(
894                            ep,
895                            &ep_input,
896                            fake_iter.next().unwrap(),
897                        )?;
898                    }
899                    writeln!(self.out, " }};")?;
900                }
901                _ => {
902                    write!(self.out, " = ")?;
903                    self.write_ep_argument_initialization(
904                        ep,
905                        &ep_input,
906                        fake_iter.next().unwrap(),
907                    )?;
908                    writeln!(self.out, ";")?;
909                }
910            }
911        }
912        assert!(fake_iter.next().is_none());
913        Ok(())
914    }
915
916    /// Helper method used to write global variables
917    /// # Notes
918    /// Always adds a newline
919    fn write_global(
920        &mut self,
921        module: &Module,
922        handle: Handle<crate::GlobalVariable>,
923    ) -> BackendResult {
924        let global = &module.global_variables[handle];
925        let inner = &module.types[global.ty].inner;
926
927        let handle_ty = match *inner {
928            TypeInner::BindingArray { ref base, .. } => &module.types[*base].inner,
929            _ => inner,
930        };
931
932        // External textures are handled entirely differently, so defer entirely to that method.
933        // We do so prior to calling resolve_resource_binding() below, as we even need to resolve
934        // their bindings separately.
935        let is_external_texture = matches!(
936            *handle_ty,
937            TypeInner::Image {
938                class: crate::ImageClass::External,
939                ..
940            }
941        );
942        if is_external_texture {
943            return self.write_global_external_texture(module, handle, global);
944        }
945
946        if let Some(ref binding) = global.binding {
947            if let Err(err) = self.options.resolve_resource_binding(binding) {
948                log::info!(
949                    "Skipping global {:?} (name {:?}) for being inaccessible: {}",
950                    handle,
951                    global.name,
952                    err,
953                );
954                return Ok(());
955            }
956        }
957
958        // Samplers are handled entirely differently, so defer entirely to that method.
959        let is_sampler = matches!(*handle_ty, TypeInner::Sampler { .. });
960
961        if is_sampler {
962            return self.write_global_sampler(module, handle, global);
963        }
964
965        // https://docs.microsoft.com/en-us/windows/win32/direct3dhlsl/dx-graphics-hlsl-variable-register
966        let register_ty = match global.space {
967            crate::AddressSpace::Function => unreachable!("Function address space"),
968            crate::AddressSpace::Private => {
969                write!(self.out, "static ")?;
970                self.write_type(module, global.ty)?;
971                ""
972            }
973            crate::AddressSpace::WorkGroup => {
974                write!(self.out, "groupshared ")?;
975                self.write_type(module, global.ty)?;
976                ""
977            }
978            crate::AddressSpace::TaskPayload => unimplemented!(),
979            crate::AddressSpace::Uniform => {
980                // constant buffer declarations are expected to be inlined, e.g.
981                // `cbuffer foo: register(b0) { field1: type1; }`
982                write!(self.out, "cbuffer")?;
983                "b"
984            }
985            crate::AddressSpace::Storage { access } => {
986                let (prefix, register) = if access.contains(crate::StorageAccess::STORE) {
987                    ("RW", "u")
988                } else {
989                    ("", "t")
990                };
991                write!(self.out, "{prefix}ByteAddressBuffer")?;
992                register
993            }
994            crate::AddressSpace::Handle => {
995                let register = match *handle_ty {
996                    // all storage textures are UAV, unconditionally
997                    TypeInner::Image {
998                        class: crate::ImageClass::Storage { .. },
999                        ..
1000                    } => "u",
1001                    _ => "t",
1002                };
1003                self.write_type(module, global.ty)?;
1004                register
1005            }
1006            crate::AddressSpace::PushConstant => {
1007                // The type of the push constants will be wrapped in `ConstantBuffer`
1008                write!(self.out, "ConstantBuffer<")?;
1009                "b"
1010            }
1011        };
1012
1013        // If the global is a push constant write the type now because it will be a
1014        // generic argument to `ConstantBuffer`
1015        if global.space == crate::AddressSpace::PushConstant {
1016            self.write_global_type(module, global.ty)?;
1017
1018            // need to write the array size if the type was emitted with `write_type`
1019            if let TypeInner::Array { base, size, .. } = module.types[global.ty].inner {
1020                self.write_array_size(module, base, size)?;
1021            }
1022
1023            // Close the angled brackets for the generic argument
1024            write!(self.out, ">")?;
1025        }
1026
1027        let name = &self.names[&NameKey::GlobalVariable(handle)];
1028        write!(self.out, " {name}")?;
1029
1030        // Push constants need to be assigned a binding explicitly by the consumer
1031        // since naga has no way to know the binding from the shader alone
1032        if global.space == crate::AddressSpace::PushConstant {
1033            match module.types[global.ty].inner {
1034                TypeInner::Struct { .. } => {}
1035                _ => {
1036                    return Err(Error::Unimplemented(format!(
1037                        "push-constant '{name}' has non-struct type; tracked by: https://github.com/gfx-rs/wgpu/issues/5683"
1038                    )));
1039                }
1040            }
1041
1042            let target = self
1043                .options
1044                .push_constants_target
1045                .as_ref()
1046                .expect("No bind target was defined for the push constants block");
1047            write!(self.out, ": register(b{}", target.register)?;
1048            if target.space != 0 {
1049                write!(self.out, ", space{}", target.space)?;
1050            }
1051            write!(self.out, ")")?;
1052        }
1053
1054        if let Some(ref binding) = global.binding {
1055            // this was already resolved earlier when we started evaluating an entry point.
1056            let bt = self.options.resolve_resource_binding(binding).unwrap();
1057
1058            // need to write the binding array size if the type was emitted with `write_type`
1059            if let TypeInner::BindingArray { base, size, .. } = module.types[global.ty].inner {
1060                if let Some(overridden_size) = bt.binding_array_size {
1061                    write!(self.out, "[{overridden_size}]")?;
1062                } else {
1063                    self.write_array_size(module, base, size)?;
1064                }
1065            }
1066
1067            write!(self.out, " : register({}{}", register_ty, bt.register)?;
1068            if bt.space != 0 {
1069                write!(self.out, ", space{}", bt.space)?;
1070            }
1071            write!(self.out, ")")?;
1072        } else {
1073            // need to write the array size if the type was emitted with `write_type`
1074            if let TypeInner::Array { base, size, .. } = module.types[global.ty].inner {
1075                self.write_array_size(module, base, size)?;
1076            }
1077            if global.space == crate::AddressSpace::Private {
1078                write!(self.out, " = ")?;
1079                if let Some(init) = global.init {
1080                    self.write_const_expression(module, init, &module.global_expressions)?;
1081                } else {
1082                    self.write_default_init(module, global.ty)?;
1083                }
1084            }
1085        }
1086
1087        if global.space == crate::AddressSpace::Uniform {
1088            write!(self.out, " {{ ")?;
1089
1090            self.write_global_type(module, global.ty)?;
1091
1092            write!(
1093                self.out,
1094                " {}",
1095                &self.names[&NameKey::GlobalVariable(handle)]
1096            )?;
1097
1098            // need to write the array size if the type was emitted with `write_type`
1099            if let TypeInner::Array { base, size, .. } = module.types[global.ty].inner {
1100                self.write_array_size(module, base, size)?;
1101            }
1102
1103            writeln!(self.out, "; }}")?;
1104        } else {
1105            writeln!(self.out, ";")?;
1106        }
1107
1108        Ok(())
1109    }
1110
1111    fn write_global_sampler(
1112        &mut self,
1113        module: &Module,
1114        handle: Handle<crate::GlobalVariable>,
1115        global: &crate::GlobalVariable,
1116    ) -> BackendResult {
1117        let binding = *global.binding.as_ref().unwrap();
1118
1119        let key = super::SamplerIndexBufferKey {
1120            group: binding.group,
1121        };
1122        self.write_wrapped_sampler_buffer(key)?;
1123
1124        // This was already validated, so we can confidently unwrap it.
1125        let bt = self.options.resolve_resource_binding(&binding).unwrap();
1126
1127        match module.types[global.ty].inner {
1128            TypeInner::Sampler { comparison } => {
1129                // If we are generating a static access, we create a variable for the sampler.
1130                //
1131                // This prevents the DXIL from containing multiple lookups for the sampler, which
1132                // the backend compiler will then have to eliminate. AMD does seem to be able to
1133                // eliminate these, but better safe than sorry.
1134
1135                write!(self.out, "static const ")?;
1136                self.write_type(module, global.ty)?;
1137
1138                let heap_var = if comparison {
1139                    COMPARISON_SAMPLER_HEAP_VAR
1140                } else {
1141                    SAMPLER_HEAP_VAR
1142                };
1143
1144                let index_buffer_name = &self.wrapped.sampler_index_buffers[&key];
1145                let name = &self.names[&NameKey::GlobalVariable(handle)];
1146                writeln!(
1147                    self.out,
1148                    " {name} = {heap_var}[{index_buffer_name}[{register}]];",
1149                    register = bt.register
1150                )?;
1151            }
1152            TypeInner::BindingArray { .. } => {
1153                // If we are generating a binding array, we cannot directly access the sampler as the index
1154                // into the sampler index buffer is unknown at compile time. Instead we generate a constant
1155                // that represents the "base" index into the sampler index buffer. This constant is added
1156                // to the user provided index to get the final index into the sampler index buffer.
1157
1158                let name = &self.names[&NameKey::GlobalVariable(handle)];
1159                writeln!(
1160                    self.out,
1161                    "static const uint {name} = {register};",
1162                    register = bt.register
1163                )?;
1164            }
1165            _ => unreachable!(),
1166        };
1167
1168        Ok(())
1169    }
1170
1171    /// Write the declarations for an external texture global variable.
1172    /// These are emitted as multiple global variables: Three `Texture2D`s
1173    /// (one for each plane) and a parameters cbuffer.
1174    fn write_global_external_texture(
1175        &mut self,
1176        module: &Module,
1177        handle: Handle<crate::GlobalVariable>,
1178        global: &crate::GlobalVariable,
1179    ) -> BackendResult {
1180        let res_binding = global
1181            .binding
1182            .as_ref()
1183            .expect("External texture global variables must have a resource binding");
1184        let ext_tex_bindings = match self
1185            .options
1186            .resolve_external_texture_resource_binding(res_binding)
1187        {
1188            Ok(bindings) => bindings,
1189            Err(err) => {
1190                log::info!(
1191                    "Skipping global {:?} (name {:?}) for being inaccessible: {}",
1192                    handle,
1193                    global.name,
1194                    err,
1195                );
1196                return Ok(());
1197            }
1198        };
1199
1200        let mut write_plane = |bt: &super::BindTarget, name| -> BackendResult {
1201            write!(
1202                self.out,
1203                "Texture2D<float4> {}: register(t{}",
1204                name, bt.register
1205            )?;
1206            if bt.space != 0 {
1207                write!(self.out, ", space{}", bt.space)?;
1208            }
1209            writeln!(self.out, ");")?;
1210            Ok(())
1211        };
1212        for (i, bt) in ext_tex_bindings.planes.iter().enumerate() {
1213            let plane_name = &self.names
1214                [&NameKey::ExternalTextureGlobalVariable(handle, ExternalTextureNameKey::Plane(i))];
1215            write_plane(bt, plane_name)?;
1216        }
1217
1218        let params_name = &self.names
1219            [&NameKey::ExternalTextureGlobalVariable(handle, ExternalTextureNameKey::Params)];
1220        let params_ty_name =
1221            &self.names[&NameKey::Type(module.special_types.external_texture_params.unwrap())];
1222        write!(
1223            self.out,
1224            "cbuffer {}: register(b{}",
1225            params_name, ext_tex_bindings.params.register
1226        )?;
1227        if ext_tex_bindings.params.space != 0 {
1228            write!(self.out, ", space{}", ext_tex_bindings.params.space)?;
1229        }
1230        writeln!(self.out, ") {{ {params_ty_name} {params_name}; }};")?;
1231
1232        Ok(())
1233    }
1234
1235    /// Helper method used to write global constants
1236    ///
1237    /// # Notes
1238    /// Ends in a newline
1239    fn write_global_constant(
1240        &mut self,
1241        module: &Module,
1242        handle: Handle<crate::Constant>,
1243    ) -> BackendResult {
1244        write!(self.out, "static const ")?;
1245        let constant = &module.constants[handle];
1246        self.write_type(module, constant.ty)?;
1247        let name = &self.names[&NameKey::Constant(handle)];
1248        write!(self.out, " {name}")?;
1249        // Write size for array type
1250        if let TypeInner::Array { base, size, .. } = module.types[constant.ty].inner {
1251            self.write_array_size(module, base, size)?;
1252        }
1253        write!(self.out, " = ")?;
1254        self.write_const_expression(module, constant.init, &module.global_expressions)?;
1255        writeln!(self.out, ";")?;
1256        Ok(())
1257    }
1258
1259    pub(super) fn write_array_size(
1260        &mut self,
1261        module: &Module,
1262        base: Handle<crate::Type>,
1263        size: crate::ArraySize,
1264    ) -> BackendResult {
1265        write!(self.out, "[")?;
1266
1267        match size.resolve(module.to_ctx())? {
1268            proc::IndexableLength::Known(size) => {
1269                write!(self.out, "{size}")?;
1270            }
1271            proc::IndexableLength::Dynamic => unreachable!(),
1272        }
1273
1274        write!(self.out, "]")?;
1275
1276        if let TypeInner::Array {
1277            base: next_base,
1278            size: next_size,
1279            ..
1280        } = module.types[base].inner
1281        {
1282            self.write_array_size(module, next_base, next_size)?;
1283        }
1284
1285        Ok(())
1286    }
1287
1288    /// Helper method used to write structs
1289    ///
1290    /// # Notes
1291    /// Ends in a newline
1292    fn write_struct(
1293        &mut self,
1294        module: &Module,
1295        handle: Handle<crate::Type>,
1296        members: &[crate::StructMember],
1297        span: u32,
1298        shader_stage: Option<(ShaderStage, Io)>,
1299    ) -> BackendResult {
1300        // Write struct name
1301        let struct_name = &self.names[&NameKey::Type(handle)];
1302        writeln!(self.out, "struct {struct_name} {{")?;
1303
1304        let mut last_offset = 0;
1305        for (index, member) in members.iter().enumerate() {
1306            if member.binding.is_none() && member.offset > last_offset {
1307                // using int as padding should work as long as the backend
1308                // doesn't support a type that's less than 4 bytes in size
1309                // (Error::UnsupportedScalar catches this)
1310                let padding = (member.offset - last_offset) / 4;
1311                for i in 0..padding {
1312                    writeln!(self.out, "{}int _pad{}_{};", back::INDENT, index, i)?;
1313                }
1314            }
1315            let ty_inner = &module.types[member.ty].inner;
1316            last_offset = member.offset + ty_inner.size_hlsl(module.to_ctx())?;
1317
1318            // The indentation is only for readability
1319            write!(self.out, "{}", back::INDENT)?;
1320
1321            match module.types[member.ty].inner {
1322                TypeInner::Array { base, size, .. } => {
1323                    // HLSL arrays are written as `type name[size]`
1324
1325                    self.write_global_type(module, member.ty)?;
1326
1327                    // Write `name`
1328                    write!(
1329                        self.out,
1330                        " {}",
1331                        &self.names[&NameKey::StructMember(handle, index as u32)]
1332                    )?;
1333                    // Write [size]
1334                    self.write_array_size(module, base, size)?;
1335                }
1336                // We treat matrices of the form `matCx2` as a sequence of C `vec2`s.
1337                // See the module-level block comment in mod.rs for details.
1338                TypeInner::Matrix {
1339                    rows,
1340                    columns,
1341                    scalar,
1342                } if member.binding.is_none() && rows == crate::VectorSize::Bi => {
1343                    let vec_ty = TypeInner::Vector { size: rows, scalar };
1344                    let field_name_key = NameKey::StructMember(handle, index as u32);
1345
1346                    for i in 0..columns as u8 {
1347                        if i != 0 {
1348                            write!(self.out, "; ")?;
1349                        }
1350                        self.write_value_type(module, &vec_ty)?;
1351                        write!(self.out, " {}_{}", &self.names[&field_name_key], i)?;
1352                    }
1353                }
1354                _ => {
1355                    // Write modifier before type
1356                    if let Some(ref binding) = member.binding {
1357                        self.write_modifier(binding)?;
1358                    }
1359
1360                    // Even though Naga IR matrices are column-major, we must describe
1361                    // matrices passed from the CPU as being in row-major order.
1362                    // See the module-level block comment in mod.rs for details.
1363                    if let TypeInner::Matrix { .. } = module.types[member.ty].inner {
1364                        write!(self.out, "row_major ")?;
1365                    }
1366
1367                    // Write the member type and name
1368                    self.write_type(module, member.ty)?;
1369                    write!(
1370                        self.out,
1371                        " {}",
1372                        &self.names[&NameKey::StructMember(handle, index as u32)]
1373                    )?;
1374                }
1375            }
1376
1377            self.write_semantic(&member.binding, shader_stage)?;
1378            writeln!(self.out, ";")?;
1379        }
1380
1381        // add padding at the end since sizes of types don't get rounded up to their alignment in HLSL
1382        if members.last().unwrap().binding.is_none() && span > last_offset {
1383            let padding = (span - last_offset) / 4;
1384            for i in 0..padding {
1385                writeln!(self.out, "{}int _end_pad_{};", back::INDENT, i)?;
1386            }
1387        }
1388
1389        writeln!(self.out, "}};")?;
1390        Ok(())
1391    }
1392
1393    /// Helper method used to write global/structs non image/sampler types
1394    ///
1395    /// # Notes
1396    /// Adds no trailing or leading whitespace
1397    pub(super) fn write_global_type(
1398        &mut self,
1399        module: &Module,
1400        ty: Handle<crate::Type>,
1401    ) -> BackendResult {
1402        let matrix_data = get_inner_matrix_data(module, ty);
1403
1404        // We treat matrices of the form `matCx2` as a sequence of C `vec2`s.
1405        // See the module-level block comment in mod.rs for details.
1406        if let Some(MatrixType {
1407            columns,
1408            rows: crate::VectorSize::Bi,
1409            width: 4,
1410        }) = matrix_data
1411        {
1412            write!(self.out, "__mat{}x2", columns as u8)?;
1413        } else {
1414            // Even though Naga IR matrices are column-major, we must describe
1415            // matrices passed from the CPU as being in row-major order.
1416            // See the module-level block comment in mod.rs for details.
1417            if matrix_data.is_some() {
1418                write!(self.out, "row_major ")?;
1419            }
1420
1421            self.write_type(module, ty)?;
1422        }
1423
1424        Ok(())
1425    }
1426
1427    /// Helper method used to write non image/sampler types
1428    ///
1429    /// # Notes
1430    /// Adds no trailing or leading whitespace
1431    pub(super) fn write_type(&mut self, module: &Module, ty: Handle<crate::Type>) -> BackendResult {
1432        let inner = &module.types[ty].inner;
1433        match *inner {
1434            TypeInner::Struct { .. } => write!(self.out, "{}", self.names[&NameKey::Type(ty)])?,
1435            // hlsl array has the size separated from the base type
1436            TypeInner::Array { base, .. } | TypeInner::BindingArray { base, .. } => {
1437                self.write_type(module, base)?
1438            }
1439            ref other => self.write_value_type(module, other)?,
1440        }
1441
1442        Ok(())
1443    }
1444
1445    /// Helper method used to write value types
1446    ///
1447    /// # Notes
1448    /// Adds no trailing or leading whitespace
1449    pub(super) fn write_value_type(&mut self, module: &Module, inner: &TypeInner) -> BackendResult {
1450        match *inner {
1451            TypeInner::Scalar(scalar) | TypeInner::Atomic(scalar) => {
1452                write!(self.out, "{}", scalar.to_hlsl_str()?)?;
1453            }
1454            TypeInner::Vector { size, scalar } => {
1455                write!(
1456                    self.out,
1457                    "{}{}",
1458                    scalar.to_hlsl_str()?,
1459                    common::vector_size_str(size)
1460                )?;
1461            }
1462            TypeInner::Matrix {
1463                columns,
1464                rows,
1465                scalar,
1466            } => {
1467                // The IR supports only float matrix
1468                // https://docs.microsoft.com/en-us/windows/win32/direct3dhlsl/dx-graphics-hlsl-matrix
1469
1470                // Because of the implicit transpose all matrices have in HLSL, we need to transpose the size as well.
1471                write!(
1472                    self.out,
1473                    "{}{}x{}",
1474                    scalar.to_hlsl_str()?,
1475                    common::vector_size_str(columns),
1476                    common::vector_size_str(rows),
1477                )?;
1478            }
1479            TypeInner::Image {
1480                dim,
1481                arrayed,
1482                class,
1483            } => {
1484                self.write_image_type(dim, arrayed, class)?;
1485            }
1486            TypeInner::Sampler { comparison } => {
1487                let sampler = if comparison {
1488                    "SamplerComparisonState"
1489                } else {
1490                    "SamplerState"
1491                };
1492                write!(self.out, "{sampler}")?;
1493            }
1494            // HLSL arrays are written as `type name[size]`
1495            // Current code is written arrays only as `[size]`
1496            // Base `type` and `name` should be written outside
1497            TypeInner::Array { base, size, .. } | TypeInner::BindingArray { base, size } => {
1498                self.write_array_size(module, base, size)?;
1499            }
1500            TypeInner::AccelerationStructure { .. } => {
1501                write!(self.out, "RaytracingAccelerationStructure")?;
1502            }
1503            TypeInner::RayQuery { .. } => {
1504                // these are constant flags, there are dynamic flags also but constant flags are not supported by naga
1505                write!(self.out, "RayQuery<RAY_FLAG_NONE>")?;
1506            }
1507            _ => return Err(Error::Unimplemented(format!("write_value_type {inner:?}"))),
1508        }
1509
1510        Ok(())
1511    }
1512
1513    /// Helper method used to write functions
1514    /// # Notes
1515    /// Ends in a newline
1516    fn write_function(
1517        &mut self,
1518        module: &Module,
1519        name: &str,
1520        func: &crate::Function,
1521        func_ctx: &back::FunctionCtx<'_>,
1522        info: &valid::FunctionInfo,
1523    ) -> BackendResult {
1524        // Function Declaration Syntax - https://docs.microsoft.com/en-us/windows/win32/direct3dhlsl/dx-graphics-hlsl-function-syntax
1525
1526        self.update_expressions_to_bake(module, func, info);
1527
1528        if let Some(ref result) = func.result {
1529            // Write typedef if return type is an array
1530            let array_return_type = match module.types[result.ty].inner {
1531                TypeInner::Array { base, size, .. } => {
1532                    let array_return_type = self.namer.call(&format!("ret_{name}"));
1533                    write!(self.out, "typedef ")?;
1534                    self.write_type(module, result.ty)?;
1535                    write!(self.out, " {array_return_type}")?;
1536                    self.write_array_size(module, base, size)?;
1537                    writeln!(self.out, ";")?;
1538                    Some(array_return_type)
1539                }
1540                _ => None,
1541            };
1542
1543            // Write modifier
1544            if let Some(
1545                ref binding @ crate::Binding::BuiltIn(crate::BuiltIn::Position { invariant: true }),
1546            ) = result.binding
1547            {
1548                self.write_modifier(binding)?;
1549            }
1550
1551            // Write return type
1552            match func_ctx.ty {
1553                back::FunctionType::Function(_) => {
1554                    if let Some(array_return_type) = array_return_type {
1555                        write!(self.out, "{array_return_type}")?;
1556                    } else {
1557                        self.write_type(module, result.ty)?;
1558                    }
1559                }
1560                back::FunctionType::EntryPoint(index) => {
1561                    if let Some(ref ep_output) =
1562                        self.entry_point_io.get(&(index as usize)).unwrap().output
1563                    {
1564                        write!(self.out, "{}", ep_output.ty_name)?;
1565                    } else {
1566                        self.write_type(module, result.ty)?;
1567                    }
1568                }
1569            }
1570        } else {
1571            write!(self.out, "void")?;
1572        }
1573
1574        // Write function name
1575        write!(self.out, " {name}(")?;
1576
1577        let need_workgroup_variables_initialization =
1578            self.need_workgroup_variables_initialization(func_ctx, module);
1579
1580        // Write function arguments for non entry point functions
1581        match func_ctx.ty {
1582            back::FunctionType::Function(handle) => {
1583                for (index, arg) in func.arguments.iter().enumerate() {
1584                    if index != 0 {
1585                        write!(self.out, ", ")?;
1586                    }
1587
1588                    self.write_function_argument(module, handle, arg, index)?;
1589                }
1590            }
1591            back::FunctionType::EntryPoint(ep_index) => {
1592                if let Some(ref ep_input) =
1593                    self.entry_point_io.get(&(ep_index as usize)).unwrap().input
1594                {
1595                    write!(self.out, "{} {}", ep_input.ty_name, ep_input.arg_name)?;
1596                } else {
1597                    let stage = module.entry_points[ep_index as usize].stage;
1598                    for (index, arg) in func.arguments.iter().enumerate() {
1599                        if index != 0 {
1600                            write!(self.out, ", ")?;
1601                        }
1602                        self.write_type(module, arg.ty)?;
1603
1604                        let argument_name =
1605                            &self.names[&NameKey::EntryPointArgument(ep_index, index as u32)];
1606
1607                        write!(self.out, " {argument_name}")?;
1608                        if let TypeInner::Array { base, size, .. } = module.types[arg.ty].inner {
1609                            self.write_array_size(module, base, size)?;
1610                        }
1611
1612                        self.write_semantic(&arg.binding, Some((stage, Io::Input)))?;
1613                    }
1614                }
1615                if need_workgroup_variables_initialization {
1616                    if self
1617                        .entry_point_io
1618                        .get(&(ep_index as usize))
1619                        .unwrap()
1620                        .input
1621                        .is_some()
1622                        || !func.arguments.is_empty()
1623                    {
1624                        write!(self.out, ", ")?;
1625                    }
1626                    write!(self.out, "uint3 __local_invocation_id : SV_GroupThreadID")?;
1627                }
1628            }
1629        }
1630        // Ends of arguments
1631        write!(self.out, ")")?;
1632
1633        // Write semantic if it present
1634        if let back::FunctionType::EntryPoint(index) = func_ctx.ty {
1635            let stage = module.entry_points[index as usize].stage;
1636            if let Some(crate::FunctionResult { ref binding, .. }) = func.result {
1637                self.write_semantic(binding, Some((stage, Io::Output)))?;
1638            }
1639        }
1640
1641        // Function body start
1642        writeln!(self.out)?;
1643        writeln!(self.out, "{{")?;
1644
1645        if need_workgroup_variables_initialization {
1646            self.write_workgroup_variables_initialization(func_ctx, module)?;
1647        }
1648
1649        if let back::FunctionType::EntryPoint(index) = func_ctx.ty {
1650            self.write_ep_arguments_initialization(module, func, index)?;
1651        }
1652
1653        // Write function local variables
1654        for (handle, local) in func.local_variables.iter() {
1655            // Write indentation (only for readability)
1656            write!(self.out, "{}", back::INDENT)?;
1657
1658            // Write the local name
1659            // The leading space is important
1660            self.write_type(module, local.ty)?;
1661            write!(self.out, " {}", self.names[&func_ctx.name_key(handle)])?;
1662            // Write size for array type
1663            if let TypeInner::Array { base, size, .. } = module.types[local.ty].inner {
1664                self.write_array_size(module, base, size)?;
1665            }
1666
1667            match module.types[local.ty].inner {
1668                // from https://microsoft.github.io/DirectX-Specs/d3d/Raytracing.html#tracerayinline-example-1 it seems that ray queries shouldn't be zeroed
1669                TypeInner::RayQuery { .. } => {}
1670                _ => {
1671                    write!(self.out, " = ")?;
1672                    // Write the local initializer if needed
1673                    if let Some(init) = local.init {
1674                        self.write_expr(module, init, func_ctx)?;
1675                    } else {
1676                        // Zero initialize local variables
1677                        self.write_default_init(module, local.ty)?;
1678                    }
1679                }
1680            }
1681            // Finish the local with `;` and add a newline (only for readability)
1682            writeln!(self.out, ";")?
1683        }
1684
1685        if !func.local_variables.is_empty() {
1686            writeln!(self.out)?;
1687        }
1688
1689        // Write the function body (statement list)
1690        for sta in func.body.iter() {
1691            // The indentation should always be 1 when writing the function body
1692            self.write_stmt(module, sta, func_ctx, back::Level(1))?;
1693        }
1694
1695        writeln!(self.out, "}}")?;
1696
1697        self.named_expressions.clear();
1698
1699        Ok(())
1700    }
1701
1702    fn write_function_argument(
1703        &mut self,
1704        module: &Module,
1705        handle: Handle<crate::Function>,
1706        arg: &crate::FunctionArgument,
1707        index: usize,
1708    ) -> BackendResult {
1709        // External texture arguments must be expanded into separate
1710        // arguments for each plane and the params buffer.
1711        if let TypeInner::Image {
1712            class: crate::ImageClass::External,
1713            ..
1714        } = module.types[arg.ty].inner
1715        {
1716            return self.write_function_external_texture_argument(module, handle, index);
1717        }
1718
1719        // Write argument type
1720        let arg_ty = match module.types[arg.ty].inner {
1721            // pointers in function arguments are expected and resolve to `inout`
1722            TypeInner::Pointer { base, .. } => {
1723                //TODO: can we narrow this down to just `in` when possible?
1724                write!(self.out, "inout ")?;
1725                base
1726            }
1727            _ => arg.ty,
1728        };
1729        self.write_type(module, arg_ty)?;
1730
1731        let argument_name = &self.names[&NameKey::FunctionArgument(handle, index as u32)];
1732
1733        // Write argument name. Space is important.
1734        write!(self.out, " {argument_name}")?;
1735        if let TypeInner::Array { base, size, .. } = module.types[arg_ty].inner {
1736            self.write_array_size(module, base, size)?;
1737        }
1738
1739        Ok(())
1740    }
1741
1742    fn write_function_external_texture_argument(
1743        &mut self,
1744        module: &Module,
1745        handle: Handle<crate::Function>,
1746        index: usize,
1747    ) -> BackendResult {
1748        let plane_names = [0, 1, 2].map(|i| {
1749            &self.names[&NameKey::ExternalTextureFunctionArgument(
1750                handle,
1751                index as u32,
1752                ExternalTextureNameKey::Plane(i),
1753            )]
1754        });
1755        let params_name = &self.names[&NameKey::ExternalTextureFunctionArgument(
1756            handle,
1757            index as u32,
1758            ExternalTextureNameKey::Params,
1759        )];
1760        let params_ty_name =
1761            &self.names[&NameKey::Type(module.special_types.external_texture_params.unwrap())];
1762        write!(
1763            self.out,
1764            "Texture2D<float4> {}, Texture2D<float4> {}, Texture2D<float4> {}, {params_ty_name} {params_name}",
1765            plane_names[0], plane_names[1], plane_names[2],
1766        )?;
1767        Ok(())
1768    }
1769
1770    fn need_workgroup_variables_initialization(
1771        &mut self,
1772        func_ctx: &back::FunctionCtx,
1773        module: &Module,
1774    ) -> bool {
1775        self.options.zero_initialize_workgroup_memory
1776            && func_ctx.ty.is_compute_like_entry_point(module)
1777            && module.global_variables.iter().any(|(handle, var)| {
1778                !func_ctx.info[handle].is_empty() && var.space == crate::AddressSpace::WorkGroup
1779            })
1780    }
1781
1782    fn write_workgroup_variables_initialization(
1783        &mut self,
1784        func_ctx: &back::FunctionCtx,
1785        module: &Module,
1786    ) -> BackendResult {
1787        let level = back::Level(1);
1788
1789        writeln!(
1790            self.out,
1791            "{level}if (all(__local_invocation_id == uint3(0u, 0u, 0u))) {{"
1792        )?;
1793
1794        let vars = module.global_variables.iter().filter(|&(handle, var)| {
1795            !func_ctx.info[handle].is_empty() && var.space == crate::AddressSpace::WorkGroup
1796        });
1797
1798        for (handle, var) in vars {
1799            let name = &self.names[&NameKey::GlobalVariable(handle)];
1800            write!(self.out, "{}{} = ", level.next(), name)?;
1801            self.write_default_init(module, var.ty)?;
1802            writeln!(self.out, ";")?;
1803        }
1804
1805        writeln!(self.out, "{level}}}")?;
1806        self.write_control_barrier(crate::Barrier::WORK_GROUP, level)
1807    }
1808
1809    /// Helper method used to write switches
1810    fn write_switch(
1811        &mut self,
1812        module: &Module,
1813        func_ctx: &back::FunctionCtx<'_>,
1814        level: back::Level,
1815        selector: Handle<crate::Expression>,
1816        cases: &[crate::SwitchCase],
1817    ) -> BackendResult {
1818        // Write all cases
1819        let indent_level_1 = level.next();
1820        let indent_level_2 = indent_level_1.next();
1821
1822        // See docs of `back::continue_forward` module.
1823        if let Some(variable) = self.continue_ctx.enter_switch(&mut self.namer) {
1824            writeln!(self.out, "{level}bool {variable} = false;",)?;
1825        };
1826
1827        // Check if there is only one body, by seeing if all except the last case are fall through
1828        // with empty bodies. FXC doesn't handle these switches correctly, so
1829        // we generate a `do {} while(false);` loop instead. There must be a default case, so there
1830        // is no need to check if one of the cases would have matched.
1831        let one_body = cases
1832            .iter()
1833            .rev()
1834            .skip(1)
1835            .all(|case| case.fall_through && case.body.is_empty());
1836        if one_body {
1837            // Start the do-while
1838            writeln!(self.out, "{level}do {{")?;
1839            // Note: Expressions have no side-effects so we don't need to emit selector expression.
1840
1841            // Body
1842            if let Some(case) = cases.last() {
1843                for sta in case.body.iter() {
1844                    self.write_stmt(module, sta, func_ctx, indent_level_1)?;
1845                }
1846            }
1847            // End do-while
1848            writeln!(self.out, "{level}}} while(false);")?;
1849        } else {
1850            // Start the switch
1851            write!(self.out, "{level}")?;
1852            write!(self.out, "switch(")?;
1853            self.write_expr(module, selector, func_ctx)?;
1854            writeln!(self.out, ") {{")?;
1855
1856            for (i, case) in cases.iter().enumerate() {
1857                match case.value {
1858                    crate::SwitchValue::I32(value) => {
1859                        write!(self.out, "{indent_level_1}case {value}:")?
1860                    }
1861                    crate::SwitchValue::U32(value) => {
1862                        write!(self.out, "{indent_level_1}case {value}u:")?
1863                    }
1864                    crate::SwitchValue::Default => write!(self.out, "{indent_level_1}default:")?,
1865                }
1866
1867                // The new block is not only stylistic, it plays a role here:
1868                // We might end up having to write the same case body
1869                // multiple times due to FXC not supporting fallthrough.
1870                // Therefore, some `Expression`s written by `Statement::Emit`
1871                // will end up having the same name (`_expr<handle_index>`).
1872                // So we need to put each case in its own scope.
1873                let write_block_braces = !(case.fall_through && case.body.is_empty());
1874                if write_block_braces {
1875                    writeln!(self.out, " {{")?;
1876                } else {
1877                    writeln!(self.out)?;
1878                }
1879
1880                // Although FXC does support a series of case clauses before
1881                // a block[^yes], it does not support fallthrough from a
1882                // non-empty case block to the next[^no]. If this case has a
1883                // non-empty body with a fallthrough, emulate that by
1884                // duplicating the bodies of all the cases it would fall
1885                // into as extensions of this case's own body. This makes
1886                // the HLSL output potentially quadratic in the size of the
1887                // Naga IR.
1888                //
1889                // [^yes]: ```hlsl
1890                // case 1:
1891                // case 2: do_stuff()
1892                // ```
1893                // [^no]: ```hlsl
1894                // case 1: do_this();
1895                // case 2: do_that();
1896                // ```
1897                if case.fall_through && !case.body.is_empty() {
1898                    let curr_len = i + 1;
1899                    let end_case_idx = curr_len
1900                        + cases
1901                            .iter()
1902                            .skip(curr_len)
1903                            .position(|case| !case.fall_through)
1904                            .unwrap();
1905                    let indent_level_3 = indent_level_2.next();
1906                    for case in &cases[i..=end_case_idx] {
1907                        writeln!(self.out, "{indent_level_2}{{")?;
1908                        let prev_len = self.named_expressions.len();
1909                        for sta in case.body.iter() {
1910                            self.write_stmt(module, sta, func_ctx, indent_level_3)?;
1911                        }
1912                        // Clear all named expressions that were previously inserted by the statements in the block
1913                        self.named_expressions.truncate(prev_len);
1914                        writeln!(self.out, "{indent_level_2}}}")?;
1915                    }
1916
1917                    let last_case = &cases[end_case_idx];
1918                    if last_case.body.last().is_none_or(|s| !s.is_terminator()) {
1919                        writeln!(self.out, "{indent_level_2}break;")?;
1920                    }
1921                } else {
1922                    for sta in case.body.iter() {
1923                        self.write_stmt(module, sta, func_ctx, indent_level_2)?;
1924                    }
1925                    if !case.fall_through && case.body.last().is_none_or(|s| !s.is_terminator()) {
1926                        writeln!(self.out, "{indent_level_2}break;")?;
1927                    }
1928                }
1929
1930                if write_block_braces {
1931                    writeln!(self.out, "{indent_level_1}}}")?;
1932                }
1933            }
1934
1935            writeln!(self.out, "{level}}}")?;
1936        }
1937
1938        // Handle any forwarded continue statements.
1939        use back::continue_forward::ExitControlFlow;
1940        let op = match self.continue_ctx.exit_switch() {
1941            ExitControlFlow::None => None,
1942            ExitControlFlow::Continue { variable } => Some(("continue", variable)),
1943            ExitControlFlow::Break { variable } => Some(("break", variable)),
1944        };
1945        if let Some((control_flow, variable)) = op {
1946            writeln!(self.out, "{level}if ({variable}) {{")?;
1947            writeln!(self.out, "{indent_level_1}{control_flow};")?;
1948            writeln!(self.out, "{level}}}")?;
1949        }
1950
1951        Ok(())
1952    }
1953
1954    fn write_index(
1955        &mut self,
1956        module: &Module,
1957        index: Index,
1958        func_ctx: &back::FunctionCtx<'_>,
1959    ) -> BackendResult {
1960        match index {
1961            Index::Static(index) => {
1962                write!(self.out, "{index}")?;
1963            }
1964            Index::Expression(index) => {
1965                self.write_expr(module, index, func_ctx)?;
1966            }
1967        }
1968        Ok(())
1969    }
1970
1971    /// Helper method used to write statements
1972    ///
1973    /// # Notes
1974    /// Always adds a newline
1975    fn write_stmt(
1976        &mut self,
1977        module: &Module,
1978        stmt: &crate::Statement,
1979        func_ctx: &back::FunctionCtx<'_>,
1980        level: back::Level,
1981    ) -> BackendResult {
1982        use crate::Statement;
1983
1984        match *stmt {
1985            Statement::Emit(ref range) => {
1986                for handle in range.clone() {
1987                    let ptr_class = func_ctx.resolve_type(handle, &module.types).pointer_space();
1988                    let expr_name = if ptr_class.is_some() {
1989                        // HLSL can't save a pointer-valued expression in a variable,
1990                        // but we shouldn't ever need to: they should never be named expressions,
1991                        // and none of the expression types flagged by bake_ref_count can be pointer-valued.
1992                        None
1993                    } else if let Some(name) = func_ctx.named_expressions.get(&handle) {
1994                        // Front end provides names for all variables at the start of writing.
1995                        // But we write them to step by step. We need to recache them
1996                        // Otherwise, we could accidentally write variable name instead of full expression.
1997                        // Also, we use sanitized names! It defense backend from generating variable with name from reserved keywords.
1998                        Some(self.namer.call(name))
1999                    } else if self.need_bake_expressions.contains(&handle) {
2000                        Some(Baked(handle).to_string())
2001                    } else {
2002                        None
2003                    };
2004
2005                    if let Some(name) = expr_name {
2006                        write!(self.out, "{level}")?;
2007                        self.write_named_expr(module, handle, name, handle, func_ctx)?;
2008                    }
2009                }
2010            }
2011            // TODO: copy-paste from glsl-out
2012            Statement::Block(ref block) => {
2013                write!(self.out, "{level}")?;
2014                writeln!(self.out, "{{")?;
2015                for sta in block.iter() {
2016                    // Increase the indentation to help with readability
2017                    self.write_stmt(module, sta, func_ctx, level.next())?
2018                }
2019                writeln!(self.out, "{level}}}")?
2020            }
2021            // TODO: copy-paste from glsl-out
2022            Statement::If {
2023                condition,
2024                ref accept,
2025                ref reject,
2026            } => {
2027                write!(self.out, "{level}")?;
2028                write!(self.out, "if (")?;
2029                self.write_expr(module, condition, func_ctx)?;
2030                writeln!(self.out, ") {{")?;
2031
2032                let l2 = level.next();
2033                for sta in accept {
2034                    // Increase indentation to help with readability
2035                    self.write_stmt(module, sta, func_ctx, l2)?;
2036                }
2037
2038                // If there are no statements in the reject block we skip writing it
2039                // This is only for readability
2040                if !reject.is_empty() {
2041                    writeln!(self.out, "{level}}} else {{")?;
2042
2043                    for sta in reject {
2044                        // Increase indentation to help with readability
2045                        self.write_stmt(module, sta, func_ctx, l2)?;
2046                    }
2047                }
2048
2049                writeln!(self.out, "{level}}}")?
2050            }
2051            // TODO: copy-paste from glsl-out
2052            Statement::Kill => writeln!(self.out, "{level}discard;")?,
2053            Statement::Return { value: None } => {
2054                writeln!(self.out, "{level}return;")?;
2055            }
2056            Statement::Return { value: Some(expr) } => {
2057                let base_ty_res = &func_ctx.info[expr].ty;
2058                let mut resolved = base_ty_res.inner_with(&module.types);
2059                if let TypeInner::Pointer { base, space: _ } = *resolved {
2060                    resolved = &module.types[base].inner;
2061                }
2062
2063                if let TypeInner::Struct { .. } = *resolved {
2064                    // We can safely unwrap here, since we now we working with struct
2065                    let ty = base_ty_res.handle().unwrap();
2066                    let struct_name = &self.names[&NameKey::Type(ty)];
2067                    let variable_name = self.namer.call(&struct_name.to_lowercase());
2068                    write!(self.out, "{level}const {struct_name} {variable_name} = ",)?;
2069                    self.write_expr(module, expr, func_ctx)?;
2070                    writeln!(self.out, ";")?;
2071
2072                    // for entry point returns, we may need to reshuffle the outputs into a different struct
2073                    let ep_output = match func_ctx.ty {
2074                        back::FunctionType::Function(_) => None,
2075                        back::FunctionType::EntryPoint(index) => self
2076                            .entry_point_io
2077                            .get(&(index as usize))
2078                            .unwrap()
2079                            .output
2080                            .as_ref(),
2081                    };
2082                    let final_name = match ep_output {
2083                        Some(ep_output) => {
2084                            let final_name = self.namer.call(&variable_name);
2085                            write!(
2086                                self.out,
2087                                "{}const {} {} = {{ ",
2088                                level, ep_output.ty_name, final_name,
2089                            )?;
2090                            for (index, m) in ep_output.members.iter().enumerate() {
2091                                if index != 0 {
2092                                    write!(self.out, ", ")?;
2093                                }
2094                                let member_name = &self.names[&NameKey::StructMember(ty, m.index)];
2095                                write!(self.out, "{variable_name}.{member_name}")?;
2096                            }
2097                            writeln!(self.out, " }};")?;
2098                            final_name
2099                        }
2100                        None => variable_name,
2101                    };
2102                    writeln!(self.out, "{level}return {final_name};")?;
2103                } else {
2104                    write!(self.out, "{level}return ")?;
2105                    self.write_expr(module, expr, func_ctx)?;
2106                    writeln!(self.out, ";")?
2107                }
2108            }
2109            Statement::Store { pointer, value } => {
2110                let ty_inner = func_ctx.resolve_type(pointer, &module.types);
2111                if let Some(crate::AddressSpace::Storage { .. }) = ty_inner.pointer_space() {
2112                    let var_handle = self.fill_access_chain(module, pointer, func_ctx)?;
2113                    self.write_storage_store(
2114                        module,
2115                        var_handle,
2116                        StoreValue::Expression(value),
2117                        func_ctx,
2118                        level,
2119                        None,
2120                    )?;
2121                } else {
2122                    // We treat matrices of the form `matCx2` as a sequence of C `vec2`s.
2123                    // See the module-level block comment in mod.rs for details.
2124                    //
2125                    // We handle matrix Stores here directly (including sub accesses for Vectors and Scalars).
2126                    // Loads are handled by `Expression::AccessIndex` (since sub accesses work fine for Loads).
2127                    enum MatrixAccess {
2128                        Direct {
2129                            base: Handle<crate::Expression>,
2130                            index: u32,
2131                        },
2132                        Struct {
2133                            columns: crate::VectorSize,
2134                            base: Handle<crate::Expression>,
2135                        },
2136                    }
2137
2138                    let get_members = |expr: Handle<crate::Expression>| {
2139                        let resolved = func_ctx.resolve_type(expr, &module.types);
2140                        match *resolved {
2141                            TypeInner::Pointer { base, .. } => match module.types[base].inner {
2142                                TypeInner::Struct { ref members, .. } => Some(members),
2143                                _ => None,
2144                            },
2145                            _ => None,
2146                        }
2147                    };
2148
2149                    write!(self.out, "{level}")?;
2150
2151                    let matrix_access_on_lhs =
2152                        find_matrix_in_access_chain(module, pointer, func_ctx).and_then(
2153                            |(matrix_expr, vector, scalar)| match (
2154                                func_ctx.resolve_type(matrix_expr, &module.types),
2155                                &func_ctx.expressions[matrix_expr],
2156                            ) {
2157                                (
2158                                    &TypeInner::Pointer { base: ty, .. },
2159                                    &crate::Expression::AccessIndex { base, index },
2160                                ) if matches!(
2161                                    module.types[ty].inner,
2162                                    TypeInner::Matrix {
2163                                        rows: crate::VectorSize::Bi,
2164                                        ..
2165                                    }
2166                                ) && get_members(base)
2167                                    .map(|members| members[index as usize].binding.is_none())
2168                                    == Some(true) =>
2169                                {
2170                                    Some((MatrixAccess::Direct { base, index }, vector, scalar))
2171                                }
2172                                _ => {
2173                                    if let Some(MatrixType {
2174                                        columns,
2175                                        rows: crate::VectorSize::Bi,
2176                                        width: 4,
2177                                    }) = get_inner_matrix_of_struct_array_member(
2178                                        module,
2179                                        matrix_expr,
2180                                        func_ctx,
2181                                        true,
2182                                    ) {
2183                                        Some((
2184                                            MatrixAccess::Struct {
2185                                                columns,
2186                                                base: matrix_expr,
2187                                            },
2188                                            vector,
2189                                            scalar,
2190                                        ))
2191                                    } else {
2192                                        None
2193                                    }
2194                                }
2195                            },
2196                        );
2197
2198                    match matrix_access_on_lhs {
2199                        Some((MatrixAccess::Direct { index, base }, vector, scalar)) => {
2200                            let base_ty_res = &func_ctx.info[base].ty;
2201                            let resolved = base_ty_res.inner_with(&module.types);
2202                            let ty = match *resolved {
2203                                TypeInner::Pointer { base, .. } => base,
2204                                _ => base_ty_res.handle().unwrap(),
2205                            };
2206
2207                            if let Some(Index::Static(vec_index)) = vector {
2208                                self.write_expr(module, base, func_ctx)?;
2209                                write!(
2210                                    self.out,
2211                                    ".{}_{}",
2212                                    &self.names[&NameKey::StructMember(ty, index)],
2213                                    vec_index
2214                                )?;
2215
2216                                if let Some(scalar_index) = scalar {
2217                                    write!(self.out, "[")?;
2218                                    self.write_index(module, scalar_index, func_ctx)?;
2219                                    write!(self.out, "]")?;
2220                                }
2221
2222                                write!(self.out, " = ")?;
2223                                self.write_expr(module, value, func_ctx)?;
2224                                writeln!(self.out, ";")?;
2225                            } else {
2226                                let access = WrappedStructMatrixAccess { ty, index };
2227                                match (&vector, &scalar) {
2228                                    (&Some(_), &Some(_)) => {
2229                                        self.write_wrapped_struct_matrix_set_scalar_function_name(
2230                                            access,
2231                                        )?;
2232                                    }
2233                                    (&Some(_), &None) => {
2234                                        self.write_wrapped_struct_matrix_set_vec_function_name(
2235                                            access,
2236                                        )?;
2237                                    }
2238                                    (&None, _) => {
2239                                        self.write_wrapped_struct_matrix_set_function_name(access)?;
2240                                    }
2241                                }
2242
2243                                write!(self.out, "(")?;
2244                                self.write_expr(module, base, func_ctx)?;
2245                                write!(self.out, ", ")?;
2246                                self.write_expr(module, value, func_ctx)?;
2247
2248                                if let Some(Index::Expression(vec_index)) = vector {
2249                                    write!(self.out, ", ")?;
2250                                    self.write_expr(module, vec_index, func_ctx)?;
2251
2252                                    if let Some(scalar_index) = scalar {
2253                                        write!(self.out, ", ")?;
2254                                        self.write_index(module, scalar_index, func_ctx)?;
2255                                    }
2256                                }
2257                                writeln!(self.out, ");")?;
2258                            }
2259                        }
2260                        Some((
2261                            MatrixAccess::Struct { columns, base },
2262                            Some(Index::Expression(vec_index)),
2263                            scalar,
2264                        )) => {
2265                            // We handle `Store`s to __matCx2 column vectors and scalar elements via
2266                            // the previously injected functions __set_col_of_matCx2 / __set_el_of_matCx2.
2267
2268                            if scalar.is_some() {
2269                                write!(self.out, "__set_el_of_mat{}x2", columns as u8)?;
2270                            } else {
2271                                write!(self.out, "__set_col_of_mat{}x2", columns as u8)?;
2272                            }
2273                            write!(self.out, "(")?;
2274                            self.write_expr(module, base, func_ctx)?;
2275                            write!(self.out, ", ")?;
2276                            self.write_expr(module, vec_index, func_ctx)?;
2277
2278                            if let Some(scalar_index) = scalar {
2279                                write!(self.out, ", ")?;
2280                                self.write_index(module, scalar_index, func_ctx)?;
2281                            }
2282
2283                            write!(self.out, ", ")?;
2284                            self.write_expr(module, value, func_ctx)?;
2285
2286                            writeln!(self.out, ");")?;
2287                        }
2288                        Some((MatrixAccess::Struct { .. }, Some(Index::Static(_)), _))
2289                        | Some((MatrixAccess::Struct { .. }, None, _))
2290                        | None => {
2291                            self.write_expr(module, pointer, func_ctx)?;
2292                            write!(self.out, " = ")?;
2293
2294                            // We cast the RHS of this store in cases where the LHS
2295                            // is a struct member with type:
2296                            //  - matCx2 or
2297                            //  - a (possibly nested) array of matCx2's
2298                            if let Some(MatrixType {
2299                                columns,
2300                                rows: crate::VectorSize::Bi,
2301                                width: 4,
2302                            }) = get_inner_matrix_of_struct_array_member(
2303                                module, pointer, func_ctx, false,
2304                            ) {
2305                                let mut resolved = func_ctx.resolve_type(pointer, &module.types);
2306                                if let TypeInner::Pointer { base, .. } = *resolved {
2307                                    resolved = &module.types[base].inner;
2308                                }
2309
2310                                write!(self.out, "(__mat{}x2", columns as u8)?;
2311                                if let TypeInner::Array { base, size, .. } = *resolved {
2312                                    self.write_array_size(module, base, size)?;
2313                                }
2314                                write!(self.out, ")")?;
2315                            }
2316
2317                            self.write_expr(module, value, func_ctx)?;
2318                            writeln!(self.out, ";")?
2319                        }
2320                    }
2321                }
2322            }
2323            Statement::Loop {
2324                ref body,
2325                ref continuing,
2326                break_if,
2327            } => {
2328                let force_loop_bound_statements = self.gen_force_bounded_loop_statements(level);
2329                let gate_name = (!continuing.is_empty() || break_if.is_some())
2330                    .then(|| self.namer.call("loop_init"));
2331
2332                if let Some((ref decl, _)) = force_loop_bound_statements {
2333                    writeln!(self.out, "{decl}")?;
2334                }
2335                if let Some(ref gate_name) = gate_name {
2336                    writeln!(self.out, "{level}bool {gate_name} = true;")?;
2337                }
2338
2339                self.continue_ctx.enter_loop();
2340                writeln!(self.out, "{level}while(true) {{")?;
2341                if let Some((_, ref break_and_inc)) = force_loop_bound_statements {
2342                    writeln!(self.out, "{break_and_inc}")?;
2343                }
2344                let l2 = level.next();
2345                if let Some(gate_name) = gate_name {
2346                    writeln!(self.out, "{l2}if (!{gate_name}) {{")?;
2347                    let l3 = l2.next();
2348                    for sta in continuing.iter() {
2349                        self.write_stmt(module, sta, func_ctx, l3)?;
2350                    }
2351                    if let Some(condition) = break_if {
2352                        write!(self.out, "{l3}if (")?;
2353                        self.write_expr(module, condition, func_ctx)?;
2354                        writeln!(self.out, ") {{")?;
2355                        writeln!(self.out, "{}break;", l3.next())?;
2356                        writeln!(self.out, "{l3}}}")?;
2357                    }
2358                    writeln!(self.out, "{l2}}}")?;
2359                    writeln!(self.out, "{l2}{gate_name} = false;")?;
2360                }
2361
2362                for sta in body.iter() {
2363                    self.write_stmt(module, sta, func_ctx, l2)?;
2364                }
2365
2366                writeln!(self.out, "{level}}}")?;
2367                self.continue_ctx.exit_loop();
2368            }
2369            Statement::Break => writeln!(self.out, "{level}break;")?,
2370            Statement::Continue => {
2371                if let Some(variable) = self.continue_ctx.continue_encountered() {
2372                    writeln!(self.out, "{level}{variable} = true;")?;
2373                    writeln!(self.out, "{level}break;")?
2374                } else {
2375                    writeln!(self.out, "{level}continue;")?
2376                }
2377            }
2378            Statement::ControlBarrier(barrier) => {
2379                self.write_control_barrier(barrier, level)?;
2380            }
2381            Statement::MemoryBarrier(barrier) => {
2382                self.write_memory_barrier(barrier, level)?;
2383            }
2384            Statement::ImageStore {
2385                image,
2386                coordinate,
2387                array_index,
2388                value,
2389            } => {
2390                write!(self.out, "{level}")?;
2391                self.write_expr(module, image, func_ctx)?;
2392
2393                write!(self.out, "[")?;
2394                if let Some(index) = array_index {
2395                    // Array index accepted only for texture_storage_2d_array, so we can safety use int3(coordinate, array_index) here
2396                    write!(self.out, "int3(")?;
2397                    self.write_expr(module, coordinate, func_ctx)?;
2398                    write!(self.out, ", ")?;
2399                    self.write_expr(module, index, func_ctx)?;
2400                    write!(self.out, ")")?;
2401                } else {
2402                    self.write_expr(module, coordinate, func_ctx)?;
2403                }
2404                write!(self.out, "]")?;
2405
2406                write!(self.out, " = ")?;
2407                self.write_expr(module, value, func_ctx)?;
2408                writeln!(self.out, ";")?;
2409            }
2410            Statement::Call {
2411                function,
2412                ref arguments,
2413                result,
2414            } => {
2415                write!(self.out, "{level}")?;
2416                if let Some(expr) = result {
2417                    write!(self.out, "const ")?;
2418                    let name = Baked(expr).to_string();
2419                    let expr_ty = &func_ctx.info[expr].ty;
2420                    let ty_inner = match *expr_ty {
2421                        proc::TypeResolution::Handle(handle) => {
2422                            self.write_type(module, handle)?;
2423                            &module.types[handle].inner
2424                        }
2425                        proc::TypeResolution::Value(ref value) => {
2426                            self.write_value_type(module, value)?;
2427                            value
2428                        }
2429                    };
2430                    write!(self.out, " {name}")?;
2431                    if let TypeInner::Array { base, size, .. } = *ty_inner {
2432                        self.write_array_size(module, base, size)?;
2433                    }
2434                    write!(self.out, " = ")?;
2435                    self.named_expressions.insert(expr, name);
2436                }
2437                let func_name = &self.names[&NameKey::Function(function)];
2438                write!(self.out, "{func_name}(")?;
2439                for (index, argument) in arguments.iter().enumerate() {
2440                    if index != 0 {
2441                        write!(self.out, ", ")?;
2442                    }
2443                    self.write_expr(module, *argument, func_ctx)?;
2444                }
2445                writeln!(self.out, ");")?
2446            }
2447            Statement::Atomic {
2448                pointer,
2449                ref fun,
2450                value,
2451                result,
2452            } => {
2453                write!(self.out, "{level}")?;
2454                let res_var_info = if let Some(res_handle) = result {
2455                    let name = Baked(res_handle).to_string();
2456                    match func_ctx.info[res_handle].ty {
2457                        proc::TypeResolution::Handle(handle) => self.write_type(module, handle)?,
2458                        proc::TypeResolution::Value(ref value) => {
2459                            self.write_value_type(module, value)?
2460                        }
2461                    };
2462                    write!(self.out, " {name}; ")?;
2463                    self.named_expressions.insert(res_handle, name.clone());
2464                    Some((res_handle, name))
2465                } else {
2466                    None
2467                };
2468                let pointer_space = func_ctx
2469                    .resolve_type(pointer, &module.types)
2470                    .pointer_space()
2471                    .unwrap();
2472                let fun_str = fun.to_hlsl_suffix();
2473                let compare_expr = match *fun {
2474                    crate::AtomicFunction::Exchange { compare: Some(cmp) } => Some(cmp),
2475                    _ => None,
2476                };
2477                match pointer_space {
2478                    crate::AddressSpace::WorkGroup => {
2479                        write!(self.out, "Interlocked{fun_str}(")?;
2480                        self.write_expr(module, pointer, func_ctx)?;
2481                        self.emit_hlsl_atomic_tail(
2482                            module,
2483                            func_ctx,
2484                            fun,
2485                            compare_expr,
2486                            value,
2487                            &res_var_info,
2488                        )?;
2489                    }
2490                    crate::AddressSpace::Storage { .. } => {
2491                        let var_handle = self.fill_access_chain(module, pointer, func_ctx)?;
2492                        let var_name = &self.names[&NameKey::GlobalVariable(var_handle)];
2493                        let width = match func_ctx.resolve_type(value, &module.types) {
2494                            &TypeInner::Scalar(Scalar { width: 8, .. }) => "64",
2495                            _ => "",
2496                        };
2497                        write!(self.out, "{var_name}.Interlocked{fun_str}{width}(")?;
2498                        let chain = mem::take(&mut self.temp_access_chain);
2499                        self.write_storage_address(module, &chain, func_ctx)?;
2500                        self.temp_access_chain = chain;
2501                        self.emit_hlsl_atomic_tail(
2502                            module,
2503                            func_ctx,
2504                            fun,
2505                            compare_expr,
2506                            value,
2507                            &res_var_info,
2508                        )?;
2509                    }
2510                    ref other => {
2511                        return Err(Error::Custom(format!(
2512                            "invalid address space {other:?} for atomic statement"
2513                        )))
2514                    }
2515                }
2516                if let Some(cmp) = compare_expr {
2517                    if let Some(&(_res_handle, ref res_name)) = res_var_info.as_ref() {
2518                        write!(
2519                            self.out,
2520                            "{level}{res_name}.exchanged = ({res_name}.old_value == "
2521                        )?;
2522                        self.write_expr(module, cmp, func_ctx)?;
2523                        writeln!(self.out, ");")?;
2524                    }
2525                }
2526            }
2527            Statement::ImageAtomic {
2528                image,
2529                coordinate,
2530                array_index,
2531                fun,
2532                value,
2533            } => {
2534                write!(self.out, "{level}")?;
2535
2536                let fun_str = fun.to_hlsl_suffix();
2537                write!(self.out, "Interlocked{fun_str}(")?;
2538                self.write_expr(module, image, func_ctx)?;
2539                write!(self.out, "[")?;
2540                self.write_texture_coordinates(
2541                    "int",
2542                    coordinate,
2543                    array_index,
2544                    None,
2545                    module,
2546                    func_ctx,
2547                )?;
2548                write!(self.out, "],")?;
2549
2550                self.write_expr(module, value, func_ctx)?;
2551                writeln!(self.out, ");")?;
2552            }
2553            Statement::WorkGroupUniformLoad { pointer, result } => {
2554                self.write_control_barrier(crate::Barrier::WORK_GROUP, level)?;
2555                write!(self.out, "{level}")?;
2556                let name = Baked(result).to_string();
2557                self.write_named_expr(module, pointer, name, result, func_ctx)?;
2558
2559                self.write_control_barrier(crate::Barrier::WORK_GROUP, level)?;
2560            }
2561            Statement::Switch {
2562                selector,
2563                ref cases,
2564            } => {
2565                self.write_switch(module, func_ctx, level, selector, cases)?;
2566            }
2567            Statement::RayQuery { query, ref fun } => match *fun {
2568                RayQueryFunction::Initialize {
2569                    acceleration_structure,
2570                    descriptor,
2571                } => {
2572                    write!(self.out, "{level}")?;
2573                    self.write_expr(module, query, func_ctx)?;
2574                    write!(self.out, ".TraceRayInline(")?;
2575                    self.write_expr(module, acceleration_structure, func_ctx)?;
2576                    write!(self.out, ", ")?;
2577                    self.write_expr(module, descriptor, func_ctx)?;
2578                    write!(self.out, ".flags, ")?;
2579                    self.write_expr(module, descriptor, func_ctx)?;
2580                    write!(self.out, ".cull_mask, ")?;
2581                    write!(self.out, "RayDescFromRayDesc_(")?;
2582                    self.write_expr(module, descriptor, func_ctx)?;
2583                    writeln!(self.out, "));")?;
2584                }
2585                RayQueryFunction::Proceed { result } => {
2586                    write!(self.out, "{level}")?;
2587                    let name = Baked(result).to_string();
2588                    write!(self.out, "const bool {name} = ")?;
2589                    self.named_expressions.insert(result, name);
2590                    self.write_expr(module, query, func_ctx)?;
2591                    writeln!(self.out, ".Proceed();")?;
2592                }
2593                RayQueryFunction::GenerateIntersection { hit_t } => {
2594                    write!(self.out, "{level}")?;
2595                    self.write_expr(module, query, func_ctx)?;
2596                    write!(self.out, ".CommitProceduralPrimitiveHit(")?;
2597                    self.write_expr(module, hit_t, func_ctx)?;
2598                    writeln!(self.out, ");")?;
2599                }
2600                RayQueryFunction::ConfirmIntersection => {
2601                    write!(self.out, "{level}")?;
2602                    self.write_expr(module, query, func_ctx)?;
2603                    writeln!(self.out, ".CommitNonOpaqueTriangleHit();")?;
2604                }
2605                RayQueryFunction::Terminate => {
2606                    write!(self.out, "{level}")?;
2607                    self.write_expr(module, query, func_ctx)?;
2608                    writeln!(self.out, ".Abort();")?;
2609                }
2610            },
2611            Statement::MeshFunction(crate::MeshFunction::SetMeshOutputs {
2612                vertex_count,
2613                primitive_count,
2614            }) => {
2615                write!(self.out, "{level}SetMeshOutputCounts(")?;
2616                self.write_expr(module, vertex_count, func_ctx)?;
2617                write!(self.out, ", ")?;
2618                self.write_expr(module, primitive_count, func_ctx)?;
2619                write!(self.out, ");")?;
2620            }
2621            Statement::MeshFunction(
2622                crate::MeshFunction::SetVertex { .. } | crate::MeshFunction::SetPrimitive { .. },
2623            ) => unimplemented!(),
2624            Statement::SubgroupBallot { result, predicate } => {
2625                write!(self.out, "{level}")?;
2626                let name = Baked(result).to_string();
2627                write!(self.out, "const uint4 {name} = ")?;
2628                self.named_expressions.insert(result, name);
2629
2630                write!(self.out, "WaveActiveBallot(")?;
2631                match predicate {
2632                    Some(predicate) => self.write_expr(module, predicate, func_ctx)?,
2633                    None => write!(self.out, "true")?,
2634                }
2635                writeln!(self.out, ");")?;
2636            }
2637            Statement::SubgroupCollectiveOperation {
2638                op,
2639                collective_op,
2640                argument,
2641                result,
2642            } => {
2643                write!(self.out, "{level}")?;
2644                write!(self.out, "const ")?;
2645                let name = Baked(result).to_string();
2646                match func_ctx.info[result].ty {
2647                    proc::TypeResolution::Handle(handle) => self.write_type(module, handle)?,
2648                    proc::TypeResolution::Value(ref value) => {
2649                        self.write_value_type(module, value)?
2650                    }
2651                };
2652                write!(self.out, " {name} = ")?;
2653                self.named_expressions.insert(result, name);
2654
2655                match (collective_op, op) {
2656                    (crate::CollectiveOperation::Reduce, crate::SubgroupOperation::All) => {
2657                        write!(self.out, "WaveActiveAllTrue(")?
2658                    }
2659                    (crate::CollectiveOperation::Reduce, crate::SubgroupOperation::Any) => {
2660                        write!(self.out, "WaveActiveAnyTrue(")?
2661                    }
2662                    (crate::CollectiveOperation::Reduce, crate::SubgroupOperation::Add) => {
2663                        write!(self.out, "WaveActiveSum(")?
2664                    }
2665                    (crate::CollectiveOperation::Reduce, crate::SubgroupOperation::Mul) => {
2666                        write!(self.out, "WaveActiveProduct(")?
2667                    }
2668                    (crate::CollectiveOperation::Reduce, crate::SubgroupOperation::Max) => {
2669                        write!(self.out, "WaveActiveMax(")?
2670                    }
2671                    (crate::CollectiveOperation::Reduce, crate::SubgroupOperation::Min) => {
2672                        write!(self.out, "WaveActiveMin(")?
2673                    }
2674                    (crate::CollectiveOperation::Reduce, crate::SubgroupOperation::And) => {
2675                        write!(self.out, "WaveActiveBitAnd(")?
2676                    }
2677                    (crate::CollectiveOperation::Reduce, crate::SubgroupOperation::Or) => {
2678                        write!(self.out, "WaveActiveBitOr(")?
2679                    }
2680                    (crate::CollectiveOperation::Reduce, crate::SubgroupOperation::Xor) => {
2681                        write!(self.out, "WaveActiveBitXor(")?
2682                    }
2683                    (crate::CollectiveOperation::ExclusiveScan, crate::SubgroupOperation::Add) => {
2684                        write!(self.out, "WavePrefixSum(")?
2685                    }
2686                    (crate::CollectiveOperation::ExclusiveScan, crate::SubgroupOperation::Mul) => {
2687                        write!(self.out, "WavePrefixProduct(")?
2688                    }
2689                    (crate::CollectiveOperation::InclusiveScan, crate::SubgroupOperation::Add) => {
2690                        self.write_expr(module, argument, func_ctx)?;
2691                        write!(self.out, " + WavePrefixSum(")?;
2692                    }
2693                    (crate::CollectiveOperation::InclusiveScan, crate::SubgroupOperation::Mul) => {
2694                        self.write_expr(module, argument, func_ctx)?;
2695                        write!(self.out, " * WavePrefixProduct(")?;
2696                    }
2697                    _ => unimplemented!(),
2698                }
2699                self.write_expr(module, argument, func_ctx)?;
2700                writeln!(self.out, ");")?;
2701            }
2702            Statement::SubgroupGather {
2703                mode,
2704                argument,
2705                result,
2706            } => {
2707                write!(self.out, "{level}")?;
2708                write!(self.out, "const ")?;
2709                let name = Baked(result).to_string();
2710                match func_ctx.info[result].ty {
2711                    proc::TypeResolution::Handle(handle) => self.write_type(module, handle)?,
2712                    proc::TypeResolution::Value(ref value) => {
2713                        self.write_value_type(module, value)?
2714                    }
2715                };
2716                write!(self.out, " {name} = ")?;
2717                self.named_expressions.insert(result, name);
2718                match mode {
2719                    crate::GatherMode::BroadcastFirst => {
2720                        write!(self.out, "WaveReadLaneFirst(")?;
2721                        self.write_expr(module, argument, func_ctx)?;
2722                    }
2723                    crate::GatherMode::QuadBroadcast(index) => {
2724                        write!(self.out, "QuadReadLaneAt(")?;
2725                        self.write_expr(module, argument, func_ctx)?;
2726                        write!(self.out, ", ")?;
2727                        self.write_expr(module, index, func_ctx)?;
2728                    }
2729                    crate::GatherMode::QuadSwap(direction) => {
2730                        match direction {
2731                            crate::Direction::X => {
2732                                write!(self.out, "QuadReadAcrossX(")?;
2733                            }
2734                            crate::Direction::Y => {
2735                                write!(self.out, "QuadReadAcrossY(")?;
2736                            }
2737                            crate::Direction::Diagonal => {
2738                                write!(self.out, "QuadReadAcrossDiagonal(")?;
2739                            }
2740                        }
2741                        self.write_expr(module, argument, func_ctx)?;
2742                    }
2743                    _ => {
2744                        write!(self.out, "WaveReadLaneAt(")?;
2745                        self.write_expr(module, argument, func_ctx)?;
2746                        write!(self.out, ", ")?;
2747                        match mode {
2748                            crate::GatherMode::BroadcastFirst => unreachable!(),
2749                            crate::GatherMode::Broadcast(index)
2750                            | crate::GatherMode::Shuffle(index) => {
2751                                self.write_expr(module, index, func_ctx)?;
2752                            }
2753                            crate::GatherMode::ShuffleDown(index) => {
2754                                write!(self.out, "WaveGetLaneIndex() + ")?;
2755                                self.write_expr(module, index, func_ctx)?;
2756                            }
2757                            crate::GatherMode::ShuffleUp(index) => {
2758                                write!(self.out, "WaveGetLaneIndex() - ")?;
2759                                self.write_expr(module, index, func_ctx)?;
2760                            }
2761                            crate::GatherMode::ShuffleXor(index) => {
2762                                write!(self.out, "WaveGetLaneIndex() ^ ")?;
2763                                self.write_expr(module, index, func_ctx)?;
2764                            }
2765                            crate::GatherMode::QuadBroadcast(_) => unreachable!(),
2766                            crate::GatherMode::QuadSwap(_) => unreachable!(),
2767                        }
2768                    }
2769                }
2770                writeln!(self.out, ");")?;
2771            }
2772        }
2773
2774        Ok(())
2775    }
2776
2777    fn write_const_expression(
2778        &mut self,
2779        module: &Module,
2780        expr: Handle<crate::Expression>,
2781        arena: &crate::Arena<crate::Expression>,
2782    ) -> BackendResult {
2783        self.write_possibly_const_expression(module, expr, arena, |writer, expr| {
2784            writer.write_const_expression(module, expr, arena)
2785        })
2786    }
2787
2788    pub(super) fn write_literal(&mut self, literal: crate::Literal) -> BackendResult {
2789        match literal {
2790            crate::Literal::F64(value) => write!(self.out, "{value:?}L")?,
2791            crate::Literal::F32(value) => write!(self.out, "{value:?}")?,
2792            crate::Literal::F16(value) => write!(self.out, "{value:?}h")?,
2793            crate::Literal::U32(value) => write!(self.out, "{value}u")?,
2794            // `-2147483648` is parsed by some compilers as unary negation of
2795            // positive 2147483648, which is too large for an int, causing
2796            // issues for some compilers. Neither DXC nor FXC appear to have
2797            // this problem, but this is not specified and could change. We
2798            // therefore use `-2147483647 - 1` as a precaution.
2799            crate::Literal::I32(value) if value == i32::MIN => {
2800                write!(self.out, "int({} - 1)", value + 1)?
2801            }
2802            // HLSL has no suffix for explicit i32 literals, but not using any suffix
2803            // makes the type ambiguous which prevents overload resolution from
2804            // working. So we explicitly use the int() constructor syntax.
2805            crate::Literal::I32(value) => write!(self.out, "int({value})")?,
2806            crate::Literal::U64(value) => write!(self.out, "{value}uL")?,
2807            // I64 version of the minimum I32 value issue described above.
2808            crate::Literal::I64(value) if value == i64::MIN => {
2809                write!(self.out, "({}L - 1L)", value + 1)?;
2810            }
2811            crate::Literal::I64(value) => write!(self.out, "{value}L")?,
2812            crate::Literal::Bool(value) => write!(self.out, "{value}")?,
2813            crate::Literal::AbstractInt(_) | crate::Literal::AbstractFloat(_) => {
2814                return Err(Error::Custom(
2815                    "Abstract types should not appear in IR presented to backends".into(),
2816                ));
2817            }
2818        }
2819        Ok(())
2820    }
2821
2822    fn write_possibly_const_expression<E>(
2823        &mut self,
2824        module: &Module,
2825        expr: Handle<crate::Expression>,
2826        expressions: &crate::Arena<crate::Expression>,
2827        write_expression: E,
2828    ) -> BackendResult
2829    where
2830        E: Fn(&mut Self, Handle<crate::Expression>) -> BackendResult,
2831    {
2832        use crate::Expression;
2833
2834        match expressions[expr] {
2835            Expression::Literal(literal) => {
2836                self.write_literal(literal)?;
2837            }
2838            Expression::Constant(handle) => {
2839                let constant = &module.constants[handle];
2840                if constant.name.is_some() {
2841                    write!(self.out, "{}", self.names[&NameKey::Constant(handle)])?;
2842                } else {
2843                    self.write_const_expression(module, constant.init, &module.global_expressions)?;
2844                }
2845            }
2846            Expression::ZeroValue(ty) => {
2847                self.write_wrapped_zero_value_function_name(module, WrappedZeroValue { ty })?;
2848                write!(self.out, "()")?;
2849            }
2850            Expression::Compose { ty, ref components } => {
2851                match module.types[ty].inner {
2852                    TypeInner::Struct { .. } | TypeInner::Array { .. } => {
2853                        self.write_wrapped_constructor_function_name(
2854                            module,
2855                            WrappedConstructor { ty },
2856                        )?;
2857                    }
2858                    _ => {
2859                        self.write_type(module, ty)?;
2860                    }
2861                };
2862                write!(self.out, "(")?;
2863                for (index, component) in components.iter().enumerate() {
2864                    if index != 0 {
2865                        write!(self.out, ", ")?;
2866                    }
2867                    write_expression(self, *component)?;
2868                }
2869                write!(self.out, ")")?;
2870            }
2871            Expression::Splat { size, value } => {
2872                // hlsl is not supported one value constructor
2873                // if we write, for example, int4(0), dxc returns error:
2874                // error: too few elements in vector initialization (expected 4 elements, have 1)
2875                let number_of_components = match size {
2876                    crate::VectorSize::Bi => "xx",
2877                    crate::VectorSize::Tri => "xxx",
2878                    crate::VectorSize::Quad => "xxxx",
2879                };
2880                write!(self.out, "(")?;
2881                write_expression(self, value)?;
2882                write!(self.out, ").{number_of_components}")?
2883            }
2884            _ => {
2885                return Err(Error::Override);
2886            }
2887        }
2888
2889        Ok(())
2890    }
2891
2892    /// Helper method to write expressions
2893    ///
2894    /// # Notes
2895    /// Doesn't add any newlines or leading/trailing spaces
2896    pub(super) fn write_expr(
2897        &mut self,
2898        module: &Module,
2899        expr: Handle<crate::Expression>,
2900        func_ctx: &back::FunctionCtx<'_>,
2901    ) -> BackendResult {
2902        use crate::Expression;
2903
2904        // Handle the special semantics of vertex_index/instance_index
2905        let ff_input = if self.options.special_constants_binding.is_some() {
2906            func_ctx.is_fixed_function_input(expr, module)
2907        } else {
2908            None
2909        };
2910        let closing_bracket = match ff_input {
2911            Some(crate::BuiltIn::VertexIndex) => {
2912                write!(self.out, "({SPECIAL_CBUF_VAR}.{SPECIAL_FIRST_VERTEX} + ")?;
2913                ")"
2914            }
2915            Some(crate::BuiltIn::InstanceIndex) => {
2916                write!(self.out, "({SPECIAL_CBUF_VAR}.{SPECIAL_FIRST_INSTANCE} + ",)?;
2917                ")"
2918            }
2919            Some(crate::BuiltIn::NumWorkGroups) => {
2920                // Note: despite their names (`FIRST_VERTEX` and `FIRST_INSTANCE`),
2921                // in compute shaders the special constants contain the number
2922                // of workgroups, which we are using here.
2923                write!(
2924                    self.out,
2925                    "uint3({SPECIAL_CBUF_VAR}.{SPECIAL_FIRST_VERTEX}, {SPECIAL_CBUF_VAR}.{SPECIAL_FIRST_INSTANCE}, {SPECIAL_CBUF_VAR}.{SPECIAL_OTHER})",
2926                )?;
2927                return Ok(());
2928            }
2929            _ => "",
2930        };
2931
2932        if let Some(name) = self.named_expressions.get(&expr) {
2933            write!(self.out, "{name}{closing_bracket}")?;
2934            return Ok(());
2935        }
2936
2937        let expression = &func_ctx.expressions[expr];
2938
2939        match *expression {
2940            Expression::Literal(_)
2941            | Expression::Constant(_)
2942            | Expression::ZeroValue(_)
2943            | Expression::Compose { .. }
2944            | Expression::Splat { .. } => {
2945                self.write_possibly_const_expression(
2946                    module,
2947                    expr,
2948                    func_ctx.expressions,
2949                    |writer, expr| writer.write_expr(module, expr, func_ctx),
2950                )?;
2951            }
2952            Expression::Override(_) => return Err(Error::Override),
2953            // Avoid undefined behaviour for addition, subtraction, and
2954            // multiplication of signed integers by casting operands to
2955            // unsigned, performing the operation, then casting the result back
2956            // to signed.
2957            // TODO(#7109): This relies on the asint()/asuint() functions which only work
2958            // for 32-bit types, so we must find another solution for different bit widths.
2959            Expression::Binary {
2960                op:
2961                    op @ crate::BinaryOperator::Add
2962                    | op @ crate::BinaryOperator::Subtract
2963                    | op @ crate::BinaryOperator::Multiply,
2964                left,
2965                right,
2966            } if matches!(
2967                func_ctx.resolve_type(expr, &module.types).scalar(),
2968                Some(Scalar::I32)
2969            ) =>
2970            {
2971                write!(self.out, "asint(asuint(",)?;
2972                self.write_expr(module, left, func_ctx)?;
2973                write!(self.out, ") {} asuint(", back::binary_operation_str(op))?;
2974                self.write_expr(module, right, func_ctx)?;
2975                write!(self.out, "))")?;
2976            }
2977            // All of the multiplication can be expressed as `mul`,
2978            // except vector * vector, which needs to use the "*" operator.
2979            Expression::Binary {
2980                op: crate::BinaryOperator::Multiply,
2981                left,
2982                right,
2983            } if func_ctx.resolve_type(left, &module.types).is_matrix()
2984                || func_ctx.resolve_type(right, &module.types).is_matrix() =>
2985            {
2986                // We intentionally flip the order of multiplication as our matrices are implicitly transposed.
2987                write!(self.out, "mul(")?;
2988                self.write_expr(module, right, func_ctx)?;
2989                write!(self.out, ", ")?;
2990                self.write_expr(module, left, func_ctx)?;
2991                write!(self.out, ")")?;
2992            }
2993
2994            // WGSL says that floating-point division by zero should return
2995            // infinity. Microsoft's Direct3D 11 functional specification
2996            // (https://microsoft.github.io/DirectX-Specs/d3d/archive/D3D11_3_FunctionalSpec.htm)
2997            // says:
2998            //
2999            //     Divide by 0 produces +/- INF, except 0/0 which results in NaN.
3000            //
3001            // which is what we want. The DXIL specification for the FDiv
3002            // instruction corroborates this:
3003            //
3004            // https://github.com/microsoft/DirectXShaderCompiler/blob/main/docs/DXIL.rst#fdiv
3005            Expression::Binary {
3006                op: crate::BinaryOperator::Divide,
3007                left,
3008                right,
3009            } if matches!(
3010                func_ctx.resolve_type(expr, &module.types).scalar_kind(),
3011                Some(ScalarKind::Sint | ScalarKind::Uint)
3012            ) =>
3013            {
3014                write!(self.out, "{DIV_FUNCTION}(")?;
3015                self.write_expr(module, left, func_ctx)?;
3016                write!(self.out, ", ")?;
3017                self.write_expr(module, right, func_ctx)?;
3018                write!(self.out, ")")?;
3019            }
3020
3021            Expression::Binary {
3022                op: crate::BinaryOperator::Modulo,
3023                left,
3024                right,
3025            } if matches!(
3026                func_ctx.resolve_type(expr, &module.types).scalar_kind(),
3027                Some(ScalarKind::Sint | ScalarKind::Uint | ScalarKind::Float)
3028            ) =>
3029            {
3030                write!(self.out, "{MOD_FUNCTION}(")?;
3031                self.write_expr(module, left, func_ctx)?;
3032                write!(self.out, ", ")?;
3033                self.write_expr(module, right, func_ctx)?;
3034                write!(self.out, ")")?;
3035            }
3036
3037            Expression::Binary { op, left, right } => {
3038                write!(self.out, "(")?;
3039                self.write_expr(module, left, func_ctx)?;
3040                write!(self.out, " {} ", back::binary_operation_str(op))?;
3041                self.write_expr(module, right, func_ctx)?;
3042                write!(self.out, ")")?;
3043            }
3044            Expression::Access { base, index } => {
3045                if let Some(crate::AddressSpace::Storage { .. }) =
3046                    func_ctx.resolve_type(expr, &module.types).pointer_space()
3047                {
3048                    // do nothing, the chain is written on `Load`/`Store`
3049                } else {
3050                    // We use the function __get_col_of_matCx2 here in cases
3051                    // where `base`s type resolves to a matCx2 and is part of a
3052                    // struct member with type of (possibly nested) array of matCx2's.
3053                    //
3054                    // Note that this only works for `Load`s and we handle
3055                    // `Store`s differently in `Statement::Store`.
3056                    if let Some(MatrixType {
3057                        columns,
3058                        rows: crate::VectorSize::Bi,
3059                        width: 4,
3060                    }) = get_inner_matrix_of_struct_array_member(module, base, func_ctx, true)
3061                        .or_else(|| get_global_uniform_matrix(module, base, func_ctx))
3062                    {
3063                        write!(self.out, "__get_col_of_mat{}x2(", columns as u8)?;
3064                        self.write_expr(module, base, func_ctx)?;
3065                        write!(self.out, ", ")?;
3066                        self.write_expr(module, index, func_ctx)?;
3067                        write!(self.out, ")")?;
3068                        return Ok(());
3069                    }
3070
3071                    let resolved = func_ctx.resolve_type(base, &module.types);
3072
3073                    let (indexing_binding_array, non_uniform_qualifier) = match *resolved {
3074                        TypeInner::BindingArray { .. } => {
3075                            let uniformity = &func_ctx.info[index].uniformity;
3076
3077                            (true, uniformity.non_uniform_result.is_some())
3078                        }
3079                        _ => (false, false),
3080                    };
3081
3082                    self.write_expr(module, base, func_ctx)?;
3083
3084                    let array_sampler_info = self.sampler_binding_array_info_from_expression(
3085                        module, func_ctx, base, resolved,
3086                    );
3087
3088                    if let Some(ref info) = array_sampler_info {
3089                        write!(self.out, "{}[", info.sampler_heap_name)?;
3090                    } else {
3091                        write!(self.out, "[")?;
3092                    }
3093
3094                    let needs_bound_check = self.options.restrict_indexing
3095                        && !indexing_binding_array
3096                        && match resolved.pointer_space() {
3097                            Some(
3098                                crate::AddressSpace::Function
3099                                | crate::AddressSpace::Private
3100                                | crate::AddressSpace::WorkGroup
3101                                | crate::AddressSpace::PushConstant
3102                                | crate::AddressSpace::TaskPayload,
3103                            )
3104                            | None => true,
3105                            Some(crate::AddressSpace::Uniform) => {
3106                                // check if BindTarget.restrict_indexing is set, this is used for dynamic buffers
3107                                let var_handle = self.fill_access_chain(module, base, func_ctx)?;
3108                                let bind_target = self
3109                                    .options
3110                                    .resolve_resource_binding(
3111                                        module.global_variables[var_handle]
3112                                            .binding
3113                                            .as_ref()
3114                                            .unwrap(),
3115                                    )
3116                                    .unwrap();
3117                                bind_target.restrict_indexing
3118                            }
3119                            Some(
3120                                crate::AddressSpace::Handle | crate::AddressSpace::Storage { .. },
3121                            ) => unreachable!(),
3122                        };
3123                    // Decide whether this index needs to be clamped to fall within range.
3124                    let restriction_needed = if needs_bound_check {
3125                        index::access_needs_check(
3126                            base,
3127                            index::GuardedIndex::Expression(index),
3128                            module,
3129                            func_ctx.expressions,
3130                            func_ctx.info,
3131                        )
3132                    } else {
3133                        None
3134                    };
3135                    if let Some(limit) = restriction_needed {
3136                        write!(self.out, "min(uint(")?;
3137                        self.write_expr(module, index, func_ctx)?;
3138                        write!(self.out, "), ")?;
3139                        match limit {
3140                            index::IndexableLength::Known(limit) => {
3141                                write!(self.out, "{}u", limit - 1)?;
3142                            }
3143                            index::IndexableLength::Dynamic => unreachable!(),
3144                        }
3145                        write!(self.out, ")")?;
3146                    } else {
3147                        if non_uniform_qualifier {
3148                            write!(self.out, "NonUniformResourceIndex(")?;
3149                        }
3150                        if let Some(ref info) = array_sampler_info {
3151                            write!(
3152                                self.out,
3153                                "{}[{} + ",
3154                                info.sampler_index_buffer_name, info.binding_array_base_index_name,
3155                            )?;
3156                        }
3157                        self.write_expr(module, index, func_ctx)?;
3158                        if array_sampler_info.is_some() {
3159                            write!(self.out, "]")?;
3160                        }
3161                        if non_uniform_qualifier {
3162                            write!(self.out, ")")?;
3163                        }
3164                    }
3165
3166                    write!(self.out, "]")?;
3167                }
3168            }
3169            Expression::AccessIndex { base, index } => {
3170                if let Some(crate::AddressSpace::Storage { .. }) =
3171                    func_ctx.resolve_type(expr, &module.types).pointer_space()
3172                {
3173                    // do nothing, the chain is written on `Load`/`Store`
3174                } else {
3175                    // See if we need to write the matrix column access in a
3176                    // special way since the type of `base` is our special
3177                    // __matCx2 struct.
3178                    if let Some(MatrixType {
3179                        rows: crate::VectorSize::Bi,
3180                        width: 4,
3181                        ..
3182                    }) = get_inner_matrix_of_struct_array_member(module, base, func_ctx, true)
3183                        .or_else(|| get_global_uniform_matrix(module, base, func_ctx))
3184                    {
3185                        self.write_expr(module, base, func_ctx)?;
3186                        write!(self.out, "._{index}")?;
3187                        return Ok(());
3188                    }
3189
3190                    let base_ty_res = &func_ctx.info[base].ty;
3191                    let mut resolved = base_ty_res.inner_with(&module.types);
3192                    let base_ty_handle = match *resolved {
3193                        TypeInner::Pointer { base, .. } => {
3194                            resolved = &module.types[base].inner;
3195                            Some(base)
3196                        }
3197                        _ => base_ty_res.handle(),
3198                    };
3199
3200                    // We treat matrices of the form `matCx2` as a sequence of C `vec2`s.
3201                    // See the module-level block comment in mod.rs for details.
3202                    //
3203                    // We handle matrix reconstruction here for Loads.
3204                    // Stores are handled directly by `Statement::Store`.
3205                    if let TypeInner::Struct { ref members, .. } = *resolved {
3206                        let member = &members[index as usize];
3207
3208                        match module.types[member.ty].inner {
3209                            TypeInner::Matrix {
3210                                rows: crate::VectorSize::Bi,
3211                                ..
3212                            } if member.binding.is_none() => {
3213                                let ty = base_ty_handle.unwrap();
3214                                self.write_wrapped_struct_matrix_get_function_name(
3215                                    WrappedStructMatrixAccess { ty, index },
3216                                )?;
3217                                write!(self.out, "(")?;
3218                                self.write_expr(module, base, func_ctx)?;
3219                                write!(self.out, ")")?;
3220                                return Ok(());
3221                            }
3222                            _ => {}
3223                        }
3224                    }
3225
3226                    let array_sampler_info = self.sampler_binding_array_info_from_expression(
3227                        module, func_ctx, base, resolved,
3228                    );
3229
3230                    if let Some(ref info) = array_sampler_info {
3231                        write!(
3232                            self.out,
3233                            "{}[{}",
3234                            info.sampler_heap_name, info.sampler_index_buffer_name
3235                        )?;
3236                    }
3237
3238                    self.write_expr(module, base, func_ctx)?;
3239
3240                    match *resolved {
3241                        // We specifically lift the ValuePointer to this case. While `[0]` is valid
3242                        // HLSL for any vector behind a value pointer, FXC completely miscompiles
3243                        // it and generates completely nonsensical DXBC.
3244                        //
3245                        // See https://github.com/gfx-rs/naga/issues/2095 for more details.
3246                        TypeInner::Vector { .. } | TypeInner::ValuePointer { .. } => {
3247                            // Write vector access as a swizzle
3248                            write!(self.out, ".{}", back::COMPONENTS[index as usize])?
3249                        }
3250                        TypeInner::Matrix { .. }
3251                        | TypeInner::Array { .. }
3252                        | TypeInner::BindingArray { .. } => {
3253                            if let Some(ref info) = array_sampler_info {
3254                                write!(
3255                                    self.out,
3256                                    "[{} + {index}]",
3257                                    info.binding_array_base_index_name
3258                                )?;
3259                            } else {
3260                                write!(self.out, "[{index}]")?;
3261                            }
3262                        }
3263                        TypeInner::Struct { .. } => {
3264                            // This will never panic in case the type is a `Struct`, this is not true
3265                            // for other types so we can only check while inside this match arm
3266                            let ty = base_ty_handle.unwrap();
3267
3268                            write!(
3269                                self.out,
3270                                ".{}",
3271                                &self.names[&NameKey::StructMember(ty, index)]
3272                            )?
3273                        }
3274                        ref other => return Err(Error::Custom(format!("Cannot index {other:?}"))),
3275                    }
3276
3277                    if array_sampler_info.is_some() {
3278                        write!(self.out, "]")?;
3279                    }
3280                }
3281            }
3282            Expression::FunctionArgument(pos) => {
3283                let ty = func_ctx.resolve_type(expr, &module.types);
3284
3285                // We know that any external texture function argument has been expanded into
3286                // separate consecutive arguments for each plane and the parameters buffer. And we
3287                // also know that external textures can only ever be used as an argument to another
3288                // function. Therefore we can simply emit each of the expanded arguments in a
3289                // consecutive comma-separated list.
3290                if let TypeInner::Image {
3291                    class: crate::ImageClass::External,
3292                    ..
3293                } = *ty
3294                {
3295                    let plane_names = [0, 1, 2].map(|i| {
3296                        &self.names[&func_ctx
3297                            .external_texture_argument_key(pos, ExternalTextureNameKey::Plane(i))]
3298                    });
3299                    let params_name = &self.names[&func_ctx
3300                        .external_texture_argument_key(pos, ExternalTextureNameKey::Params)];
3301                    write!(
3302                        self.out,
3303                        "{}, {}, {}, {}",
3304                        plane_names[0], plane_names[1], plane_names[2], params_name
3305                    )?;
3306                } else {
3307                    let key = func_ctx.argument_key(pos);
3308                    let name = &self.names[&key];
3309                    write!(self.out, "{name}")?;
3310                }
3311            }
3312            Expression::ImageSample {
3313                coordinate,
3314                image,
3315                sampler,
3316                clamp_to_edge: true,
3317                gather: None,
3318                array_index: None,
3319                offset: None,
3320                level: crate::SampleLevel::Zero,
3321                depth_ref: None,
3322            } => {
3323                write!(self.out, "{IMAGE_SAMPLE_BASE_CLAMP_TO_EDGE_FUNCTION}(")?;
3324                self.write_expr(module, image, func_ctx)?;
3325                write!(self.out, ", ")?;
3326                self.write_expr(module, sampler, func_ctx)?;
3327                write!(self.out, ", ")?;
3328                self.write_expr(module, coordinate, func_ctx)?;
3329                write!(self.out, ")")?;
3330            }
3331            Expression::ImageSample {
3332                image,
3333                sampler,
3334                gather,
3335                coordinate,
3336                array_index,
3337                offset,
3338                level,
3339                depth_ref,
3340                clamp_to_edge,
3341            } => {
3342                if clamp_to_edge {
3343                    return Err(Error::Custom(
3344                        "ImageSample::clamp_to_edge should have been validated out".to_string(),
3345                    ));
3346                }
3347
3348                use crate::SampleLevel as Sl;
3349                const COMPONENTS: [&str; 4] = ["", "Green", "Blue", "Alpha"];
3350
3351                let (base_str, component_str) = match gather {
3352                    Some(component) => ("Gather", COMPONENTS[component as usize]),
3353                    None => ("Sample", ""),
3354                };
3355                let cmp_str = match depth_ref {
3356                    Some(_) => "Cmp",
3357                    None => "",
3358                };
3359                let level_str = match level {
3360                    Sl::Zero if gather.is_none() => "LevelZero",
3361                    Sl::Auto | Sl::Zero => "",
3362                    Sl::Exact(_) => "Level",
3363                    Sl::Bias(_) => "Bias",
3364                    Sl::Gradient { .. } => "Grad",
3365                };
3366
3367                self.write_expr(module, image, func_ctx)?;
3368                write!(self.out, ".{base_str}{cmp_str}{component_str}{level_str}(")?;
3369                self.write_expr(module, sampler, func_ctx)?;
3370                write!(self.out, ", ")?;
3371                self.write_texture_coordinates(
3372                    "float",
3373                    coordinate,
3374                    array_index,
3375                    None,
3376                    module,
3377                    func_ctx,
3378                )?;
3379
3380                if let Some(depth_ref) = depth_ref {
3381                    write!(self.out, ", ")?;
3382                    self.write_expr(module, depth_ref, func_ctx)?;
3383                }
3384
3385                match level {
3386                    Sl::Auto | Sl::Zero => {}
3387                    Sl::Exact(expr) => {
3388                        write!(self.out, ", ")?;
3389                        self.write_expr(module, expr, func_ctx)?;
3390                    }
3391                    Sl::Bias(expr) => {
3392                        write!(self.out, ", ")?;
3393                        self.write_expr(module, expr, func_ctx)?;
3394                    }
3395                    Sl::Gradient { x, y } => {
3396                        write!(self.out, ", ")?;
3397                        self.write_expr(module, x, func_ctx)?;
3398                        write!(self.out, ", ")?;
3399                        self.write_expr(module, y, func_ctx)?;
3400                    }
3401                }
3402
3403                if let Some(offset) = offset {
3404                    write!(self.out, ", ")?;
3405                    write!(self.out, "int2(")?; // work around https://github.com/microsoft/DirectXShaderCompiler/issues/5082#issuecomment-1540147807
3406                    self.write_const_expression(module, offset, func_ctx.expressions)?;
3407                    write!(self.out, ")")?;
3408                }
3409
3410                write!(self.out, ")")?;
3411            }
3412            Expression::ImageQuery { image, query } => {
3413                // use wrapped image query function
3414                if let TypeInner::Image {
3415                    dim,
3416                    arrayed,
3417                    class,
3418                } = *func_ctx.resolve_type(image, &module.types)
3419                {
3420                    let wrapped_image_query = WrappedImageQuery {
3421                        dim,
3422                        arrayed,
3423                        class,
3424                        query: query.into(),
3425                    };
3426
3427                    self.write_wrapped_image_query_function_name(wrapped_image_query)?;
3428                    write!(self.out, "(")?;
3429                    // Image always first param
3430                    self.write_expr(module, image, func_ctx)?;
3431                    if let crate::ImageQuery::Size { level: Some(level) } = query {
3432                        write!(self.out, ", ")?;
3433                        self.write_expr(module, level, func_ctx)?;
3434                    }
3435                    write!(self.out, ")")?;
3436                }
3437            }
3438            Expression::ImageLoad {
3439                image,
3440                coordinate,
3441                array_index,
3442                sample,
3443                level,
3444            } => self.write_image_load(
3445                &module,
3446                expr,
3447                func_ctx,
3448                image,
3449                coordinate,
3450                array_index,
3451                sample,
3452                level,
3453            )?,
3454            Expression::GlobalVariable(handle) => {
3455                let global_variable = &module.global_variables[handle];
3456                let ty = &module.types[global_variable.ty].inner;
3457
3458                // In the case of binding arrays of samplers, we need to not write anything
3459                // as the we are in the wrong position to fully write the expression.
3460                //
3461                // The entire writing is done by AccessIndex.
3462                let is_binding_array_of_samplers = match *ty {
3463                    TypeInner::BindingArray { base, .. } => {
3464                        let base_ty = &module.types[base].inner;
3465                        matches!(*base_ty, TypeInner::Sampler { .. })
3466                    }
3467                    _ => false,
3468                };
3469
3470                let is_storage_space =
3471                    matches!(global_variable.space, crate::AddressSpace::Storage { .. });
3472
3473                // Our external texture global variable has been expanded into multiple
3474                // global variables, one for each plane and the parameters buffer.
3475                // External textures can only ever be used as arguments to a function
3476                // call, and we know that an external texture argument to any function
3477                // will have been expanded to separate consecutive arguments for each
3478                // plane and the parameters buffer. Therefore we can simply emit each of
3479                // the expanded global variables in a consecutive comma-separated list.
3480                if let TypeInner::Image {
3481                    class: crate::ImageClass::External,
3482                    ..
3483                } = *ty
3484                {
3485                    let plane_names = [0, 1, 2].map(|i| {
3486                        &self.names[&NameKey::ExternalTextureGlobalVariable(
3487                            handle,
3488                            ExternalTextureNameKey::Plane(i),
3489                        )]
3490                    });
3491                    let params_name = &self.names[&NameKey::ExternalTextureGlobalVariable(
3492                        handle,
3493                        ExternalTextureNameKey::Params,
3494                    )];
3495                    write!(
3496                        self.out,
3497                        "{}, {}, {}, {}",
3498                        plane_names[0], plane_names[1], plane_names[2], params_name
3499                    )?;
3500                } else if !is_binding_array_of_samplers && !is_storage_space {
3501                    let name = &self.names[&NameKey::GlobalVariable(handle)];
3502                    write!(self.out, "{name}")?;
3503                }
3504            }
3505            Expression::LocalVariable(handle) => {
3506                write!(self.out, "{}", self.names[&func_ctx.name_key(handle)])?
3507            }
3508            Expression::Load { pointer } => {
3509                match func_ctx
3510                    .resolve_type(pointer, &module.types)
3511                    .pointer_space()
3512                {
3513                    Some(crate::AddressSpace::Storage { .. }) => {
3514                        let var_handle = self.fill_access_chain(module, pointer, func_ctx)?;
3515                        let result_ty = func_ctx.info[expr].ty.clone();
3516                        self.write_storage_load(module, var_handle, result_ty, func_ctx)?;
3517                    }
3518                    _ => {
3519                        let mut close_paren = false;
3520
3521                        // We cast the value loaded to a native HLSL floatCx2
3522                        // in cases where it is of type:
3523                        //  - __matCx2 or
3524                        //  - a (possibly nested) array of __matCx2's
3525                        if let Some(MatrixType {
3526                            rows: crate::VectorSize::Bi,
3527                            width: 4,
3528                            ..
3529                        }) = get_inner_matrix_of_struct_array_member(
3530                            module, pointer, func_ctx, false,
3531                        )
3532                        .or_else(|| get_inner_matrix_of_global_uniform(module, pointer, func_ctx))
3533                        {
3534                            let mut resolved = func_ctx.resolve_type(pointer, &module.types);
3535                            let ptr_tr = resolved.pointer_base_type();
3536                            if let Some(ptr_ty) =
3537                                ptr_tr.as_ref().map(|tr| tr.inner_with(&module.types))
3538                            {
3539                                resolved = ptr_ty;
3540                            }
3541
3542                            write!(self.out, "((")?;
3543                            if let TypeInner::Array { base, size, .. } = *resolved {
3544                                self.write_type(module, base)?;
3545                                self.write_array_size(module, base, size)?;
3546                            } else {
3547                                self.write_value_type(module, resolved)?;
3548                            }
3549                            write!(self.out, ")")?;
3550                            close_paren = true;
3551                        }
3552
3553                        self.write_expr(module, pointer, func_ctx)?;
3554
3555                        if close_paren {
3556                            write!(self.out, ")")?;
3557                        }
3558                    }
3559                }
3560            }
3561            Expression::Unary { op, expr } => {
3562                // https://docs.microsoft.com/en-us/windows/win32/direct3dhlsl/dx-graphics-hlsl-operators#unary-operators
3563                let op_str = match op {
3564                    crate::UnaryOperator::Negate => {
3565                        match func_ctx.resolve_type(expr, &module.types).scalar() {
3566                            Some(Scalar::I32) => NEG_FUNCTION,
3567                            _ => "-",
3568                        }
3569                    }
3570                    crate::UnaryOperator::LogicalNot => "!",
3571                    crate::UnaryOperator::BitwiseNot => "~",
3572                };
3573                write!(self.out, "{op_str}(")?;
3574                self.write_expr(module, expr, func_ctx)?;
3575                write!(self.out, ")")?;
3576            }
3577            Expression::As {
3578                expr,
3579                kind,
3580                convert,
3581            } => {
3582                let inner = func_ctx.resolve_type(expr, &module.types);
3583                if inner.scalar_kind() == Some(ScalarKind::Float)
3584                    && (kind == ScalarKind::Sint || kind == ScalarKind::Uint)
3585                    && convert.is_some()
3586                {
3587                    // Use helper functions for float to int casts in order to
3588                    // avoid undefined behaviour when value is out of range for
3589                    // the target type.
3590                    let fun_name = match (kind, convert) {
3591                        (ScalarKind::Sint, Some(4)) => F2I32_FUNCTION,
3592                        (ScalarKind::Uint, Some(4)) => F2U32_FUNCTION,
3593                        (ScalarKind::Sint, Some(8)) => F2I64_FUNCTION,
3594                        (ScalarKind::Uint, Some(8)) => F2U64_FUNCTION,
3595                        _ => unreachable!(),
3596                    };
3597                    write!(self.out, "{fun_name}(")?;
3598                    self.write_expr(module, expr, func_ctx)?;
3599                    write!(self.out, ")")?;
3600                } else {
3601                    let close_paren = match convert {
3602                        Some(dst_width) => {
3603                            let scalar = Scalar {
3604                                kind,
3605                                width: dst_width,
3606                            };
3607                            match *inner {
3608                                TypeInner::Vector { size, .. } => {
3609                                    write!(
3610                                        self.out,
3611                                        "{}{}(",
3612                                        scalar.to_hlsl_str()?,
3613                                        common::vector_size_str(size)
3614                                    )?;
3615                                }
3616                                TypeInner::Scalar(_) => {
3617                                    write!(self.out, "{}(", scalar.to_hlsl_str()?,)?;
3618                                }
3619                                TypeInner::Matrix { columns, rows, .. } => {
3620                                    write!(
3621                                        self.out,
3622                                        "{}{}x{}(",
3623                                        scalar.to_hlsl_str()?,
3624                                        common::vector_size_str(columns),
3625                                        common::vector_size_str(rows)
3626                                    )?;
3627                                }
3628                                _ => {
3629                                    return Err(Error::Unimplemented(format!(
3630                                        "write_expr expression::as {inner:?}"
3631                                    )));
3632                                }
3633                            };
3634                            true
3635                        }
3636                        None => {
3637                            if inner.scalar_width() == Some(8) {
3638                                false
3639                            } else {
3640                                write!(self.out, "{}(", kind.to_hlsl_cast(),)?;
3641                                true
3642                            }
3643                        }
3644                    };
3645                    self.write_expr(module, expr, func_ctx)?;
3646                    if close_paren {
3647                        write!(self.out, ")")?;
3648                    }
3649                }
3650            }
3651            Expression::Math {
3652                fun,
3653                arg,
3654                arg1,
3655                arg2,
3656                arg3,
3657            } => {
3658                use crate::MathFunction as Mf;
3659
3660                enum Function {
3661                    Asincosh { is_sin: bool },
3662                    Atanh,
3663                    Pack2x16float,
3664                    Pack2x16snorm,
3665                    Pack2x16unorm,
3666                    Pack4x8snorm,
3667                    Pack4x8unorm,
3668                    Pack4xI8,
3669                    Pack4xU8,
3670                    Pack4xI8Clamp,
3671                    Pack4xU8Clamp,
3672                    Unpack2x16float,
3673                    Unpack2x16snorm,
3674                    Unpack2x16unorm,
3675                    Unpack4x8snorm,
3676                    Unpack4x8unorm,
3677                    Unpack4xI8,
3678                    Unpack4xU8,
3679                    Dot4I8Packed,
3680                    Dot4U8Packed,
3681                    QuantizeToF16,
3682                    Regular(&'static str),
3683                    MissingIntOverload(&'static str),
3684                    MissingIntReturnType(&'static str),
3685                    CountTrailingZeros,
3686                    CountLeadingZeros,
3687                }
3688
3689                let fun = match fun {
3690                    // comparison
3691                    Mf::Abs => match func_ctx.resolve_type(arg, &module.types).scalar() {
3692                        Some(Scalar::I32) => Function::Regular(ABS_FUNCTION),
3693                        _ => Function::Regular("abs"),
3694                    },
3695                    Mf::Min => Function::Regular("min"),
3696                    Mf::Max => Function::Regular("max"),
3697                    Mf::Clamp => Function::Regular("clamp"),
3698                    Mf::Saturate => Function::Regular("saturate"),
3699                    // trigonometry
3700                    Mf::Cos => Function::Regular("cos"),
3701                    Mf::Cosh => Function::Regular("cosh"),
3702                    Mf::Sin => Function::Regular("sin"),
3703                    Mf::Sinh => Function::Regular("sinh"),
3704                    Mf::Tan => Function::Regular("tan"),
3705                    Mf::Tanh => Function::Regular("tanh"),
3706                    Mf::Acos => Function::Regular("acos"),
3707                    Mf::Asin => Function::Regular("asin"),
3708                    Mf::Atan => Function::Regular("atan"),
3709                    Mf::Atan2 => Function::Regular("atan2"),
3710                    Mf::Asinh => Function::Asincosh { is_sin: true },
3711                    Mf::Acosh => Function::Asincosh { is_sin: false },
3712                    Mf::Atanh => Function::Atanh,
3713                    Mf::Radians => Function::Regular("radians"),
3714                    Mf::Degrees => Function::Regular("degrees"),
3715                    // decomposition
3716                    Mf::Ceil => Function::Regular("ceil"),
3717                    Mf::Floor => Function::Regular("floor"),
3718                    Mf::Round => Function::Regular("round"),
3719                    Mf::Fract => Function::Regular("frac"),
3720                    Mf::Trunc => Function::Regular("trunc"),
3721                    Mf::Modf => Function::Regular(MODF_FUNCTION),
3722                    Mf::Frexp => Function::Regular(FREXP_FUNCTION),
3723                    Mf::Ldexp => Function::Regular("ldexp"),
3724                    // exponent
3725                    Mf::Exp => Function::Regular("exp"),
3726                    Mf::Exp2 => Function::Regular("exp2"),
3727                    Mf::Log => Function::Regular("log"),
3728                    Mf::Log2 => Function::Regular("log2"),
3729                    Mf::Pow => Function::Regular("pow"),
3730                    // geometry
3731                    Mf::Dot => Function::Regular("dot"),
3732                    Mf::Dot4I8Packed => Function::Dot4I8Packed,
3733                    Mf::Dot4U8Packed => Function::Dot4U8Packed,
3734                    //Mf::Outer => ,
3735                    Mf::Cross => Function::Regular("cross"),
3736                    Mf::Distance => Function::Regular("distance"),
3737                    Mf::Length => Function::Regular("length"),
3738                    Mf::Normalize => Function::Regular("normalize"),
3739                    Mf::FaceForward => Function::Regular("faceforward"),
3740                    Mf::Reflect => Function::Regular("reflect"),
3741                    Mf::Refract => Function::Regular("refract"),
3742                    // computational
3743                    Mf::Sign => Function::Regular("sign"),
3744                    Mf::Fma => Function::Regular("mad"),
3745                    Mf::Mix => Function::Regular("lerp"),
3746                    Mf::Step => Function::Regular("step"),
3747                    Mf::SmoothStep => Function::Regular("smoothstep"),
3748                    Mf::Sqrt => Function::Regular("sqrt"),
3749                    Mf::InverseSqrt => Function::Regular("rsqrt"),
3750                    //Mf::Inverse =>,
3751                    Mf::Transpose => Function::Regular("transpose"),
3752                    Mf::Determinant => Function::Regular("determinant"),
3753                    Mf::QuantizeToF16 => Function::QuantizeToF16,
3754                    // bits
3755                    Mf::CountTrailingZeros => Function::CountTrailingZeros,
3756                    Mf::CountLeadingZeros => Function::CountLeadingZeros,
3757                    Mf::CountOneBits => Function::MissingIntOverload("countbits"),
3758                    Mf::ReverseBits => Function::MissingIntOverload("reversebits"),
3759                    Mf::FirstTrailingBit => Function::MissingIntReturnType("firstbitlow"),
3760                    Mf::FirstLeadingBit => Function::MissingIntReturnType("firstbithigh"),
3761                    Mf::ExtractBits => Function::Regular(EXTRACT_BITS_FUNCTION),
3762                    Mf::InsertBits => Function::Regular(INSERT_BITS_FUNCTION),
3763                    // Data Packing
3764                    Mf::Pack2x16float => Function::Pack2x16float,
3765                    Mf::Pack2x16snorm => Function::Pack2x16snorm,
3766                    Mf::Pack2x16unorm => Function::Pack2x16unorm,
3767                    Mf::Pack4x8snorm => Function::Pack4x8snorm,
3768                    Mf::Pack4x8unorm => Function::Pack4x8unorm,
3769                    Mf::Pack4xI8 => Function::Pack4xI8,
3770                    Mf::Pack4xU8 => Function::Pack4xU8,
3771                    Mf::Pack4xI8Clamp => Function::Pack4xI8Clamp,
3772                    Mf::Pack4xU8Clamp => Function::Pack4xU8Clamp,
3773                    // Data Unpacking
3774                    Mf::Unpack2x16float => Function::Unpack2x16float,
3775                    Mf::Unpack2x16snorm => Function::Unpack2x16snorm,
3776                    Mf::Unpack2x16unorm => Function::Unpack2x16unorm,
3777                    Mf::Unpack4x8snorm => Function::Unpack4x8snorm,
3778                    Mf::Unpack4x8unorm => Function::Unpack4x8unorm,
3779                    Mf::Unpack4xI8 => Function::Unpack4xI8,
3780                    Mf::Unpack4xU8 => Function::Unpack4xU8,
3781                    _ => return Err(Error::Unimplemented(format!("write_expr_math {fun:?}"))),
3782                };
3783
3784                match fun {
3785                    Function::Asincosh { is_sin } => {
3786                        write!(self.out, "log(")?;
3787                        self.write_expr(module, arg, func_ctx)?;
3788                        write!(self.out, " + sqrt(")?;
3789                        self.write_expr(module, arg, func_ctx)?;
3790                        write!(self.out, " * ")?;
3791                        self.write_expr(module, arg, func_ctx)?;
3792                        match is_sin {
3793                            true => write!(self.out, " + 1.0))")?,
3794                            false => write!(self.out, " - 1.0))")?,
3795                        }
3796                    }
3797                    Function::Atanh => {
3798                        write!(self.out, "0.5 * log((1.0 + ")?;
3799                        self.write_expr(module, arg, func_ctx)?;
3800                        write!(self.out, ") / (1.0 - ")?;
3801                        self.write_expr(module, arg, func_ctx)?;
3802                        write!(self.out, "))")?;
3803                    }
3804                    Function::Pack2x16float => {
3805                        write!(self.out, "(f32tof16(")?;
3806                        self.write_expr(module, arg, func_ctx)?;
3807                        write!(self.out, "[0]) | f32tof16(")?;
3808                        self.write_expr(module, arg, func_ctx)?;
3809                        write!(self.out, "[1]) << 16)")?;
3810                    }
3811                    Function::Pack2x16snorm => {
3812                        let scale = 32767;
3813
3814                        write!(self.out, "uint((int(round(clamp(")?;
3815                        self.write_expr(module, arg, func_ctx)?;
3816                        write!(
3817                            self.out,
3818                            "[0], -1.0, 1.0) * {scale}.0)) & 0xFFFF) | ((int(round(clamp("
3819                        )?;
3820                        self.write_expr(module, arg, func_ctx)?;
3821                        write!(self.out, "[1], -1.0, 1.0) * {scale}.0)) & 0xFFFF) << 16))",)?;
3822                    }
3823                    Function::Pack2x16unorm => {
3824                        let scale = 65535;
3825
3826                        write!(self.out, "(uint(round(clamp(")?;
3827                        self.write_expr(module, arg, func_ctx)?;
3828                        write!(self.out, "[0], 0.0, 1.0) * {scale}.0)) | uint(round(clamp(")?;
3829                        self.write_expr(module, arg, func_ctx)?;
3830                        write!(self.out, "[1], 0.0, 1.0) * {scale}.0)) << 16)")?;
3831                    }
3832                    Function::Pack4x8snorm => {
3833                        let scale = 127;
3834
3835                        write!(self.out, "uint((int(round(clamp(")?;
3836                        self.write_expr(module, arg, func_ctx)?;
3837                        write!(
3838                            self.out,
3839                            "[0], -1.0, 1.0) * {scale}.0)) & 0xFF) | ((int(round(clamp("
3840                        )?;
3841                        self.write_expr(module, arg, func_ctx)?;
3842                        write!(
3843                            self.out,
3844                            "[1], -1.0, 1.0) * {scale}.0)) & 0xFF) << 8) | ((int(round(clamp("
3845                        )?;
3846                        self.write_expr(module, arg, func_ctx)?;
3847                        write!(
3848                            self.out,
3849                            "[2], -1.0, 1.0) * {scale}.0)) & 0xFF) << 16) | ((int(round(clamp("
3850                        )?;
3851                        self.write_expr(module, arg, func_ctx)?;
3852                        write!(self.out, "[3], -1.0, 1.0) * {scale}.0)) & 0xFF) << 24))",)?;
3853                    }
3854                    Function::Pack4x8unorm => {
3855                        let scale = 255;
3856
3857                        write!(self.out, "(uint(round(clamp(")?;
3858                        self.write_expr(module, arg, func_ctx)?;
3859                        write!(self.out, "[0], 0.0, 1.0) * {scale}.0)) | uint(round(clamp(")?;
3860                        self.write_expr(module, arg, func_ctx)?;
3861                        write!(
3862                            self.out,
3863                            "[1], 0.0, 1.0) * {scale}.0)) << 8 | uint(round(clamp("
3864                        )?;
3865                        self.write_expr(module, arg, func_ctx)?;
3866                        write!(
3867                            self.out,
3868                            "[2], 0.0, 1.0) * {scale}.0)) << 16 | uint(round(clamp("
3869                        )?;
3870                        self.write_expr(module, arg, func_ctx)?;
3871                        write!(self.out, "[3], 0.0, 1.0) * {scale}.0)) << 24)")?;
3872                    }
3873                    fun @ (Function::Pack4xI8
3874                    | Function::Pack4xU8
3875                    | Function::Pack4xI8Clamp
3876                    | Function::Pack4xU8Clamp) => {
3877                        let was_signed =
3878                            matches!(fun, Function::Pack4xI8 | Function::Pack4xI8Clamp);
3879                        let clamp_bounds = match fun {
3880                            Function::Pack4xI8Clamp => Some(("-128", "127")),
3881                            Function::Pack4xU8Clamp => Some(("0", "255")),
3882                            _ => None,
3883                        };
3884                        if was_signed {
3885                            write!(self.out, "uint(")?;
3886                        }
3887                        let write_arg = |this: &mut Self| -> BackendResult {
3888                            if let Some((min, max)) = clamp_bounds {
3889                                write!(this.out, "clamp(")?;
3890                                this.write_expr(module, arg, func_ctx)?;
3891                                write!(this.out, ", {min}, {max})")?;
3892                            } else {
3893                                this.write_expr(module, arg, func_ctx)?;
3894                            }
3895                            Ok(())
3896                        };
3897                        write!(self.out, "(")?;
3898                        write_arg(self)?;
3899                        write!(self.out, "[0] & 0xFF) | ((")?;
3900                        write_arg(self)?;
3901                        write!(self.out, "[1] & 0xFF) << 8) | ((")?;
3902                        write_arg(self)?;
3903                        write!(self.out, "[2] & 0xFF) << 16) | ((")?;
3904                        write_arg(self)?;
3905                        write!(self.out, "[3] & 0xFF) << 24)")?;
3906                        if was_signed {
3907                            write!(self.out, ")")?;
3908                        }
3909                    }
3910
3911                    Function::Unpack2x16float => {
3912                        write!(self.out, "float2(f16tof32(")?;
3913                        self.write_expr(module, arg, func_ctx)?;
3914                        write!(self.out, "), f16tof32((")?;
3915                        self.write_expr(module, arg, func_ctx)?;
3916                        write!(self.out, ") >> 16))")?;
3917                    }
3918                    Function::Unpack2x16snorm => {
3919                        let scale = 32767;
3920
3921                        write!(self.out, "(float2(int2(")?;
3922                        self.write_expr(module, arg, func_ctx)?;
3923                        write!(self.out, " << 16, ")?;
3924                        self.write_expr(module, arg, func_ctx)?;
3925                        write!(self.out, ") >> 16) / {scale}.0)")?;
3926                    }
3927                    Function::Unpack2x16unorm => {
3928                        let scale = 65535;
3929
3930                        write!(self.out, "(float2(")?;
3931                        self.write_expr(module, arg, func_ctx)?;
3932                        write!(self.out, " & 0xFFFF, ")?;
3933                        self.write_expr(module, arg, func_ctx)?;
3934                        write!(self.out, " >> 16) / {scale}.0)")?;
3935                    }
3936                    Function::Unpack4x8snorm => {
3937                        let scale = 127;
3938
3939                        write!(self.out, "(float4(int4(")?;
3940                        self.write_expr(module, arg, func_ctx)?;
3941                        write!(self.out, " << 24, ")?;
3942                        self.write_expr(module, arg, func_ctx)?;
3943                        write!(self.out, " << 16, ")?;
3944                        self.write_expr(module, arg, func_ctx)?;
3945                        write!(self.out, " << 8, ")?;
3946                        self.write_expr(module, arg, func_ctx)?;
3947                        write!(self.out, ") >> 24) / {scale}.0)")?;
3948                    }
3949                    Function::Unpack4x8unorm => {
3950                        let scale = 255;
3951
3952                        write!(self.out, "(float4(")?;
3953                        self.write_expr(module, arg, func_ctx)?;
3954                        write!(self.out, " & 0xFF, ")?;
3955                        self.write_expr(module, arg, func_ctx)?;
3956                        write!(self.out, " >> 8 & 0xFF, ")?;
3957                        self.write_expr(module, arg, func_ctx)?;
3958                        write!(self.out, " >> 16 & 0xFF, ")?;
3959                        self.write_expr(module, arg, func_ctx)?;
3960                        write!(self.out, " >> 24) / {scale}.0)")?;
3961                    }
3962                    fun @ (Function::Unpack4xI8 | Function::Unpack4xU8) => {
3963                        write!(self.out, "(")?;
3964                        if matches!(fun, Function::Unpack4xU8) {
3965                            write!(self.out, "u")?;
3966                        }
3967                        write!(self.out, "int4(")?;
3968                        self.write_expr(module, arg, func_ctx)?;
3969                        write!(self.out, ", ")?;
3970                        self.write_expr(module, arg, func_ctx)?;
3971                        write!(self.out, " >> 8, ")?;
3972                        self.write_expr(module, arg, func_ctx)?;
3973                        write!(self.out, " >> 16, ")?;
3974                        self.write_expr(module, arg, func_ctx)?;
3975                        write!(self.out, " >> 24) << 24 >> 24)")?;
3976                    }
3977                    fun @ (Function::Dot4I8Packed | Function::Dot4U8Packed) => {
3978                        let arg1 = arg1.unwrap();
3979
3980                        if self.options.shader_model >= ShaderModel::V6_4 {
3981                            // Intrinsics `dot4add_{i, u}8packed` are available in SM 6.4 and later.
3982                            let function_name = match fun {
3983                                Function::Dot4I8Packed => "dot4add_i8packed",
3984                                Function::Dot4U8Packed => "dot4add_u8packed",
3985                                _ => unreachable!(),
3986                            };
3987                            write!(self.out, "{function_name}(")?;
3988                            self.write_expr(module, arg, func_ctx)?;
3989                            write!(self.out, ", ")?;
3990                            self.write_expr(module, arg1, func_ctx)?;
3991                            write!(self.out, ", 0)")?;
3992                        } else {
3993                            // Fall back to a polyfill as `dot4add_u8packed` is not available.
3994                            write!(self.out, "dot(")?;
3995
3996                            if matches!(fun, Function::Dot4U8Packed) {
3997                                write!(self.out, "u")?;
3998                            }
3999                            write!(self.out, "int4(")?;
4000                            self.write_expr(module, arg, func_ctx)?;
4001                            write!(self.out, ", ")?;
4002                            self.write_expr(module, arg, func_ctx)?;
4003                            write!(self.out, " >> 8, ")?;
4004                            self.write_expr(module, arg, func_ctx)?;
4005                            write!(self.out, " >> 16, ")?;
4006                            self.write_expr(module, arg, func_ctx)?;
4007                            write!(self.out, " >> 24) << 24 >> 24, ")?;
4008
4009                            if matches!(fun, Function::Dot4U8Packed) {
4010                                write!(self.out, "u")?;
4011                            }
4012                            write!(self.out, "int4(")?;
4013                            self.write_expr(module, arg1, func_ctx)?;
4014                            write!(self.out, ", ")?;
4015                            self.write_expr(module, arg1, func_ctx)?;
4016                            write!(self.out, " >> 8, ")?;
4017                            self.write_expr(module, arg1, func_ctx)?;
4018                            write!(self.out, " >> 16, ")?;
4019                            self.write_expr(module, arg1, func_ctx)?;
4020                            write!(self.out, " >> 24) << 24 >> 24)")?;
4021                        }
4022                    }
4023                    Function::QuantizeToF16 => {
4024                        write!(self.out, "f16tof32(f32tof16(")?;
4025                        self.write_expr(module, arg, func_ctx)?;
4026                        write!(self.out, "))")?;
4027                    }
4028                    Function::Regular(fun_name) => {
4029                        write!(self.out, "{fun_name}(")?;
4030                        self.write_expr(module, arg, func_ctx)?;
4031                        if let Some(arg) = arg1 {
4032                            write!(self.out, ", ")?;
4033                            self.write_expr(module, arg, func_ctx)?;
4034                        }
4035                        if let Some(arg) = arg2 {
4036                            write!(self.out, ", ")?;
4037                            self.write_expr(module, arg, func_ctx)?;
4038                        }
4039                        if let Some(arg) = arg3 {
4040                            write!(self.out, ", ")?;
4041                            self.write_expr(module, arg, func_ctx)?;
4042                        }
4043                        write!(self.out, ")")?
4044                    }
4045                    // These overloads are only missing on FXC, so this is only needed for 32bit types,
4046                    // as non-32bit types are DXC only.
4047                    Function::MissingIntOverload(fun_name) => {
4048                        let scalar_kind = func_ctx.resolve_type(arg, &module.types).scalar();
4049                        if let Some(Scalar::I32) = scalar_kind {
4050                            write!(self.out, "asint({fun_name}(asuint(")?;
4051                            self.write_expr(module, arg, func_ctx)?;
4052                            write!(self.out, ")))")?;
4053                        } else {
4054                            write!(self.out, "{fun_name}(")?;
4055                            self.write_expr(module, arg, func_ctx)?;
4056                            write!(self.out, ")")?;
4057                        }
4058                    }
4059                    // These overloads are only missing on FXC, so this is only needed for 32bit types,
4060                    // as non-32bit types are DXC only.
4061                    Function::MissingIntReturnType(fun_name) => {
4062                        let scalar_kind = func_ctx.resolve_type(arg, &module.types).scalar();
4063                        if let Some(Scalar::I32) = scalar_kind {
4064                            write!(self.out, "asint({fun_name}(")?;
4065                            self.write_expr(module, arg, func_ctx)?;
4066                            write!(self.out, "))")?;
4067                        } else {
4068                            write!(self.out, "{fun_name}(")?;
4069                            self.write_expr(module, arg, func_ctx)?;
4070                            write!(self.out, ")")?;
4071                        }
4072                    }
4073                    Function::CountTrailingZeros => {
4074                        match *func_ctx.resolve_type(arg, &module.types) {
4075                            TypeInner::Vector { size, scalar } => {
4076                                let s = match size {
4077                                    crate::VectorSize::Bi => ".xx",
4078                                    crate::VectorSize::Tri => ".xxx",
4079                                    crate::VectorSize::Quad => ".xxxx",
4080                                };
4081
4082                                let scalar_width_bits = scalar.width * 8;
4083
4084                                if scalar.kind == ScalarKind::Uint || scalar.width != 4 {
4085                                    write!(
4086                                        self.out,
4087                                        "min(({scalar_width_bits}u){s}, firstbitlow("
4088                                    )?;
4089                                    self.write_expr(module, arg, func_ctx)?;
4090                                    write!(self.out, "))")?;
4091                                } else {
4092                                    // This is only needed for the FXC path, on 32bit signed integers.
4093                                    write!(
4094                                        self.out,
4095                                        "asint(min(({scalar_width_bits}u){s}, firstbitlow("
4096                                    )?;
4097                                    self.write_expr(module, arg, func_ctx)?;
4098                                    write!(self.out, ")))")?;
4099                                }
4100                            }
4101                            TypeInner::Scalar(scalar) => {
4102                                let scalar_width_bits = scalar.width * 8;
4103
4104                                if scalar.kind == ScalarKind::Uint || scalar.width != 4 {
4105                                    write!(self.out, "min({scalar_width_bits}u, firstbitlow(")?;
4106                                    self.write_expr(module, arg, func_ctx)?;
4107                                    write!(self.out, "))")?;
4108                                } else {
4109                                    // This is only needed for the FXC path, on 32bit signed integers.
4110                                    write!(
4111                                        self.out,
4112                                        "asint(min({scalar_width_bits}u, firstbitlow("
4113                                    )?;
4114                                    self.write_expr(module, arg, func_ctx)?;
4115                                    write!(self.out, ")))")?;
4116                                }
4117                            }
4118                            _ => unreachable!(),
4119                        }
4120
4121                        return Ok(());
4122                    }
4123                    Function::CountLeadingZeros => {
4124                        match *func_ctx.resolve_type(arg, &module.types) {
4125                            TypeInner::Vector { size, scalar } => {
4126                                let s = match size {
4127                                    crate::VectorSize::Bi => ".xx",
4128                                    crate::VectorSize::Tri => ".xxx",
4129                                    crate::VectorSize::Quad => ".xxxx",
4130                                };
4131
4132                                // scalar width - 1
4133                                let constant = scalar.width * 8 - 1;
4134
4135                                if scalar.kind == ScalarKind::Uint {
4136                                    write!(self.out, "(({constant}u){s} - firstbithigh(")?;
4137                                    self.write_expr(module, arg, func_ctx)?;
4138                                    write!(self.out, "))")?;
4139                                } else {
4140                                    let conversion_func = match scalar.width {
4141                                        4 => "asint",
4142                                        _ => "",
4143                                    };
4144                                    write!(self.out, "(")?;
4145                                    self.write_expr(module, arg, func_ctx)?;
4146                                    write!(
4147                                        self.out,
4148                                        " < (0){s} ? (0){s} : ({constant}){s} - {conversion_func}(firstbithigh("
4149                                    )?;
4150                                    self.write_expr(module, arg, func_ctx)?;
4151                                    write!(self.out, ")))")?;
4152                                }
4153                            }
4154                            TypeInner::Scalar(scalar) => {
4155                                // scalar width - 1
4156                                let constant = scalar.width * 8 - 1;
4157
4158                                if let ScalarKind::Uint = scalar.kind {
4159                                    write!(self.out, "({constant}u - firstbithigh(")?;
4160                                    self.write_expr(module, arg, func_ctx)?;
4161                                    write!(self.out, "))")?;
4162                                } else {
4163                                    let conversion_func = match scalar.width {
4164                                        4 => "asint",
4165                                        _ => "",
4166                                    };
4167                                    write!(self.out, "(")?;
4168                                    self.write_expr(module, arg, func_ctx)?;
4169                                    write!(
4170                                        self.out,
4171                                        " < 0 ? 0 : {constant} - {conversion_func}(firstbithigh("
4172                                    )?;
4173                                    self.write_expr(module, arg, func_ctx)?;
4174                                    write!(self.out, ")))")?;
4175                                }
4176                            }
4177                            _ => unreachable!(),
4178                        }
4179
4180                        return Ok(());
4181                    }
4182                }
4183            }
4184            Expression::Swizzle {
4185                size,
4186                vector,
4187                pattern,
4188            } => {
4189                self.write_expr(module, vector, func_ctx)?;
4190                write!(self.out, ".")?;
4191                for &sc in pattern[..size as usize].iter() {
4192                    self.out.write_char(back::COMPONENTS[sc as usize])?;
4193                }
4194            }
4195            Expression::ArrayLength(expr) => {
4196                let var_handle = match func_ctx.expressions[expr] {
4197                    Expression::AccessIndex { base, index: _ } => {
4198                        match func_ctx.expressions[base] {
4199                            Expression::GlobalVariable(handle) => handle,
4200                            _ => unreachable!(),
4201                        }
4202                    }
4203                    Expression::GlobalVariable(handle) => handle,
4204                    _ => unreachable!(),
4205                };
4206
4207                let var = &module.global_variables[var_handle];
4208                let (offset, stride) = match module.types[var.ty].inner {
4209                    TypeInner::Array { stride, .. } => (0, stride),
4210                    TypeInner::Struct { ref members, .. } => {
4211                        let last = members.last().unwrap();
4212                        let stride = match module.types[last.ty].inner {
4213                            TypeInner::Array { stride, .. } => stride,
4214                            _ => unreachable!(),
4215                        };
4216                        (last.offset, stride)
4217                    }
4218                    _ => unreachable!(),
4219                };
4220
4221                let storage_access = match var.space {
4222                    crate::AddressSpace::Storage { access } => access,
4223                    _ => crate::StorageAccess::default(),
4224                };
4225                let wrapped_array_length = WrappedArrayLength {
4226                    writable: storage_access.contains(crate::StorageAccess::STORE),
4227                };
4228
4229                write!(self.out, "((")?;
4230                self.write_wrapped_array_length_function_name(wrapped_array_length)?;
4231                let var_name = &self.names[&NameKey::GlobalVariable(var_handle)];
4232                write!(self.out, "({var_name}) - {offset}) / {stride})")?
4233            }
4234            Expression::Derivative { axis, ctrl, expr } => {
4235                use crate::{DerivativeAxis as Axis, DerivativeControl as Ctrl};
4236                if axis == Axis::Width && (ctrl == Ctrl::Coarse || ctrl == Ctrl::Fine) {
4237                    let tail = match ctrl {
4238                        Ctrl::Coarse => "coarse",
4239                        Ctrl::Fine => "fine",
4240                        Ctrl::None => unreachable!(),
4241                    };
4242                    write!(self.out, "abs(ddx_{tail}(")?;
4243                    self.write_expr(module, expr, func_ctx)?;
4244                    write!(self.out, ")) + abs(ddy_{tail}(")?;
4245                    self.write_expr(module, expr, func_ctx)?;
4246                    write!(self.out, "))")?
4247                } else {
4248                    let fun_str = match (axis, ctrl) {
4249                        (Axis::X, Ctrl::Coarse) => "ddx_coarse",
4250                        (Axis::X, Ctrl::Fine) => "ddx_fine",
4251                        (Axis::X, Ctrl::None) => "ddx",
4252                        (Axis::Y, Ctrl::Coarse) => "ddy_coarse",
4253                        (Axis::Y, Ctrl::Fine) => "ddy_fine",
4254                        (Axis::Y, Ctrl::None) => "ddy",
4255                        (Axis::Width, Ctrl::Coarse | Ctrl::Fine) => unreachable!(),
4256                        (Axis::Width, Ctrl::None) => "fwidth",
4257                    };
4258                    write!(self.out, "{fun_str}(")?;
4259                    self.write_expr(module, expr, func_ctx)?;
4260                    write!(self.out, ")")?
4261                }
4262            }
4263            Expression::Relational { fun, argument } => {
4264                use crate::RelationalFunction as Rf;
4265
4266                let fun_str = match fun {
4267                    Rf::All => "all",
4268                    Rf::Any => "any",
4269                    Rf::IsNan => "isnan",
4270                    Rf::IsInf => "isinf",
4271                };
4272                write!(self.out, "{fun_str}(")?;
4273                self.write_expr(module, argument, func_ctx)?;
4274                write!(self.out, ")")?
4275            }
4276            Expression::Select {
4277                condition,
4278                accept,
4279                reject,
4280            } => {
4281                write!(self.out, "(")?;
4282                self.write_expr(module, condition, func_ctx)?;
4283                write!(self.out, " ? ")?;
4284                self.write_expr(module, accept, func_ctx)?;
4285                write!(self.out, " : ")?;
4286                self.write_expr(module, reject, func_ctx)?;
4287                write!(self.out, ")")?
4288            }
4289            Expression::RayQueryGetIntersection { query, committed } => {
4290                if committed {
4291                    write!(self.out, "GetCommittedIntersection(")?;
4292                    self.write_expr(module, query, func_ctx)?;
4293                    write!(self.out, ")")?;
4294                } else {
4295                    write!(self.out, "GetCandidateIntersection(")?;
4296                    self.write_expr(module, query, func_ctx)?;
4297                    write!(self.out, ")")?;
4298                }
4299            }
4300            // Not supported yet
4301            Expression::RayQueryVertexPositions { .. } => unreachable!(),
4302            // Nothing to do here, since call expression already cached
4303            Expression::CallResult(_)
4304            | Expression::AtomicResult { .. }
4305            | Expression::WorkGroupUniformLoadResult { .. }
4306            | Expression::RayQueryProceedResult
4307            | Expression::SubgroupBallotResult
4308            | Expression::SubgroupOperationResult { .. } => {}
4309        }
4310
4311        if !closing_bracket.is_empty() {
4312            write!(self.out, "{closing_bracket}")?;
4313        }
4314        Ok(())
4315    }
4316
4317    #[allow(clippy::too_many_arguments)]
4318    fn write_image_load(
4319        &mut self,
4320        module: &&Module,
4321        expr: Handle<crate::Expression>,
4322        func_ctx: &back::FunctionCtx,
4323        image: Handle<crate::Expression>,
4324        coordinate: Handle<crate::Expression>,
4325        array_index: Option<Handle<crate::Expression>>,
4326        sample: Option<Handle<crate::Expression>>,
4327        level: Option<Handle<crate::Expression>>,
4328    ) -> Result<(), Error> {
4329        let mut wrapping_type = None;
4330        match *func_ctx.resolve_type(image, &module.types) {
4331            TypeInner::Image {
4332                class: crate::ImageClass::External,
4333                ..
4334            } => {
4335                write!(self.out, "{IMAGE_LOAD_EXTERNAL_FUNCTION}(")?;
4336                self.write_expr(module, image, func_ctx)?;
4337                write!(self.out, ", ")?;
4338                self.write_expr(module, coordinate, func_ctx)?;
4339                write!(self.out, ")")?;
4340                return Ok(());
4341            }
4342            TypeInner::Image {
4343                class: crate::ImageClass::Storage { format, .. },
4344                ..
4345            } => {
4346                if format.single_component() {
4347                    wrapping_type = Some(Scalar::from(format));
4348                }
4349            }
4350            _ => {}
4351        }
4352        if let Some(scalar) = wrapping_type {
4353            write!(
4354                self.out,
4355                "{}{}(",
4356                help::IMAGE_STORAGE_LOAD_SCALAR_WRAPPER,
4357                scalar.to_hlsl_str()?
4358            )?;
4359        }
4360        // https://docs.microsoft.com/en-us/windows/win32/direct3dhlsl/dx-graphics-hlsl-to-load
4361        self.write_expr(module, image, func_ctx)?;
4362        write!(self.out, ".Load(")?;
4363
4364        self.write_texture_coordinates("int", coordinate, array_index, level, module, func_ctx)?;
4365
4366        if let Some(sample) = sample {
4367            write!(self.out, ", ")?;
4368            self.write_expr(module, sample, func_ctx)?;
4369        }
4370
4371        // close bracket for Load function
4372        write!(self.out, ")")?;
4373
4374        if wrapping_type.is_some() {
4375            write!(self.out, ")")?;
4376        }
4377
4378        // return x component if return type is scalar
4379        if let TypeInner::Scalar(_) = *func_ctx.resolve_type(expr, &module.types) {
4380            write!(self.out, ".x")?;
4381        }
4382        Ok(())
4383    }
4384
4385    /// Find the [`BindingArraySamplerInfo`] from an expression so that such an access
4386    /// can be generated later.
4387    fn sampler_binding_array_info_from_expression(
4388        &mut self,
4389        module: &Module,
4390        func_ctx: &back::FunctionCtx<'_>,
4391        base: Handle<crate::Expression>,
4392        resolved: &TypeInner,
4393    ) -> Option<BindingArraySamplerInfo> {
4394        if let TypeInner::BindingArray {
4395            base: base_ty_handle,
4396            ..
4397        } = *resolved
4398        {
4399            let base_ty = &module.types[base_ty_handle].inner;
4400            if let TypeInner::Sampler { comparison, .. } = *base_ty {
4401                let base = &func_ctx.expressions[base];
4402
4403                if let crate::Expression::GlobalVariable(handle) = *base {
4404                    let variable = &module.global_variables[handle];
4405
4406                    let sampler_heap_name = match comparison {
4407                        true => COMPARISON_SAMPLER_HEAP_VAR,
4408                        false => SAMPLER_HEAP_VAR,
4409                    };
4410
4411                    return Some(BindingArraySamplerInfo {
4412                        sampler_heap_name,
4413                        sampler_index_buffer_name: self
4414                            .wrapped
4415                            .sampler_index_buffers
4416                            .get(&super::SamplerIndexBufferKey {
4417                                group: variable.binding.unwrap().group,
4418                            })
4419                            .unwrap()
4420                            .clone(),
4421                        binding_array_base_index_name: self.names[&NameKey::GlobalVariable(handle)]
4422                            .clone(),
4423                    });
4424                }
4425            }
4426        }
4427
4428        None
4429    }
4430
4431    fn write_named_expr(
4432        &mut self,
4433        module: &Module,
4434        handle: Handle<crate::Expression>,
4435        name: String,
4436        // The expression which is being named.
4437        // Generally, this is the same as handle, except in WorkGroupUniformLoad
4438        named: Handle<crate::Expression>,
4439        ctx: &back::FunctionCtx,
4440    ) -> BackendResult {
4441        match ctx.info[named].ty {
4442            proc::TypeResolution::Handle(ty_handle) => match module.types[ty_handle].inner {
4443                TypeInner::Struct { .. } => {
4444                    let ty_name = &self.names[&NameKey::Type(ty_handle)];
4445                    write!(self.out, "{ty_name}")?;
4446                }
4447                _ => {
4448                    self.write_type(module, ty_handle)?;
4449                }
4450            },
4451            proc::TypeResolution::Value(ref inner) => {
4452                self.write_value_type(module, inner)?;
4453            }
4454        }
4455
4456        let resolved = ctx.resolve_type(named, &module.types);
4457
4458        write!(self.out, " {name}")?;
4459        // If rhs is a array type, we should write array size
4460        if let TypeInner::Array { base, size, .. } = *resolved {
4461            self.write_array_size(module, base, size)?;
4462        }
4463        write!(self.out, " = ")?;
4464        self.write_expr(module, handle, ctx)?;
4465        writeln!(self.out, ";")?;
4466        self.named_expressions.insert(named, name);
4467
4468        Ok(())
4469    }
4470
4471    /// Helper function that write default zero initialization
4472    pub(super) fn write_default_init(
4473        &mut self,
4474        module: &Module,
4475        ty: Handle<crate::Type>,
4476    ) -> BackendResult {
4477        write!(self.out, "(")?;
4478        self.write_type(module, ty)?;
4479        if let TypeInner::Array { base, size, .. } = module.types[ty].inner {
4480            self.write_array_size(module, base, size)?;
4481        }
4482        write!(self.out, ")0")?;
4483        Ok(())
4484    }
4485
4486    fn write_control_barrier(
4487        &mut self,
4488        barrier: crate::Barrier,
4489        level: back::Level,
4490    ) -> BackendResult {
4491        if barrier.contains(crate::Barrier::STORAGE) {
4492            writeln!(self.out, "{level}DeviceMemoryBarrierWithGroupSync();")?;
4493        }
4494        if barrier.contains(crate::Barrier::WORK_GROUP) {
4495            writeln!(self.out, "{level}GroupMemoryBarrierWithGroupSync();")?;
4496        }
4497        if barrier.contains(crate::Barrier::SUB_GROUP) {
4498            // Does not exist in DirectX
4499        }
4500        if barrier.contains(crate::Barrier::TEXTURE) {
4501            writeln!(self.out, "{level}DeviceMemoryBarrierWithGroupSync();")?;
4502        }
4503        Ok(())
4504    }
4505
4506    fn write_memory_barrier(
4507        &mut self,
4508        barrier: crate::Barrier,
4509        level: back::Level,
4510    ) -> BackendResult {
4511        if barrier.contains(crate::Barrier::STORAGE) {
4512            writeln!(self.out, "{level}DeviceMemoryBarrier();")?;
4513        }
4514        if barrier.contains(crate::Barrier::WORK_GROUP) {
4515            writeln!(self.out, "{level}GroupMemoryBarrier();")?;
4516        }
4517        if barrier.contains(crate::Barrier::SUB_GROUP) {
4518            // Does not exist in DirectX
4519        }
4520        if barrier.contains(crate::Barrier::TEXTURE) {
4521            writeln!(self.out, "{level}DeviceMemoryBarrier();")?;
4522        }
4523        Ok(())
4524    }
4525
4526    /// Helper to emit the shared tail of an HLSL atomic call (arguments, value, result)
4527    fn emit_hlsl_atomic_tail(
4528        &mut self,
4529        module: &Module,
4530        func_ctx: &back::FunctionCtx<'_>,
4531        fun: &crate::AtomicFunction,
4532        compare_expr: Option<Handle<crate::Expression>>,
4533        value: Handle<crate::Expression>,
4534        res_var_info: &Option<(Handle<crate::Expression>, String)>,
4535    ) -> BackendResult {
4536        if let Some(cmp) = compare_expr {
4537            write!(self.out, ", ")?;
4538            self.write_expr(module, cmp, func_ctx)?;
4539        }
4540        write!(self.out, ", ")?;
4541        if let crate::AtomicFunction::Subtract = *fun {
4542            // we just wrote `InterlockedAdd`, so negate the argument
4543            write!(self.out, "-")?;
4544        }
4545        self.write_expr(module, value, func_ctx)?;
4546        if let Some(&(_res_handle, ref res_name)) = res_var_info.as_ref() {
4547            write!(self.out, ", ")?;
4548            if compare_expr.is_some() {
4549                write!(self.out, "{res_name}.old_value")?;
4550            } else {
4551                write!(self.out, "{res_name}")?;
4552            }
4553        }
4554        writeln!(self.out, ");")?;
4555        Ok(())
4556    }
4557}
4558
4559pub(super) struct MatrixType {
4560    pub(super) columns: crate::VectorSize,
4561    pub(super) rows: crate::VectorSize,
4562    pub(super) width: crate::Bytes,
4563}
4564
4565pub(super) fn get_inner_matrix_data(
4566    module: &Module,
4567    handle: Handle<crate::Type>,
4568) -> Option<MatrixType> {
4569    match module.types[handle].inner {
4570        TypeInner::Matrix {
4571            columns,
4572            rows,
4573            scalar,
4574        } => Some(MatrixType {
4575            columns,
4576            rows,
4577            width: scalar.width,
4578        }),
4579        TypeInner::Array { base, .. } => get_inner_matrix_data(module, base),
4580        _ => None,
4581    }
4582}
4583
4584/// If `base` is an access chain of the form `mat`, `mat[col]`, or `mat[col][row]`,
4585/// returns a tuple of the matrix, the column (vector) index (if present), and
4586/// the row (scalar) index (if present).
4587fn find_matrix_in_access_chain(
4588    module: &Module,
4589    base: Handle<crate::Expression>,
4590    func_ctx: &back::FunctionCtx<'_>,
4591) -> Option<(Handle<crate::Expression>, Option<Index>, Option<Index>)> {
4592    let mut current_base = base;
4593    let mut vector = None;
4594    let mut scalar = None;
4595    loop {
4596        let resolved_tr = func_ctx
4597            .resolve_type(current_base, &module.types)
4598            .pointer_base_type();
4599        let resolved = resolved_tr.as_ref()?.inner_with(&module.types);
4600
4601        match *resolved {
4602            TypeInner::Matrix { .. } => return Some((current_base, vector, scalar)),
4603            TypeInner::Scalar(_) | TypeInner::Vector { .. } => {}
4604            _ => return None,
4605        }
4606
4607        let index;
4608        (current_base, index) = match func_ctx.expressions[current_base] {
4609            crate::Expression::Access { base, index } => (base, Index::Expression(index)),
4610            crate::Expression::AccessIndex { base, index } => (base, Index::Static(index)),
4611            _ => return None,
4612        };
4613
4614        match *resolved {
4615            TypeInner::Scalar(_) => scalar = Some(index),
4616            TypeInner::Vector { .. } => vector = Some(index),
4617            _ => unreachable!(),
4618        }
4619    }
4620}
4621
4622/// Returns the matrix data if the access chain starting at `base`:
4623/// - starts with an expression with resolved type of [`TypeInner::Matrix`] if `direct = true`
4624/// - contains one or more expressions with resolved type of [`TypeInner::Array`] of [`TypeInner::Matrix`]
4625/// - ends at an expression with resolved type of [`TypeInner::Struct`]
4626pub(super) fn get_inner_matrix_of_struct_array_member(
4627    module: &Module,
4628    base: Handle<crate::Expression>,
4629    func_ctx: &back::FunctionCtx<'_>,
4630    direct: bool,
4631) -> Option<MatrixType> {
4632    let mut mat_data = None;
4633    let mut array_base = None;
4634
4635    let mut current_base = base;
4636    loop {
4637        let mut resolved = func_ctx.resolve_type(current_base, &module.types);
4638        if let TypeInner::Pointer { base, .. } = *resolved {
4639            resolved = &module.types[base].inner;
4640        };
4641
4642        match *resolved {
4643            TypeInner::Matrix {
4644                columns,
4645                rows,
4646                scalar,
4647            } => {
4648                mat_data = Some(MatrixType {
4649                    columns,
4650                    rows,
4651                    width: scalar.width,
4652                })
4653            }
4654            TypeInner::Array { base, .. } => {
4655                array_base = Some(base);
4656            }
4657            TypeInner::Struct { .. } => {
4658                if let Some(array_base) = array_base {
4659                    if direct {
4660                        return mat_data;
4661                    } else {
4662                        return get_inner_matrix_data(module, array_base);
4663                    }
4664                }
4665
4666                break;
4667            }
4668            _ => break,
4669        }
4670
4671        current_base = match func_ctx.expressions[current_base] {
4672            crate::Expression::Access { base, .. } => base,
4673            crate::Expression::AccessIndex { base, .. } => base,
4674            _ => break,
4675        };
4676    }
4677    None
4678}
4679
4680/// Simpler version of get_inner_matrix_of_global_uniform that only looks at the
4681/// immediate expression, rather than traversing an access chain.
4682fn get_global_uniform_matrix(
4683    module: &Module,
4684    base: Handle<crate::Expression>,
4685    func_ctx: &back::FunctionCtx<'_>,
4686) -> Option<MatrixType> {
4687    let base_tr = func_ctx
4688        .resolve_type(base, &module.types)
4689        .pointer_base_type();
4690    let base_ty = base_tr.as_ref().map(|tr| tr.inner_with(&module.types));
4691    match (&func_ctx.expressions[base], base_ty) {
4692        (
4693            &crate::Expression::GlobalVariable(handle),
4694            Some(&TypeInner::Matrix {
4695                columns,
4696                rows,
4697                scalar,
4698            }),
4699        ) if module.global_variables[handle].space == crate::AddressSpace::Uniform => {
4700            Some(MatrixType {
4701                columns,
4702                rows,
4703                width: scalar.width,
4704            })
4705        }
4706        _ => None,
4707    }
4708}
4709
4710/// Returns the matrix data if the access chain starting at `base`:
4711/// - starts with an expression with resolved type of [`TypeInner::Matrix`]
4712/// - contains zero or more expressions with resolved type of [`TypeInner::Array`] of [`TypeInner::Matrix`]
4713/// - ends with an [`Expression::GlobalVariable`](crate::Expression::GlobalVariable) in [`AddressSpace::Uniform`](crate::AddressSpace::Uniform)
4714fn get_inner_matrix_of_global_uniform(
4715    module: &Module,
4716    base: Handle<crate::Expression>,
4717    func_ctx: &back::FunctionCtx<'_>,
4718) -> Option<MatrixType> {
4719    let mut mat_data = None;
4720    let mut array_base = None;
4721
4722    let mut current_base = base;
4723    loop {
4724        let mut resolved = func_ctx.resolve_type(current_base, &module.types);
4725        if let TypeInner::Pointer { base, .. } = *resolved {
4726            resolved = &module.types[base].inner;
4727        };
4728
4729        match *resolved {
4730            TypeInner::Matrix {
4731                columns,
4732                rows,
4733                scalar,
4734            } => {
4735                mat_data = Some(MatrixType {
4736                    columns,
4737                    rows,
4738                    width: scalar.width,
4739                })
4740            }
4741            TypeInner::Array { base, .. } => {
4742                array_base = Some(base);
4743            }
4744            _ => break,
4745        }
4746
4747        current_base = match func_ctx.expressions[current_base] {
4748            crate::Expression::Access { base, .. } => base,
4749            crate::Expression::AccessIndex { base, .. } => base,
4750            crate::Expression::GlobalVariable(handle)
4751                if module.global_variables[handle].space == crate::AddressSpace::Uniform =>
4752            {
4753                return mat_data.or_else(|| {
4754                    array_base.and_then(|array_base| get_inner_matrix_data(module, array_base))
4755                })
4756            }
4757            _ => break,
4758        };
4759    }
4760    None
4761}