naga/back/hlsl/
writer.rs

1use alloc::{
2    format,
3    string::{String, ToString},
4    vec::Vec,
5};
6use core::{fmt, mem};
7
8use super::{
9    help,
10    help::{
11        WrappedArrayLength, WrappedConstructor, WrappedImageQuery, WrappedStructMatrixAccess,
12        WrappedZeroValue,
13    },
14    storage::StoreValue,
15    BackendResult, Error, FragmentEntryPoint, Options, PipelineOptions, ShaderModel,
16};
17use crate::{
18    back::{self, get_entry_points, Baked},
19    common,
20    proc::{self, index, ExternalTextureNameKey, NameKey},
21    valid, Handle, Module, RayQueryFunction, Scalar, ScalarKind, ShaderStage, TypeInner,
22};
23
24const LOCATION_SEMANTIC: &str = "LOC";
25const SPECIAL_CBUF_TYPE: &str = "NagaConstants";
26const SPECIAL_CBUF_VAR: &str = "_NagaConstants";
27const SPECIAL_FIRST_VERTEX: &str = "first_vertex";
28const SPECIAL_FIRST_INSTANCE: &str = "first_instance";
29const SPECIAL_OTHER: &str = "other";
30
31pub(crate) const MODF_FUNCTION: &str = "naga_modf";
32pub(crate) const FREXP_FUNCTION: &str = "naga_frexp";
33pub(crate) const EXTRACT_BITS_FUNCTION: &str = "naga_extractBits";
34pub(crate) const INSERT_BITS_FUNCTION: &str = "naga_insertBits";
35pub(crate) const SAMPLER_HEAP_VAR: &str = "nagaSamplerHeap";
36pub(crate) const COMPARISON_SAMPLER_HEAP_VAR: &str = "nagaComparisonSamplerHeap";
37pub(crate) const SAMPLE_EXTERNAL_TEXTURE_FUNCTION: &str = "nagaSampleExternalTexture";
38pub(crate) const ABS_FUNCTION: &str = "naga_abs";
39pub(crate) const DIV_FUNCTION: &str = "naga_div";
40pub(crate) const MOD_FUNCTION: &str = "naga_mod";
41pub(crate) const NEG_FUNCTION: &str = "naga_neg";
42pub(crate) const F2I32_FUNCTION: &str = "naga_f2i32";
43pub(crate) const F2U32_FUNCTION: &str = "naga_f2u32";
44pub(crate) const F2I64_FUNCTION: &str = "naga_f2i64";
45pub(crate) const F2U64_FUNCTION: &str = "naga_f2u64";
46pub(crate) const IMAGE_SAMPLE_BASE_CLAMP_TO_EDGE_FUNCTION: &str =
47    "nagaTextureSampleBaseClampToEdge";
48pub(crate) const IMAGE_LOAD_EXTERNAL_FUNCTION: &str = "nagaTextureLoadExternal";
49pub(crate) const RAY_QUERY_TRACKER_VARIABLE_PREFIX: &str = "naga_query_init_tracker_for_";
50/// Prefix for variables in a naga statement
51pub(crate) const INTERNAL_PREFIX: &str = "naga_";
52
53enum Index {
54    Expression(Handle<crate::Expression>),
55    Static(u32),
56}
57
58struct EpStructMember {
59    name: String,
60    ty: Handle<crate::Type>,
61    // technically, this should always be `Some`
62    // (we `debug_assert!` this in `write_interface_struct`)
63    binding: Option<crate::Binding>,
64    index: u32,
65}
66
67/// Structure contains information required for generating
68/// wrapped structure of all entry points arguments
69struct EntryPointBinding {
70    /// Name of the fake EP argument that contains the struct
71    /// with all the flattened input data.
72    arg_name: String,
73    /// Generated structure name
74    ty_name: String,
75    /// Members of generated structure
76    members: Vec<EpStructMember>,
77}
78
79pub(super) struct EntryPointInterface {
80    /// If `Some`, the input of an entry point is gathered in a special
81    /// struct with members sorted by binding.
82    /// The `EntryPointBinding::members` array is sorted by index,
83    /// so that we can walk it in `write_ep_arguments_initialization`.
84    input: Option<EntryPointBinding>,
85    /// If `Some`, the output of an entry point is flattened.
86    /// The `EntryPointBinding::members` array is sorted by binding,
87    /// So that we can walk it in `Statement::Return` handler.
88    output: Option<EntryPointBinding>,
89}
90
91#[derive(Clone, Eq, PartialEq, PartialOrd, Ord)]
92enum InterfaceKey {
93    Location(u32),
94    BuiltIn(crate::BuiltIn),
95    Other,
96}
97
98impl InterfaceKey {
99    const fn new(binding: Option<&crate::Binding>) -> Self {
100        match binding {
101            Some(&crate::Binding::Location { location, .. }) => Self::Location(location),
102            Some(&crate::Binding::BuiltIn(built_in)) => Self::BuiltIn(built_in),
103            None => Self::Other,
104        }
105    }
106}
107
108#[derive(Copy, Clone, PartialEq)]
109enum Io {
110    Input,
111    Output,
112}
113
114const fn is_subgroup_builtin_binding(binding: &Option<crate::Binding>) -> bool {
115    let &Some(crate::Binding::BuiltIn(builtin)) = binding else {
116        return false;
117    };
118    matches!(
119        builtin,
120        crate::BuiltIn::SubgroupSize
121            | crate::BuiltIn::SubgroupInvocationId
122            | crate::BuiltIn::NumSubgroups
123            | crate::BuiltIn::SubgroupId
124    )
125}
126
127/// Information for how to generate a `binding_array<sampler>` access.
128struct BindingArraySamplerInfo {
129    /// Variable name of the sampler heap
130    sampler_heap_name: &'static str,
131    /// Variable name of the sampler index buffer
132    sampler_index_buffer_name: String,
133    /// Variable name of the base index _into_ the sampler index buffer
134    binding_array_base_index_name: String,
135}
136
137impl<'a, W: fmt::Write> super::Writer<'a, W> {
138    pub fn new(out: W, options: &'a Options, pipeline_options: &'a PipelineOptions) -> Self {
139        Self {
140            out,
141            names: crate::FastHashMap::default(),
142            namer: proc::Namer::default(),
143            options,
144            pipeline_options,
145            entry_point_io: crate::FastHashMap::default(),
146            named_expressions: crate::NamedExpressions::default(),
147            wrapped: super::Wrapped::default(),
148            written_committed_intersection: false,
149            written_candidate_intersection: false,
150            continue_ctx: back::continue_forward::ContinueCtx::default(),
151            temp_access_chain: Vec::new(),
152            need_bake_expressions: Default::default(),
153        }
154    }
155
156    fn reset(&mut self, module: &Module) {
157        self.names.clear();
158        self.namer.reset(
159            module,
160            &super::keywords::RESERVED_SET,
161            &super::keywords::RESERVED_CASE_INSENSITIVE_SET,
162            super::keywords::RESERVED_PREFIXES,
163            &mut self.names,
164        );
165        self.entry_point_io.clear();
166        self.named_expressions.clear();
167        self.wrapped.clear();
168        self.written_committed_intersection = false;
169        self.written_candidate_intersection = false;
170        self.continue_ctx.clear();
171        self.need_bake_expressions.clear();
172    }
173
174    /// Generates statements to be inserted immediately before and at the very
175    /// start of the body of each loop, to defeat infinite loop reasoning.
176    /// The 0th item of the returned tuple should be inserted immediately prior
177    /// to the loop and the 1st item should be inserted at the very start of
178    /// the loop body.
179    ///
180    /// See [`back::msl::Writer::gen_force_bounded_loop_statements`] for details.
181    fn gen_force_bounded_loop_statements(
182        &mut self,
183        level: back::Level,
184    ) -> Option<(String, String)> {
185        if !self.options.force_loop_bounding {
186            return None;
187        }
188
189        let loop_bound_name = self.namer.call("loop_bound");
190        let max = u32::MAX;
191        // Count down from u32::MAX rather than up from 0 to avoid hang on
192        // certain Intel drivers. See <https://github.com/gfx-rs/wgpu/issues/7319>.
193        let decl = format!("{level}uint2 {loop_bound_name} = uint2({max}u, {max}u);");
194        let level = level.next();
195        let break_and_inc = format!(
196            "{level}if (all({loop_bound_name} == uint2(0u, 0u))) {{ break; }}
197{level}{loop_bound_name} -= uint2({loop_bound_name}.y == 0u, 1u);"
198        );
199
200        Some((decl, break_and_inc))
201    }
202
203    /// Helper method used to find which expressions of a given function require baking
204    ///
205    /// # Notes
206    /// Clears `need_bake_expressions` set before adding to it
207    fn update_expressions_to_bake(
208        &mut self,
209        module: &Module,
210        func: &crate::Function,
211        info: &valid::FunctionInfo,
212    ) {
213        use crate::Expression;
214        self.need_bake_expressions.clear();
215        for (exp_handle, expr) in func.expressions.iter() {
216            let expr_info = &info[exp_handle];
217            let min_ref_count = func.expressions[exp_handle].bake_ref_count();
218            if min_ref_count <= expr_info.ref_count {
219                self.need_bake_expressions.insert(exp_handle);
220            }
221
222            if let Expression::Math { fun, arg, arg1, .. } = *expr {
223                match fun {
224                    crate::MathFunction::Asinh
225                    | crate::MathFunction::Acosh
226                    | crate::MathFunction::Atanh
227                    | crate::MathFunction::Unpack2x16float
228                    | crate::MathFunction::Unpack2x16snorm
229                    | crate::MathFunction::Unpack2x16unorm
230                    | crate::MathFunction::Unpack4x8snorm
231                    | crate::MathFunction::Unpack4x8unorm
232                    | crate::MathFunction::Unpack4xI8
233                    | crate::MathFunction::Unpack4xU8
234                    | crate::MathFunction::Pack2x16float
235                    | crate::MathFunction::Pack2x16snorm
236                    | crate::MathFunction::Pack2x16unorm
237                    | crate::MathFunction::Pack4x8snorm
238                    | crate::MathFunction::Pack4x8unorm
239                    | crate::MathFunction::Pack4xI8
240                    | crate::MathFunction::Pack4xU8
241                    | crate::MathFunction::Pack4xI8Clamp
242                    | crate::MathFunction::Pack4xU8Clamp => {
243                        self.need_bake_expressions.insert(arg);
244                    }
245                    crate::MathFunction::CountLeadingZeros => {
246                        let inner = info[exp_handle].ty.inner_with(&module.types);
247                        if let Some(ScalarKind::Sint) = inner.scalar_kind() {
248                            self.need_bake_expressions.insert(arg);
249                        }
250                    }
251                    crate::MathFunction::Dot4U8Packed | crate::MathFunction::Dot4I8Packed => {
252                        self.need_bake_expressions.insert(arg);
253                        self.need_bake_expressions.insert(arg1.unwrap());
254                    }
255                    _ => {}
256                }
257            }
258
259            if let Expression::Derivative { axis, ctrl, expr } = *expr {
260                use crate::{DerivativeAxis as Axis, DerivativeControl as Ctrl};
261                if axis == Axis::Width && (ctrl == Ctrl::Coarse || ctrl == Ctrl::Fine) {
262                    self.need_bake_expressions.insert(expr);
263                }
264            }
265
266            if let Expression::GlobalVariable(_) = *expr {
267                let inner = info[exp_handle].ty.inner_with(&module.types);
268
269                if let TypeInner::Sampler { .. } = *inner {
270                    self.need_bake_expressions.insert(exp_handle);
271                }
272            }
273        }
274        for statement in func.body.iter() {
275            match *statement {
276                crate::Statement::SubgroupCollectiveOperation {
277                    op: _,
278                    collective_op: crate::CollectiveOperation::InclusiveScan,
279                    argument,
280                    result: _,
281                } => {
282                    self.need_bake_expressions.insert(argument);
283                }
284                crate::Statement::Atomic {
285                    fun: crate::AtomicFunction::Exchange { compare: Some(cmp) },
286                    ..
287                } => {
288                    self.need_bake_expressions.insert(cmp);
289                }
290                _ => {}
291            }
292        }
293    }
294
295    pub fn write(
296        &mut self,
297        module: &Module,
298        module_info: &valid::ModuleInfo,
299        fragment_entry_point: Option<&FragmentEntryPoint<'_>>,
300    ) -> Result<super::ReflectionInfo, Error> {
301        self.reset(module);
302
303        // Write special constants, if needed
304        if let Some(ref bt) = self.options.special_constants_binding {
305            writeln!(self.out, "struct {SPECIAL_CBUF_TYPE} {{")?;
306            writeln!(self.out, "{}int {};", back::INDENT, SPECIAL_FIRST_VERTEX)?;
307            writeln!(self.out, "{}int {};", back::INDENT, SPECIAL_FIRST_INSTANCE)?;
308            writeln!(self.out, "{}uint {};", back::INDENT, SPECIAL_OTHER)?;
309            writeln!(self.out, "}};")?;
310            write!(
311                self.out,
312                "ConstantBuffer<{}> {}: register(b{}",
313                SPECIAL_CBUF_TYPE, SPECIAL_CBUF_VAR, bt.register
314            )?;
315            if bt.space != 0 {
316                write!(self.out, ", space{}", bt.space)?;
317            }
318            writeln!(self.out, ");")?;
319
320            // Extra newline for readability
321            writeln!(self.out)?;
322        }
323
324        for (group, bt) in self.options.dynamic_storage_buffer_offsets_targets.iter() {
325            writeln!(self.out, "struct __dynamic_buffer_offsetsTy{group} {{")?;
326            for i in 0..bt.size {
327                writeln!(self.out, "{}uint _{};", back::INDENT, i)?;
328            }
329            writeln!(self.out, "}};")?;
330            writeln!(
331                self.out,
332                "ConstantBuffer<__dynamic_buffer_offsetsTy{}> __dynamic_buffer_offsets{}: register(b{}, space{});",
333                group, group, bt.register, bt.space
334            )?;
335
336            // Extra newline for readability
337            writeln!(self.out)?;
338        }
339
340        // Save all entry point output types
341        let ep_results = module
342            .entry_points
343            .iter()
344            .map(|ep| (ep.stage, ep.function.result.clone()))
345            .collect::<Vec<(ShaderStage, Option<crate::FunctionResult>)>>();
346
347        self.write_all_mat_cx2_typedefs_and_functions(module)?;
348
349        // Write all structs
350        for (handle, ty) in module.types.iter() {
351            if let TypeInner::Struct { ref members, span } = ty.inner {
352                if module.types[members.last().unwrap().ty]
353                    .inner
354                    .is_dynamically_sized(&module.types)
355                {
356                    // unsized arrays can only be in storage buffers,
357                    // for which we use `ByteAddressBuffer` anyway.
358                    continue;
359                }
360
361                let ep_result = ep_results.iter().find(|e| {
362                    if let Some(ref result) = e.1 {
363                        result.ty == handle
364                    } else {
365                        false
366                    }
367                });
368
369                self.write_struct(
370                    module,
371                    handle,
372                    members,
373                    span,
374                    ep_result.map(|r| (r.0, Io::Output)),
375                )?;
376                writeln!(self.out)?;
377            }
378        }
379
380        self.write_special_functions(module)?;
381
382        self.write_wrapped_expression_functions(module, &module.global_expressions, None)?;
383        self.write_wrapped_zero_value_functions(module, &module.global_expressions)?;
384
385        // Write all named constants
386        let mut constants = module
387            .constants
388            .iter()
389            .filter(|&(_, c)| c.name.is_some())
390            .peekable();
391        while let Some((handle, _)) = constants.next() {
392            self.write_global_constant(module, handle)?;
393            // Add extra newline for readability on last iteration
394            if constants.peek().is_none() {
395                writeln!(self.out)?;
396            }
397        }
398
399        // Write all globals
400        for (global, _) in module.global_variables.iter() {
401            self.write_global(module, global)?;
402        }
403
404        if !module.global_variables.is_empty() {
405            // Add extra newline for readability
406            writeln!(self.out)?;
407        }
408
409        let ep_range = get_entry_points(module, self.pipeline_options.entry_point.as_ref())
410            .map_err(|(stage, name)| Error::EntryPointNotFound(stage, name))?;
411
412        // Write all entry points wrapped structs
413        for index in ep_range.clone() {
414            let ep = &module.entry_points[index];
415            let ep_name = self.names[&NameKey::EntryPoint(index as u16)].clone();
416            let ep_io = self.write_ep_interface(
417                module,
418                &ep.function,
419                ep.stage,
420                &ep_name,
421                fragment_entry_point,
422            )?;
423            self.entry_point_io.insert(index, ep_io);
424        }
425
426        // Write all regular functions
427        for (handle, function) in module.functions.iter() {
428            let info = &module_info[handle];
429
430            // Check if all of the globals are accessible
431            if !self.options.fake_missing_bindings {
432                if let Some((var_handle, _)) =
433                    module
434                        .global_variables
435                        .iter()
436                        .find(|&(var_handle, var)| match var.binding {
437                            Some(ref binding) if !info[var_handle].is_empty() => {
438                                self.options.resolve_resource_binding(binding).is_err()
439                                    && self
440                                        .options
441                                        .resolve_external_texture_resource_binding(binding)
442                                        .is_err()
443                            }
444                            _ => false,
445                        })
446                {
447                    log::debug!(
448                        "Skipping function {:?} (name {:?}) because global {:?} is inaccessible",
449                        handle,
450                        function.name,
451                        var_handle
452                    );
453                    continue;
454                }
455            }
456
457            let ctx = back::FunctionCtx {
458                ty: back::FunctionType::Function(handle),
459                info,
460                expressions: &function.expressions,
461                named_expressions: &function.named_expressions,
462            };
463            let name = self.names[&NameKey::Function(handle)].clone();
464
465            self.write_wrapped_functions(module, &ctx)?;
466
467            self.write_function(module, name.as_str(), function, &ctx, info)?;
468
469            writeln!(self.out)?;
470        }
471
472        let mut translated_ep_names = Vec::with_capacity(ep_range.len());
473
474        // Write all entry points
475        for index in ep_range {
476            let ep = &module.entry_points[index];
477            let info = module_info.get_entry_point(index);
478
479            if !self.options.fake_missing_bindings {
480                let mut ep_error = None;
481                for (var_handle, var) in module.global_variables.iter() {
482                    match var.binding {
483                        Some(ref binding) if !info[var_handle].is_empty() => {
484                            if let Err(err) = self.options.resolve_resource_binding(binding) {
485                                if self
486                                    .options
487                                    .resolve_external_texture_resource_binding(binding)
488                                    .is_err()
489                                {
490                                    ep_error = Some(err);
491                                    break;
492                                }
493                            }
494                        }
495                        _ => {}
496                    }
497                }
498                if let Some(err) = ep_error {
499                    translated_ep_names.push(Err(err));
500                    continue;
501                }
502            }
503
504            let ctx = back::FunctionCtx {
505                ty: back::FunctionType::EntryPoint(index as u16),
506                info,
507                expressions: &ep.function.expressions,
508                named_expressions: &ep.function.named_expressions,
509            };
510
511            self.write_wrapped_functions(module, &ctx)?;
512
513            if ep.stage.compute_like() {
514                // HLSL is calling workgroup size "num threads"
515                let num_threads = ep.workgroup_size;
516                writeln!(
517                    self.out,
518                    "[numthreads({}, {}, {})]",
519                    num_threads[0], num_threads[1], num_threads[2]
520                )?;
521            }
522
523            let name = self.names[&NameKey::EntryPoint(index as u16)].clone();
524            self.write_function(module, &name, &ep.function, &ctx, info)?;
525
526            if index < module.entry_points.len() - 1 {
527                writeln!(self.out)?;
528            }
529
530            translated_ep_names.push(Ok(name));
531        }
532
533        Ok(super::ReflectionInfo {
534            entry_point_names: translated_ep_names,
535        })
536    }
537
538    fn write_modifier(&mut self, binding: &crate::Binding) -> BackendResult {
539        match *binding {
540            crate::Binding::BuiltIn(crate::BuiltIn::Position { invariant: true }) => {
541                write!(self.out, "precise ")?;
542            }
543            crate::Binding::BuiltIn(crate::BuiltIn::Barycentric { perspective: false }) => {
544                write!(self.out, "noperspective ")?;
545            }
546            crate::Binding::Location {
547                interpolation,
548                sampling,
549                ..
550            } => {
551                if let Some(interpolation) = interpolation {
552                    if let Some(string) = interpolation.to_hlsl_str() {
553                        write!(self.out, "{string} ")?
554                    }
555                }
556
557                if let Some(sampling) = sampling {
558                    if let Some(string) = sampling.to_hlsl_str() {
559                        write!(self.out, "{string} ")?
560                    }
561                }
562            }
563            crate::Binding::BuiltIn(_) => {}
564        }
565
566        Ok(())
567    }
568
569    //TODO: we could force fragment outputs to always go through `entry_point_io.output` path
570    // if they are struct, so that the `stage` argument here could be omitted.
571    fn write_semantic(
572        &mut self,
573        binding: &Option<crate::Binding>,
574        stage: Option<(ShaderStage, Io)>,
575    ) -> BackendResult {
576        match *binding {
577            Some(crate::Binding::BuiltIn(builtin)) if !is_subgroup_builtin_binding(binding) => {
578                if builtin == crate::BuiltIn::ViewIndex
579                    && self.options.shader_model < ShaderModel::V6_1
580                {
581                    return Err(Error::ShaderModelTooLow(
582                        "used @builtin(view_index) or SV_ViewID".to_string(),
583                        ShaderModel::V6_1,
584                    ));
585                }
586                let builtin_str = builtin.to_hlsl_str()?;
587                write!(self.out, " : {builtin_str}")?;
588            }
589            Some(crate::Binding::Location {
590                blend_src: Some(1), ..
591            }) => {
592                write!(self.out, " : SV_Target1")?;
593            }
594            Some(crate::Binding::Location { location, .. }) => {
595                if stage == Some((ShaderStage::Fragment, Io::Output)) {
596                    write!(self.out, " : SV_Target{location}")?;
597                } else {
598                    write!(self.out, " : {LOCATION_SEMANTIC}{location}")?;
599                }
600            }
601            _ => {}
602        }
603
604        Ok(())
605    }
606
607    fn write_interface_struct(
608        &mut self,
609        module: &Module,
610        shader_stage: (ShaderStage, Io),
611        struct_name: String,
612        mut members: Vec<EpStructMember>,
613    ) -> Result<EntryPointBinding, Error> {
614        // Sort the members so that first come the user-defined varyings
615        // in ascending locations, and then built-ins. This allows VS and FS
616        // interfaces to match with regards to order.
617        members.sort_by_key(|m| InterfaceKey::new(m.binding.as_ref()));
618
619        write!(self.out, "struct {struct_name}")?;
620        writeln!(self.out, " {{")?;
621        for m in members.iter() {
622            // Sanity check that each IO member is a built-in or is assigned a
623            // location. Also see note about nesting in `write_ep_input_struct`.
624            debug_assert!(m.binding.is_some());
625
626            if is_subgroup_builtin_binding(&m.binding) {
627                continue;
628            }
629            write!(self.out, "{}", back::INDENT)?;
630            if let Some(ref binding) = m.binding {
631                self.write_modifier(binding)?;
632            }
633            self.write_type(module, m.ty)?;
634            write!(self.out, " {}", &m.name)?;
635            self.write_semantic(&m.binding, Some(shader_stage))?;
636            writeln!(self.out, ";")?;
637        }
638        if members.iter().any(|arg| {
639            matches!(
640                arg.binding,
641                Some(crate::Binding::BuiltIn(crate::BuiltIn::SubgroupId))
642            )
643        }) {
644            writeln!(
645                self.out,
646                "{}uint __local_invocation_index : SV_GroupIndex;",
647                back::INDENT
648            )?;
649        }
650        writeln!(self.out, "}};")?;
651        writeln!(self.out)?;
652
653        // See ordering notes on EntryPointInterface fields
654        match shader_stage.1 {
655            Io::Input => {
656                // bring back the original order
657                members.sort_by_key(|m| m.index);
658            }
659            Io::Output => {
660                // keep it sorted by binding
661            }
662        }
663
664        Ok(EntryPointBinding {
665            arg_name: self.namer.call(struct_name.to_lowercase().as_str()),
666            ty_name: struct_name,
667            members,
668        })
669    }
670
671    /// Flatten all entry point arguments into a single struct.
672    /// This is needed since we need to re-order them: first placing user locations,
673    /// then built-ins.
674    fn write_ep_input_struct(
675        &mut self,
676        module: &Module,
677        func: &crate::Function,
678        stage: ShaderStage,
679        entry_point_name: &str,
680    ) -> Result<EntryPointBinding, Error> {
681        let struct_name = format!("{stage:?}Input_{entry_point_name}");
682
683        let mut fake_members = Vec::new();
684        for arg in func.arguments.iter() {
685            // NOTE: We don't need to handle nesting structs. All members must
686            // be either built-ins or assigned a location. I.E. `binding` is
687            // `Some`. This is checked in `VaryingContext::validate`. See:
688            // https://gpuweb.github.io/gpuweb/wgsl/#input-output-locations
689            match module.types[arg.ty].inner {
690                TypeInner::Struct { ref members, .. } => {
691                    for member in members.iter() {
692                        let name = self.namer.call_or(&member.name, "member");
693                        let index = fake_members.len() as u32;
694                        fake_members.push(EpStructMember {
695                            name,
696                            ty: member.ty,
697                            binding: member.binding.clone(),
698                            index,
699                        });
700                    }
701                }
702                _ => {
703                    let member_name = self.namer.call_or(&arg.name, "member");
704                    let index = fake_members.len() as u32;
705                    fake_members.push(EpStructMember {
706                        name: member_name,
707                        ty: arg.ty,
708                        binding: arg.binding.clone(),
709                        index,
710                    });
711                }
712            }
713        }
714
715        self.write_interface_struct(module, (stage, Io::Input), struct_name, fake_members)
716    }
717
718    /// Flatten all entry point results into a single struct.
719    /// This is needed since we need to re-order them: first placing user locations,
720    /// then built-ins.
721    fn write_ep_output_struct(
722        &mut self,
723        module: &Module,
724        result: &crate::FunctionResult,
725        stage: ShaderStage,
726        entry_point_name: &str,
727        frag_ep: Option<&FragmentEntryPoint<'_>>,
728    ) -> Result<EntryPointBinding, Error> {
729        let struct_name = format!("{stage:?}Output_{entry_point_name}");
730
731        let empty = [];
732        let members = match module.types[result.ty].inner {
733            TypeInner::Struct { ref members, .. } => members,
734            ref other => {
735                log::error!("Unexpected {other:?} output type without a binding");
736                &empty[..]
737            }
738        };
739
740        // Gather list of fragment input locations. We use this below to remove user-defined
741        // varyings from VS outputs that aren't in the FS inputs. This makes the VS interface match
742        // as long as the FS inputs are a subset of the VS outputs. This is only applied if the
743        // writer is supplied with information about the fragment entry point.
744        let fs_input_locs = if let (Some(frag_ep), ShaderStage::Vertex) = (frag_ep, stage) {
745            let mut fs_input_locs = Vec::new();
746            for arg in frag_ep.func.arguments.iter() {
747                let mut push_if_location = |binding: &Option<crate::Binding>| match *binding {
748                    Some(crate::Binding::Location { location, .. }) => fs_input_locs.push(location),
749                    Some(crate::Binding::BuiltIn(_)) | None => {}
750                };
751
752                // NOTE: We don't need to handle struct nesting. See note in
753                // `write_ep_input_struct`.
754                match frag_ep.module.types[arg.ty].inner {
755                    TypeInner::Struct { ref members, .. } => {
756                        for member in members.iter() {
757                            push_if_location(&member.binding);
758                        }
759                    }
760                    _ => push_if_location(&arg.binding),
761                }
762            }
763            fs_input_locs.sort();
764            Some(fs_input_locs)
765        } else {
766            None
767        };
768
769        let mut fake_members = Vec::new();
770        for (index, member) in members.iter().enumerate() {
771            if let Some(ref fs_input_locs) = fs_input_locs {
772                match member.binding {
773                    Some(crate::Binding::Location { location, .. }) => {
774                        if fs_input_locs.binary_search(&location).is_err() {
775                            continue;
776                        }
777                    }
778                    Some(crate::Binding::BuiltIn(_)) | None => {}
779                }
780            }
781
782            let member_name = self.namer.call_or(&member.name, "member");
783            fake_members.push(EpStructMember {
784                name: member_name,
785                ty: member.ty,
786                binding: member.binding.clone(),
787                index: index as u32,
788            });
789        }
790
791        self.write_interface_struct(module, (stage, Io::Output), struct_name, fake_members)
792    }
793
794    /// Writes special interface structures for an entry point. The special structures have
795    /// all the fields flattened into them and sorted by binding. They are needed to emulate
796    /// subgroup built-ins and to make the interfaces between VS outputs and FS inputs match.
797    fn write_ep_interface(
798        &mut self,
799        module: &Module,
800        func: &crate::Function,
801        stage: ShaderStage,
802        ep_name: &str,
803        frag_ep: Option<&FragmentEntryPoint<'_>>,
804    ) -> Result<EntryPointInterface, Error> {
805        Ok(EntryPointInterface {
806            input: if !func.arguments.is_empty()
807                && (stage == ShaderStage::Fragment
808                    || func
809                        .arguments
810                        .iter()
811                        .any(|arg| is_subgroup_builtin_binding(&arg.binding)))
812            {
813                Some(self.write_ep_input_struct(module, func, stage, ep_name)?)
814            } else {
815                None
816            },
817            output: match func.result {
818                Some(ref fr) if fr.binding.is_none() && stage == ShaderStage::Vertex => {
819                    Some(self.write_ep_output_struct(module, fr, stage, ep_name, frag_ep)?)
820                }
821                _ => None,
822            },
823        })
824    }
825
826    fn write_ep_argument_initialization(
827        &mut self,
828        ep: &crate::EntryPoint,
829        ep_input: &EntryPointBinding,
830        fake_member: &EpStructMember,
831    ) -> BackendResult {
832        match fake_member.binding {
833            Some(crate::Binding::BuiltIn(crate::BuiltIn::SubgroupSize)) => {
834                write!(self.out, "WaveGetLaneCount()")?
835            }
836            Some(crate::Binding::BuiltIn(crate::BuiltIn::SubgroupInvocationId)) => {
837                write!(self.out, "WaveGetLaneIndex()")?
838            }
839            Some(crate::Binding::BuiltIn(crate::BuiltIn::NumSubgroups)) => write!(
840                self.out,
841                "({}u + WaveGetLaneCount() - 1u) / WaveGetLaneCount()",
842                ep.workgroup_size[0] * ep.workgroup_size[1] * ep.workgroup_size[2]
843            )?,
844            Some(crate::Binding::BuiltIn(crate::BuiltIn::SubgroupId)) => {
845                write!(
846                    self.out,
847                    "{}.__local_invocation_index / WaveGetLaneCount()",
848                    ep_input.arg_name
849                )?;
850            }
851            _ => {
852                write!(self.out, "{}.{}", ep_input.arg_name, fake_member.name)?;
853            }
854        }
855        Ok(())
856    }
857
858    /// Write an entry point preface that initializes the arguments as specified in IR.
859    fn write_ep_arguments_initialization(
860        &mut self,
861        module: &Module,
862        func: &crate::Function,
863        ep_index: u16,
864    ) -> BackendResult {
865        let ep = &module.entry_points[ep_index as usize];
866        let ep_input = match self
867            .entry_point_io
868            .get_mut(&(ep_index as usize))
869            .unwrap()
870            .input
871            .take()
872        {
873            Some(ep_input) => ep_input,
874            None => return Ok(()),
875        };
876        let mut fake_iter = ep_input.members.iter();
877        for (arg_index, arg) in func.arguments.iter().enumerate() {
878            write!(self.out, "{}", back::INDENT)?;
879            self.write_type(module, arg.ty)?;
880            let arg_name = &self.names[&NameKey::EntryPointArgument(ep_index, arg_index as u32)];
881            write!(self.out, " {arg_name}")?;
882            match module.types[arg.ty].inner {
883                TypeInner::Array { base, size, .. } => {
884                    self.write_array_size(module, base, size)?;
885                    write!(self.out, " = ")?;
886                    self.write_ep_argument_initialization(
887                        ep,
888                        &ep_input,
889                        fake_iter.next().unwrap(),
890                    )?;
891                    writeln!(self.out, ";")?;
892                }
893                TypeInner::Struct { ref members, .. } => {
894                    write!(self.out, " = {{ ")?;
895                    for index in 0..members.len() {
896                        if index != 0 {
897                            write!(self.out, ", ")?;
898                        }
899                        self.write_ep_argument_initialization(
900                            ep,
901                            &ep_input,
902                            fake_iter.next().unwrap(),
903                        )?;
904                    }
905                    writeln!(self.out, " }};")?;
906                }
907                _ => {
908                    write!(self.out, " = ")?;
909                    self.write_ep_argument_initialization(
910                        ep,
911                        &ep_input,
912                        fake_iter.next().unwrap(),
913                    )?;
914                    writeln!(self.out, ";")?;
915                }
916            }
917        }
918        assert!(fake_iter.next().is_none());
919        Ok(())
920    }
921
922    /// Helper method used to write global variables
923    /// # Notes
924    /// Always adds a newline
925    fn write_global(
926        &mut self,
927        module: &Module,
928        handle: Handle<crate::GlobalVariable>,
929    ) -> BackendResult {
930        let global = &module.global_variables[handle];
931        let inner = &module.types[global.ty].inner;
932
933        let handle_ty = match *inner {
934            TypeInner::BindingArray { ref base, .. } => &module.types[*base].inner,
935            _ => inner,
936        };
937
938        // External textures are handled entirely differently, so defer entirely to that method.
939        // We do so prior to calling resolve_resource_binding() below, as we even need to resolve
940        // their bindings separately.
941        let is_external_texture = matches!(
942            *handle_ty,
943            TypeInner::Image {
944                class: crate::ImageClass::External,
945                ..
946            }
947        );
948        if is_external_texture {
949            return self.write_global_external_texture(module, handle, global);
950        }
951
952        if let Some(ref binding) = global.binding {
953            if let Err(err) = self.options.resolve_resource_binding(binding) {
954                log::debug!(
955                    "Skipping global {:?} (name {:?}) for being inaccessible: {}",
956                    handle,
957                    global.name,
958                    err,
959                );
960                return Ok(());
961            }
962        }
963
964        // Samplers are handled entirely differently, so defer entirely to that method.
965        let is_sampler = matches!(*handle_ty, TypeInner::Sampler { .. });
966
967        if is_sampler {
968            return self.write_global_sampler(module, handle, global);
969        }
970
971        // https://docs.microsoft.com/en-us/windows/win32/direct3dhlsl/dx-graphics-hlsl-variable-register
972        let register_ty = match global.space {
973            crate::AddressSpace::Function => unreachable!("Function address space"),
974            crate::AddressSpace::Private => {
975                write!(self.out, "static ")?;
976                self.write_type(module, global.ty)?;
977                ""
978            }
979            crate::AddressSpace::WorkGroup => {
980                write!(self.out, "groupshared ")?;
981                self.write_type(module, global.ty)?;
982                ""
983            }
984            crate::AddressSpace::TaskPayload => unimplemented!(),
985            crate::AddressSpace::Uniform => {
986                // constant buffer declarations are expected to be inlined, e.g.
987                // `cbuffer foo: register(b0) { field1: type1; }`
988                write!(self.out, "cbuffer")?;
989                "b"
990            }
991            crate::AddressSpace::Storage { access } => {
992                let (prefix, register) = if access.contains(crate::StorageAccess::STORE) {
993                    ("RW", "u")
994                } else {
995                    ("", "t")
996                };
997                write!(self.out, "{prefix}ByteAddressBuffer")?;
998                register
999            }
1000            crate::AddressSpace::Handle => {
1001                let register = match *handle_ty {
1002                    // all storage textures are UAV, unconditionally
1003                    TypeInner::Image {
1004                        class: crate::ImageClass::Storage { .. },
1005                        ..
1006                    } => "u",
1007                    _ => "t",
1008                };
1009                self.write_type(module, global.ty)?;
1010                register
1011            }
1012            crate::AddressSpace::Immediate => {
1013                // The type of the immediates will be wrapped in `ConstantBuffer`
1014                write!(self.out, "ConstantBuffer<")?;
1015                "b"
1016            }
1017        };
1018
1019        // If the global is a immediate data write the type now because it will be a
1020        // generic argument to `ConstantBuffer`
1021        if global.space == crate::AddressSpace::Immediate {
1022            self.write_global_type(module, global.ty)?;
1023
1024            // need to write the array size if the type was emitted with `write_type`
1025            if let TypeInner::Array { base, size, .. } = module.types[global.ty].inner {
1026                self.write_array_size(module, base, size)?;
1027            }
1028
1029            // Close the angled brackets for the generic argument
1030            write!(self.out, ">")?;
1031        }
1032
1033        let name = &self.names[&NameKey::GlobalVariable(handle)];
1034        write!(self.out, " {name}")?;
1035
1036        // Immediates need to be assigned a binding explicitly by the consumer
1037        // since naga has no way to know the binding from the shader alone
1038        if global.space == crate::AddressSpace::Immediate {
1039            match module.types[global.ty].inner {
1040                TypeInner::Struct { .. } => {}
1041                _ => {
1042                    return Err(Error::Unimplemented(format!(
1043                        "push-constant '{name}' has non-struct type; tracked by: https://github.com/gfx-rs/wgpu/issues/5683"
1044                    )));
1045                }
1046            }
1047
1048            let target = self
1049                .options
1050                .immediates_target
1051                .as_ref()
1052                .expect("No bind target was defined for the immediates block");
1053            write!(self.out, ": register(b{}", target.register)?;
1054            if target.space != 0 {
1055                write!(self.out, ", space{}", target.space)?;
1056            }
1057            write!(self.out, ")")?;
1058        }
1059
1060        if let Some(ref binding) = global.binding {
1061            // this was already resolved earlier when we started evaluating an entry point.
1062            let bt = self.options.resolve_resource_binding(binding).unwrap();
1063
1064            // need to write the binding array size if the type was emitted with `write_type`
1065            if let TypeInner::BindingArray { base, size, .. } = module.types[global.ty].inner {
1066                if let Some(overridden_size) = bt.binding_array_size {
1067                    write!(self.out, "[{overridden_size}]")?;
1068                } else {
1069                    self.write_array_size(module, base, size)?;
1070                }
1071            }
1072
1073            write!(self.out, " : register({}{}", register_ty, bt.register)?;
1074            if bt.space != 0 {
1075                write!(self.out, ", space{}", bt.space)?;
1076            }
1077            write!(self.out, ")")?;
1078        } else {
1079            // need to write the array size if the type was emitted with `write_type`
1080            if let TypeInner::Array { base, size, .. } = module.types[global.ty].inner {
1081                self.write_array_size(module, base, size)?;
1082            }
1083            if global.space == crate::AddressSpace::Private {
1084                write!(self.out, " = ")?;
1085                if let Some(init) = global.init {
1086                    self.write_const_expression(module, init, &module.global_expressions)?;
1087                } else {
1088                    self.write_default_init(module, global.ty)?;
1089                }
1090            }
1091        }
1092
1093        if global.space == crate::AddressSpace::Uniform {
1094            write!(self.out, " {{ ")?;
1095
1096            self.write_global_type(module, global.ty)?;
1097
1098            write!(
1099                self.out,
1100                " {}",
1101                &self.names[&NameKey::GlobalVariable(handle)]
1102            )?;
1103
1104            // need to write the array size if the type was emitted with `write_type`
1105            if let TypeInner::Array { base, size, .. } = module.types[global.ty].inner {
1106                self.write_array_size(module, base, size)?;
1107            }
1108
1109            writeln!(self.out, "; }}")?;
1110        } else {
1111            writeln!(self.out, ";")?;
1112        }
1113
1114        Ok(())
1115    }
1116
1117    fn write_global_sampler(
1118        &mut self,
1119        module: &Module,
1120        handle: Handle<crate::GlobalVariable>,
1121        global: &crate::GlobalVariable,
1122    ) -> BackendResult {
1123        let binding = *global.binding.as_ref().unwrap();
1124
1125        let key = super::SamplerIndexBufferKey {
1126            group: binding.group,
1127        };
1128        self.write_wrapped_sampler_buffer(key)?;
1129
1130        // This was already validated, so we can confidently unwrap it.
1131        let bt = self.options.resolve_resource_binding(&binding).unwrap();
1132
1133        match module.types[global.ty].inner {
1134            TypeInner::Sampler { comparison } => {
1135                // If we are generating a static access, we create a variable for the sampler.
1136                //
1137                // This prevents the DXIL from containing multiple lookups for the sampler, which
1138                // the backend compiler will then have to eliminate. AMD does seem to be able to
1139                // eliminate these, but better safe than sorry.
1140
1141                write!(self.out, "static const ")?;
1142                self.write_type(module, global.ty)?;
1143
1144                let heap_var = if comparison {
1145                    COMPARISON_SAMPLER_HEAP_VAR
1146                } else {
1147                    SAMPLER_HEAP_VAR
1148                };
1149
1150                let index_buffer_name = &self.wrapped.sampler_index_buffers[&key];
1151                let name = &self.names[&NameKey::GlobalVariable(handle)];
1152                writeln!(
1153                    self.out,
1154                    " {name} = {heap_var}[{index_buffer_name}[{register}]];",
1155                    register = bt.register
1156                )?;
1157            }
1158            TypeInner::BindingArray { .. } => {
1159                // If we are generating a binding array, we cannot directly access the sampler as the index
1160                // into the sampler index buffer is unknown at compile time. Instead we generate a constant
1161                // that represents the "base" index into the sampler index buffer. This constant is added
1162                // to the user provided index to get the final index into the sampler index buffer.
1163
1164                let name = &self.names[&NameKey::GlobalVariable(handle)];
1165                writeln!(
1166                    self.out,
1167                    "static const uint {name} = {register};",
1168                    register = bt.register
1169                )?;
1170            }
1171            _ => unreachable!(),
1172        };
1173
1174        Ok(())
1175    }
1176
1177    /// Write the declarations for an external texture global variable.
1178    /// These are emitted as multiple global variables: Three `Texture2D`s
1179    /// (one for each plane) and a parameters cbuffer.
1180    fn write_global_external_texture(
1181        &mut self,
1182        module: &Module,
1183        handle: Handle<crate::GlobalVariable>,
1184        global: &crate::GlobalVariable,
1185    ) -> BackendResult {
1186        let res_binding = global
1187            .binding
1188            .as_ref()
1189            .expect("External texture global variables must have a resource binding");
1190        let ext_tex_bindings = match self
1191            .options
1192            .resolve_external_texture_resource_binding(res_binding)
1193        {
1194            Ok(bindings) => bindings,
1195            Err(err) => {
1196                log::debug!(
1197                    "Skipping global {:?} (name {:?}) for being inaccessible: {}",
1198                    handle,
1199                    global.name,
1200                    err,
1201                );
1202                return Ok(());
1203            }
1204        };
1205
1206        let mut write_plane = |bt: &super::BindTarget, name| -> BackendResult {
1207            write!(
1208                self.out,
1209                "Texture2D<float4> {}: register(t{}",
1210                name, bt.register
1211            )?;
1212            if bt.space != 0 {
1213                write!(self.out, ", space{}", bt.space)?;
1214            }
1215            writeln!(self.out, ");")?;
1216            Ok(())
1217        };
1218        for (i, bt) in ext_tex_bindings.planes.iter().enumerate() {
1219            let plane_name = &self.names
1220                [&NameKey::ExternalTextureGlobalVariable(handle, ExternalTextureNameKey::Plane(i))];
1221            write_plane(bt, plane_name)?;
1222        }
1223
1224        let params_name = &self.names
1225            [&NameKey::ExternalTextureGlobalVariable(handle, ExternalTextureNameKey::Params)];
1226        let params_ty_name =
1227            &self.names[&NameKey::Type(module.special_types.external_texture_params.unwrap())];
1228        write!(
1229            self.out,
1230            "cbuffer {}: register(b{}",
1231            params_name, ext_tex_bindings.params.register
1232        )?;
1233        if ext_tex_bindings.params.space != 0 {
1234            write!(self.out, ", space{}", ext_tex_bindings.params.space)?;
1235        }
1236        writeln!(self.out, ") {{ {params_ty_name} {params_name}; }};")?;
1237
1238        Ok(())
1239    }
1240
1241    /// Helper method used to write global constants
1242    ///
1243    /// # Notes
1244    /// Ends in a newline
1245    fn write_global_constant(
1246        &mut self,
1247        module: &Module,
1248        handle: Handle<crate::Constant>,
1249    ) -> BackendResult {
1250        write!(self.out, "static const ")?;
1251        let constant = &module.constants[handle];
1252        self.write_type(module, constant.ty)?;
1253        let name = &self.names[&NameKey::Constant(handle)];
1254        write!(self.out, " {name}")?;
1255        // Write size for array type
1256        if let TypeInner::Array { base, size, .. } = module.types[constant.ty].inner {
1257            self.write_array_size(module, base, size)?;
1258        }
1259        write!(self.out, " = ")?;
1260        self.write_const_expression(module, constant.init, &module.global_expressions)?;
1261        writeln!(self.out, ";")?;
1262        Ok(())
1263    }
1264
1265    pub(super) fn write_array_size(
1266        &mut self,
1267        module: &Module,
1268        base: Handle<crate::Type>,
1269        size: crate::ArraySize,
1270    ) -> BackendResult {
1271        write!(self.out, "[")?;
1272
1273        match size.resolve(module.to_ctx())? {
1274            proc::IndexableLength::Known(size) => {
1275                write!(self.out, "{size}")?;
1276            }
1277            proc::IndexableLength::Dynamic => unreachable!(),
1278        }
1279
1280        write!(self.out, "]")?;
1281
1282        if let TypeInner::Array {
1283            base: next_base,
1284            size: next_size,
1285            ..
1286        } = module.types[base].inner
1287        {
1288            self.write_array_size(module, next_base, next_size)?;
1289        }
1290
1291        Ok(())
1292    }
1293
1294    /// Helper method used to write structs
1295    ///
1296    /// # Notes
1297    /// Ends in a newline
1298    fn write_struct(
1299        &mut self,
1300        module: &Module,
1301        handle: Handle<crate::Type>,
1302        members: &[crate::StructMember],
1303        span: u32,
1304        shader_stage: Option<(ShaderStage, Io)>,
1305    ) -> BackendResult {
1306        // Write struct name
1307        let struct_name = &self.names[&NameKey::Type(handle)];
1308        writeln!(self.out, "struct {struct_name} {{")?;
1309
1310        let mut last_offset = 0;
1311        for (index, member) in members.iter().enumerate() {
1312            if member.binding.is_none() && member.offset > last_offset {
1313                // using int as padding should work as long as the backend
1314                // doesn't support a type that's less than 4 bytes in size
1315                // (Error::UnsupportedScalar catches this)
1316                let padding = (member.offset - last_offset) / 4;
1317                for i in 0..padding {
1318                    writeln!(self.out, "{}int _pad{}_{};", back::INDENT, index, i)?;
1319                }
1320            }
1321            let ty_inner = &module.types[member.ty].inner;
1322            last_offset = member.offset + ty_inner.size_hlsl(module.to_ctx())?;
1323
1324            // The indentation is only for readability
1325            write!(self.out, "{}", back::INDENT)?;
1326
1327            match module.types[member.ty].inner {
1328                TypeInner::Array { base, size, .. } => {
1329                    // HLSL arrays are written as `type name[size]`
1330
1331                    self.write_global_type(module, member.ty)?;
1332
1333                    // Write `name`
1334                    write!(
1335                        self.out,
1336                        " {}",
1337                        &self.names[&NameKey::StructMember(handle, index as u32)]
1338                    )?;
1339                    // Write [size]
1340                    self.write_array_size(module, base, size)?;
1341                }
1342                // We treat matrices of the form `matCx2` as a sequence of C `vec2`s.
1343                // See the module-level block comment in mod.rs for details.
1344                TypeInner::Matrix {
1345                    rows,
1346                    columns,
1347                    scalar,
1348                } if member.binding.is_none() && rows == crate::VectorSize::Bi => {
1349                    let vec_ty = TypeInner::Vector { size: rows, scalar };
1350                    let field_name_key = NameKey::StructMember(handle, index as u32);
1351
1352                    for i in 0..columns as u8 {
1353                        if i != 0 {
1354                            write!(self.out, "; ")?;
1355                        }
1356                        self.write_value_type(module, &vec_ty)?;
1357                        write!(self.out, " {}_{}", &self.names[&field_name_key], i)?;
1358                    }
1359                }
1360                _ => {
1361                    // Write modifier before type
1362                    if let Some(ref binding) = member.binding {
1363                        self.write_modifier(binding)?;
1364                    }
1365
1366                    // Even though Naga IR matrices are column-major, we must describe
1367                    // matrices passed from the CPU as being in row-major order.
1368                    // See the module-level block comment in mod.rs for details.
1369                    if let TypeInner::Matrix { .. } = module.types[member.ty].inner {
1370                        write!(self.out, "row_major ")?;
1371                    }
1372
1373                    // Write the member type and name
1374                    self.write_type(module, member.ty)?;
1375                    write!(
1376                        self.out,
1377                        " {}",
1378                        &self.names[&NameKey::StructMember(handle, index as u32)]
1379                    )?;
1380                }
1381            }
1382
1383            self.write_semantic(&member.binding, shader_stage)?;
1384            writeln!(self.out, ";")?;
1385        }
1386
1387        // add padding at the end since sizes of types don't get rounded up to their alignment in HLSL
1388        if members.last().unwrap().binding.is_none() && span > last_offset {
1389            let padding = (span - last_offset) / 4;
1390            for i in 0..padding {
1391                writeln!(self.out, "{}int _end_pad_{};", back::INDENT, i)?;
1392            }
1393        }
1394
1395        writeln!(self.out, "}};")?;
1396        Ok(())
1397    }
1398
1399    /// Helper method used to write global/structs non image/sampler types
1400    ///
1401    /// # Notes
1402    /// Adds no trailing or leading whitespace
1403    pub(super) fn write_global_type(
1404        &mut self,
1405        module: &Module,
1406        ty: Handle<crate::Type>,
1407    ) -> BackendResult {
1408        let matrix_data = get_inner_matrix_data(module, ty);
1409
1410        // We treat matrices of the form `matCx2` as a sequence of C `vec2`s.
1411        // See the module-level block comment in mod.rs for details.
1412        if let Some(MatrixType {
1413            columns,
1414            rows: crate::VectorSize::Bi,
1415            width: 4,
1416        }) = matrix_data
1417        {
1418            write!(self.out, "__mat{}x2", columns as u8)?;
1419        } else {
1420            // Even though Naga IR matrices are column-major, we must describe
1421            // matrices passed from the CPU as being in row-major order.
1422            // See the module-level block comment in mod.rs for details.
1423            if matrix_data.is_some() {
1424                write!(self.out, "row_major ")?;
1425            }
1426
1427            self.write_type(module, ty)?;
1428        }
1429
1430        Ok(())
1431    }
1432
1433    /// Helper method used to write non image/sampler types
1434    ///
1435    /// # Notes
1436    /// Adds no trailing or leading whitespace
1437    pub(super) fn write_type(&mut self, module: &Module, ty: Handle<crate::Type>) -> BackendResult {
1438        let inner = &module.types[ty].inner;
1439        match *inner {
1440            TypeInner::Struct { .. } => write!(self.out, "{}", self.names[&NameKey::Type(ty)])?,
1441            // hlsl array has the size separated from the base type
1442            TypeInner::Array { base, .. } | TypeInner::BindingArray { base, .. } => {
1443                self.write_type(module, base)?
1444            }
1445            ref other => self.write_value_type(module, other)?,
1446        }
1447
1448        Ok(())
1449    }
1450
1451    /// Helper method used to write value types
1452    ///
1453    /// # Notes
1454    /// Adds no trailing or leading whitespace
1455    pub(super) fn write_value_type(&mut self, module: &Module, inner: &TypeInner) -> BackendResult {
1456        match *inner {
1457            TypeInner::Scalar(scalar) | TypeInner::Atomic(scalar) => {
1458                write!(self.out, "{}", scalar.to_hlsl_str()?)?;
1459            }
1460            TypeInner::Vector { size, scalar } => {
1461                write!(
1462                    self.out,
1463                    "{}{}",
1464                    scalar.to_hlsl_str()?,
1465                    common::vector_size_str(size)
1466                )?;
1467            }
1468            TypeInner::Matrix {
1469                columns,
1470                rows,
1471                scalar,
1472            } => {
1473                // The IR supports only float matrix
1474                // https://docs.microsoft.com/en-us/windows/win32/direct3dhlsl/dx-graphics-hlsl-matrix
1475
1476                // Because of the implicit transpose all matrices have in HLSL, we need to transpose the size as well.
1477                write!(
1478                    self.out,
1479                    "{}{}x{}",
1480                    scalar.to_hlsl_str()?,
1481                    common::vector_size_str(columns),
1482                    common::vector_size_str(rows),
1483                )?;
1484            }
1485            TypeInner::Image {
1486                dim,
1487                arrayed,
1488                class,
1489            } => {
1490                self.write_image_type(dim, arrayed, class)?;
1491            }
1492            TypeInner::Sampler { comparison } => {
1493                let sampler = if comparison {
1494                    "SamplerComparisonState"
1495                } else {
1496                    "SamplerState"
1497                };
1498                write!(self.out, "{sampler}")?;
1499            }
1500            // HLSL arrays are written as `type name[size]`
1501            // Current code is written arrays only as `[size]`
1502            // Base `type` and `name` should be written outside
1503            TypeInner::Array { base, size, .. } | TypeInner::BindingArray { base, size } => {
1504                self.write_array_size(module, base, size)?;
1505            }
1506            TypeInner::AccelerationStructure { .. } => {
1507                write!(self.out, "RaytracingAccelerationStructure")?;
1508            }
1509            TypeInner::RayQuery { .. } => {
1510                // these are constant flags, there are dynamic flags also but constant flags are not supported by naga
1511                write!(self.out, "RayQuery<RAY_FLAG_NONE>")?;
1512            }
1513            _ => return Err(Error::Unimplemented(format!("write_value_type {inner:?}"))),
1514        }
1515
1516        Ok(())
1517    }
1518
1519    /// Helper method used to write functions
1520    /// # Notes
1521    /// Ends in a newline
1522    fn write_function(
1523        &mut self,
1524        module: &Module,
1525        name: &str,
1526        func: &crate::Function,
1527        func_ctx: &back::FunctionCtx<'_>,
1528        info: &valid::FunctionInfo,
1529    ) -> BackendResult {
1530        // Function Declaration Syntax - https://docs.microsoft.com/en-us/windows/win32/direct3dhlsl/dx-graphics-hlsl-function-syntax
1531
1532        self.update_expressions_to_bake(module, func, info);
1533
1534        if let Some(ref result) = func.result {
1535            // Write typedef if return type is an array
1536            let array_return_type = match module.types[result.ty].inner {
1537                TypeInner::Array { base, size, .. } => {
1538                    let array_return_type = self.namer.call(&format!("ret_{name}"));
1539                    write!(self.out, "typedef ")?;
1540                    self.write_type(module, result.ty)?;
1541                    write!(self.out, " {array_return_type}")?;
1542                    self.write_array_size(module, base, size)?;
1543                    writeln!(self.out, ";")?;
1544                    Some(array_return_type)
1545                }
1546                _ => None,
1547            };
1548
1549            // Write modifier
1550            if let Some(
1551                ref binding @ crate::Binding::BuiltIn(crate::BuiltIn::Position { invariant: true }),
1552            ) = result.binding
1553            {
1554                self.write_modifier(binding)?;
1555            }
1556
1557            // Write return type
1558            match func_ctx.ty {
1559                back::FunctionType::Function(_) => {
1560                    if let Some(array_return_type) = array_return_type {
1561                        write!(self.out, "{array_return_type}")?;
1562                    } else {
1563                        self.write_type(module, result.ty)?;
1564                    }
1565                }
1566                back::FunctionType::EntryPoint(index) => {
1567                    if let Some(ref ep_output) =
1568                        self.entry_point_io.get(&(index as usize)).unwrap().output
1569                    {
1570                        write!(self.out, "{}", ep_output.ty_name)?;
1571                    } else {
1572                        self.write_type(module, result.ty)?;
1573                    }
1574                }
1575            }
1576        } else {
1577            write!(self.out, "void")?;
1578        }
1579
1580        // Write function name
1581        write!(self.out, " {name}(")?;
1582
1583        let need_workgroup_variables_initialization =
1584            self.need_workgroup_variables_initialization(func_ctx, module);
1585
1586        // Write function arguments for non entry point functions
1587        match func_ctx.ty {
1588            back::FunctionType::Function(handle) => {
1589                for (index, arg) in func.arguments.iter().enumerate() {
1590                    if index != 0 {
1591                        write!(self.out, ", ")?;
1592                    }
1593
1594                    self.write_function_argument(module, handle, arg, index)?;
1595                }
1596            }
1597            back::FunctionType::EntryPoint(ep_index) => {
1598                if let Some(ref ep_input) =
1599                    self.entry_point_io.get(&(ep_index as usize)).unwrap().input
1600                {
1601                    write!(self.out, "{} {}", ep_input.ty_name, ep_input.arg_name)?;
1602                } else {
1603                    let stage = module.entry_points[ep_index as usize].stage;
1604                    for (index, arg) in func.arguments.iter().enumerate() {
1605                        if index != 0 {
1606                            write!(self.out, ", ")?;
1607                        }
1608                        self.write_type(module, arg.ty)?;
1609
1610                        let argument_name =
1611                            &self.names[&NameKey::EntryPointArgument(ep_index, index as u32)];
1612
1613                        write!(self.out, " {argument_name}")?;
1614                        if let TypeInner::Array { base, size, .. } = module.types[arg.ty].inner {
1615                            self.write_array_size(module, base, size)?;
1616                        }
1617
1618                        self.write_semantic(&arg.binding, Some((stage, Io::Input)))?;
1619                    }
1620                }
1621                if need_workgroup_variables_initialization {
1622                    if self
1623                        .entry_point_io
1624                        .get(&(ep_index as usize))
1625                        .unwrap()
1626                        .input
1627                        .is_some()
1628                        || !func.arguments.is_empty()
1629                    {
1630                        write!(self.out, ", ")?;
1631                    }
1632                    write!(self.out, "uint3 __local_invocation_id : SV_GroupThreadID")?;
1633                }
1634            }
1635        }
1636        // Ends of arguments
1637        write!(self.out, ")")?;
1638
1639        // Write semantic if it present
1640        if let back::FunctionType::EntryPoint(index) = func_ctx.ty {
1641            let stage = module.entry_points[index as usize].stage;
1642            if let Some(crate::FunctionResult { ref binding, .. }) = func.result {
1643                self.write_semantic(binding, Some((stage, Io::Output)))?;
1644            }
1645        }
1646
1647        // Function body start
1648        writeln!(self.out)?;
1649        writeln!(self.out, "{{")?;
1650
1651        if need_workgroup_variables_initialization {
1652            self.write_workgroup_variables_initialization(func_ctx, module)?;
1653        }
1654
1655        if let back::FunctionType::EntryPoint(index) = func_ctx.ty {
1656            self.write_ep_arguments_initialization(module, func, index)?;
1657        }
1658
1659        // Write function local variables
1660        for (handle, local) in func.local_variables.iter() {
1661            // Write indentation (only for readability)
1662            write!(self.out, "{}", back::INDENT)?;
1663
1664            // Write the local name
1665            // The leading space is important
1666            self.write_type(module, local.ty)?;
1667            write!(self.out, " {}", self.names[&func_ctx.name_key(handle)])?;
1668            // Write size for array type
1669            if let TypeInner::Array { base, size, .. } = module.types[local.ty].inner {
1670                self.write_array_size(module, base, size)?;
1671            }
1672
1673            let is_ray_query = match module.types[local.ty].inner {
1674                // from https://microsoft.github.io/DirectX-Specs/d3d/Raytracing.html#tracerayinline-example-1 it seems that ray queries shouldn't be zeroed
1675                TypeInner::RayQuery { .. } => true,
1676                _ => {
1677                    write!(self.out, " = ")?;
1678                    // Write the local initializer if needed
1679                    if let Some(init) = local.init {
1680                        self.write_expr(module, init, func_ctx)?;
1681                    } else {
1682                        // Zero initialize local variables
1683                        self.write_default_init(module, local.ty)?;
1684                    }
1685                    false
1686                }
1687            };
1688            // Finish the local with `;` and add a newline (only for readability)
1689            writeln!(self.out, ";")?;
1690            // If it's a ray query, we also want a tracker variable
1691            if is_ray_query {
1692                write!(self.out, "{}", back::INDENT)?;
1693                self.write_value_type(module, &TypeInner::Scalar(Scalar::U32))?;
1694                writeln!(
1695                    self.out,
1696                    " {RAY_QUERY_TRACKER_VARIABLE_PREFIX}{} = 0;",
1697                    self.names[&func_ctx.name_key(handle)]
1698                )?;
1699            }
1700        }
1701
1702        if !func.local_variables.is_empty() {
1703            writeln!(self.out)?;
1704        }
1705
1706        // Write the function body (statement list)
1707        for sta in func.body.iter() {
1708            // The indentation should always be 1 when writing the function body
1709            self.write_stmt(module, sta, func_ctx, back::Level(1))?;
1710        }
1711
1712        writeln!(self.out, "}}")?;
1713
1714        self.named_expressions.clear();
1715
1716        Ok(())
1717    }
1718
1719    fn write_function_argument(
1720        &mut self,
1721        module: &Module,
1722        handle: Handle<crate::Function>,
1723        arg: &crate::FunctionArgument,
1724        index: usize,
1725    ) -> BackendResult {
1726        // External texture arguments must be expanded into separate
1727        // arguments for each plane and the params buffer.
1728        if let TypeInner::Image {
1729            class: crate::ImageClass::External,
1730            ..
1731        } = module.types[arg.ty].inner
1732        {
1733            return self.write_function_external_texture_argument(module, handle, index);
1734        }
1735
1736        // Write argument type
1737        let arg_ty = match module.types[arg.ty].inner {
1738            // pointers in function arguments are expected and resolve to `inout`
1739            TypeInner::Pointer { base, .. } => {
1740                //TODO: can we narrow this down to just `in` when possible?
1741                write!(self.out, "inout ")?;
1742                base
1743            }
1744            _ => arg.ty,
1745        };
1746        self.write_type(module, arg_ty)?;
1747
1748        let argument_name = &self.names[&NameKey::FunctionArgument(handle, index as u32)];
1749
1750        // Write argument name. Space is important.
1751        write!(self.out, " {argument_name}")?;
1752        if let TypeInner::Array { base, size, .. } = module.types[arg_ty].inner {
1753            self.write_array_size(module, base, size)?;
1754        }
1755
1756        Ok(())
1757    }
1758
1759    fn write_function_external_texture_argument(
1760        &mut self,
1761        module: &Module,
1762        handle: Handle<crate::Function>,
1763        index: usize,
1764    ) -> BackendResult {
1765        let plane_names = [0, 1, 2].map(|i| {
1766            &self.names[&NameKey::ExternalTextureFunctionArgument(
1767                handle,
1768                index as u32,
1769                ExternalTextureNameKey::Plane(i),
1770            )]
1771        });
1772        let params_name = &self.names[&NameKey::ExternalTextureFunctionArgument(
1773            handle,
1774            index as u32,
1775            ExternalTextureNameKey::Params,
1776        )];
1777        let params_ty_name =
1778            &self.names[&NameKey::Type(module.special_types.external_texture_params.unwrap())];
1779        write!(
1780            self.out,
1781            "Texture2D<float4> {}, Texture2D<float4> {}, Texture2D<float4> {}, {params_ty_name} {params_name}",
1782            plane_names[0], plane_names[1], plane_names[2],
1783        )?;
1784        Ok(())
1785    }
1786
1787    fn need_workgroup_variables_initialization(
1788        &mut self,
1789        func_ctx: &back::FunctionCtx,
1790        module: &Module,
1791    ) -> bool {
1792        self.options.zero_initialize_workgroup_memory
1793            && func_ctx.ty.is_compute_like_entry_point(module)
1794            && module.global_variables.iter().any(|(handle, var)| {
1795                !func_ctx.info[handle].is_empty() && var.space == crate::AddressSpace::WorkGroup
1796            })
1797    }
1798
1799    fn write_workgroup_variables_initialization(
1800        &mut self,
1801        func_ctx: &back::FunctionCtx,
1802        module: &Module,
1803    ) -> BackendResult {
1804        let level = back::Level(1);
1805
1806        writeln!(
1807            self.out,
1808            "{level}if (all(__local_invocation_id == uint3(0u, 0u, 0u))) {{"
1809        )?;
1810
1811        let vars = module.global_variables.iter().filter(|&(handle, var)| {
1812            !func_ctx.info[handle].is_empty() && var.space == crate::AddressSpace::WorkGroup
1813        });
1814
1815        for (handle, var) in vars {
1816            let name = &self.names[&NameKey::GlobalVariable(handle)];
1817            write!(self.out, "{}{} = ", level.next(), name)?;
1818            self.write_default_init(module, var.ty)?;
1819            writeln!(self.out, ";")?;
1820        }
1821
1822        writeln!(self.out, "{level}}}")?;
1823        self.write_control_barrier(crate::Barrier::WORK_GROUP, level)
1824    }
1825
1826    /// Helper method used to write switches
1827    fn write_switch(
1828        &mut self,
1829        module: &Module,
1830        func_ctx: &back::FunctionCtx<'_>,
1831        level: back::Level,
1832        selector: Handle<crate::Expression>,
1833        cases: &[crate::SwitchCase],
1834    ) -> BackendResult {
1835        // Write all cases
1836        let indent_level_1 = level.next();
1837        let indent_level_2 = indent_level_1.next();
1838
1839        // See docs of `back::continue_forward` module.
1840        if let Some(variable) = self.continue_ctx.enter_switch(&mut self.namer) {
1841            writeln!(self.out, "{level}bool {variable} = false;",)?;
1842        };
1843
1844        // Check if there is only one body, by seeing if all except the last case are fall through
1845        // with empty bodies. FXC doesn't handle these switches correctly, so
1846        // we generate a `do {} while(false);` loop instead. There must be a default case, so there
1847        // is no need to check if one of the cases would have matched.
1848        let one_body = cases
1849            .iter()
1850            .rev()
1851            .skip(1)
1852            .all(|case| case.fall_through && case.body.is_empty());
1853        if one_body {
1854            // Start the do-while
1855            writeln!(self.out, "{level}do {{")?;
1856            // Note: Expressions have no side-effects so we don't need to emit selector expression.
1857
1858            // Body
1859            if let Some(case) = cases.last() {
1860                for sta in case.body.iter() {
1861                    self.write_stmt(module, sta, func_ctx, indent_level_1)?;
1862                }
1863            }
1864            // End do-while
1865            writeln!(self.out, "{level}}} while(false);")?;
1866        } else {
1867            // Start the switch
1868            write!(self.out, "{level}")?;
1869            write!(self.out, "switch(")?;
1870            self.write_expr(module, selector, func_ctx)?;
1871            writeln!(self.out, ") {{")?;
1872
1873            for (i, case) in cases.iter().enumerate() {
1874                match case.value {
1875                    crate::SwitchValue::I32(value) => {
1876                        write!(self.out, "{indent_level_1}case {value}:")?
1877                    }
1878                    crate::SwitchValue::U32(value) => {
1879                        write!(self.out, "{indent_level_1}case {value}u:")?
1880                    }
1881                    crate::SwitchValue::Default => write!(self.out, "{indent_level_1}default:")?,
1882                }
1883
1884                // The new block is not only stylistic, it plays a role here:
1885                // We might end up having to write the same case body
1886                // multiple times due to FXC not supporting fallthrough.
1887                // Therefore, some `Expression`s written by `Statement::Emit`
1888                // will end up having the same name (`_expr<handle_index>`).
1889                // So we need to put each case in its own scope.
1890                let write_block_braces = !(case.fall_through && case.body.is_empty());
1891                if write_block_braces {
1892                    writeln!(self.out, " {{")?;
1893                } else {
1894                    writeln!(self.out)?;
1895                }
1896
1897                // Although FXC does support a series of case clauses before
1898                // a block[^yes], it does not support fallthrough from a
1899                // non-empty case block to the next[^no]. If this case has a
1900                // non-empty body with a fallthrough, emulate that by
1901                // duplicating the bodies of all the cases it would fall
1902                // into as extensions of this case's own body. This makes
1903                // the HLSL output potentially quadratic in the size of the
1904                // Naga IR.
1905                //
1906                // [^yes]: ```hlsl
1907                // case 1:
1908                // case 2: do_stuff()
1909                // ```
1910                // [^no]: ```hlsl
1911                // case 1: do_this();
1912                // case 2: do_that();
1913                // ```
1914                if case.fall_through && !case.body.is_empty() {
1915                    let curr_len = i + 1;
1916                    let end_case_idx = curr_len
1917                        + cases
1918                            .iter()
1919                            .skip(curr_len)
1920                            .position(|case| !case.fall_through)
1921                            .unwrap();
1922                    let indent_level_3 = indent_level_2.next();
1923                    for case in &cases[i..=end_case_idx] {
1924                        writeln!(self.out, "{indent_level_2}{{")?;
1925                        let prev_len = self.named_expressions.len();
1926                        for sta in case.body.iter() {
1927                            self.write_stmt(module, sta, func_ctx, indent_level_3)?;
1928                        }
1929                        // Clear all named expressions that were previously inserted by the statements in the block
1930                        self.named_expressions.truncate(prev_len);
1931                        writeln!(self.out, "{indent_level_2}}}")?;
1932                    }
1933
1934                    let last_case = &cases[end_case_idx];
1935                    if last_case.body.last().is_none_or(|s| !s.is_terminator()) {
1936                        writeln!(self.out, "{indent_level_2}break;")?;
1937                    }
1938                } else {
1939                    for sta in case.body.iter() {
1940                        self.write_stmt(module, sta, func_ctx, indent_level_2)?;
1941                    }
1942                    if !case.fall_through && case.body.last().is_none_or(|s| !s.is_terminator()) {
1943                        writeln!(self.out, "{indent_level_2}break;")?;
1944                    }
1945                }
1946
1947                if write_block_braces {
1948                    writeln!(self.out, "{indent_level_1}}}")?;
1949                }
1950            }
1951
1952            writeln!(self.out, "{level}}}")?;
1953        }
1954
1955        // Handle any forwarded continue statements.
1956        use back::continue_forward::ExitControlFlow;
1957        let op = match self.continue_ctx.exit_switch() {
1958            ExitControlFlow::None => None,
1959            ExitControlFlow::Continue { variable } => Some(("continue", variable)),
1960            ExitControlFlow::Break { variable } => Some(("break", variable)),
1961        };
1962        if let Some((control_flow, variable)) = op {
1963            writeln!(self.out, "{level}if ({variable}) {{")?;
1964            writeln!(self.out, "{indent_level_1}{control_flow};")?;
1965            writeln!(self.out, "{level}}}")?;
1966        }
1967
1968        Ok(())
1969    }
1970
1971    fn write_index(
1972        &mut self,
1973        module: &Module,
1974        index: Index,
1975        func_ctx: &back::FunctionCtx<'_>,
1976    ) -> BackendResult {
1977        match index {
1978            Index::Static(index) => {
1979                write!(self.out, "{index}")?;
1980            }
1981            Index::Expression(index) => {
1982                self.write_expr(module, index, func_ctx)?;
1983            }
1984        }
1985        Ok(())
1986    }
1987
1988    /// Helper method used to write statements
1989    ///
1990    /// # Notes
1991    /// Always adds a newline
1992    fn write_stmt(
1993        &mut self,
1994        module: &Module,
1995        stmt: &crate::Statement,
1996        func_ctx: &back::FunctionCtx<'_>,
1997        level: back::Level,
1998    ) -> BackendResult {
1999        use crate::Statement;
2000
2001        match *stmt {
2002            Statement::Emit(ref range) => {
2003                for handle in range.clone() {
2004                    let ptr_class = func_ctx.resolve_type(handle, &module.types).pointer_space();
2005                    let expr_name = if ptr_class.is_some() {
2006                        // HLSL can't save a pointer-valued expression in a variable,
2007                        // but we shouldn't ever need to: they should never be named expressions,
2008                        // and none of the expression types flagged by bake_ref_count can be pointer-valued.
2009                        None
2010                    } else if let Some(name) = func_ctx.named_expressions.get(&handle) {
2011                        // Front end provides names for all variables at the start of writing.
2012                        // But we write them to step by step. We need to recache them
2013                        // Otherwise, we could accidentally write variable name instead of full expression.
2014                        // Also, we use sanitized names! It defense backend from generating variable with name from reserved keywords.
2015                        Some(self.namer.call(name))
2016                    } else if self.need_bake_expressions.contains(&handle) {
2017                        Some(Baked(handle).to_string())
2018                    } else {
2019                        None
2020                    };
2021
2022                    if let Some(name) = expr_name {
2023                        write!(self.out, "{level}")?;
2024                        self.write_named_expr(module, handle, name, handle, func_ctx)?;
2025                    }
2026                }
2027            }
2028            // TODO: copy-paste from glsl-out
2029            Statement::Block(ref block) => {
2030                write!(self.out, "{level}")?;
2031                writeln!(self.out, "{{")?;
2032                for sta in block.iter() {
2033                    // Increase the indentation to help with readability
2034                    self.write_stmt(module, sta, func_ctx, level.next())?
2035                }
2036                writeln!(self.out, "{level}}}")?
2037            }
2038            // TODO: copy-paste from glsl-out
2039            Statement::If {
2040                condition,
2041                ref accept,
2042                ref reject,
2043            } => {
2044                write!(self.out, "{level}")?;
2045                write!(self.out, "if (")?;
2046                self.write_expr(module, condition, func_ctx)?;
2047                writeln!(self.out, ") {{")?;
2048
2049                let l2 = level.next();
2050                for sta in accept {
2051                    // Increase indentation to help with readability
2052                    self.write_stmt(module, sta, func_ctx, l2)?;
2053                }
2054
2055                // If there are no statements in the reject block we skip writing it
2056                // This is only for readability
2057                if !reject.is_empty() {
2058                    writeln!(self.out, "{level}}} else {{")?;
2059
2060                    for sta in reject {
2061                        // Increase indentation to help with readability
2062                        self.write_stmt(module, sta, func_ctx, l2)?;
2063                    }
2064                }
2065
2066                writeln!(self.out, "{level}}}")?
2067            }
2068            // TODO: copy-paste from glsl-out
2069            Statement::Kill => writeln!(self.out, "{level}discard;")?,
2070            Statement::Return { value: None } => {
2071                writeln!(self.out, "{level}return;")?;
2072            }
2073            Statement::Return { value: Some(expr) } => {
2074                let base_ty_res = &func_ctx.info[expr].ty;
2075                let mut resolved = base_ty_res.inner_with(&module.types);
2076                if let TypeInner::Pointer { base, space: _ } = *resolved {
2077                    resolved = &module.types[base].inner;
2078                }
2079
2080                if let TypeInner::Struct { .. } = *resolved {
2081                    // We can safely unwrap here, since we now we working with struct
2082                    let ty = base_ty_res.handle().unwrap();
2083                    let struct_name = &self.names[&NameKey::Type(ty)];
2084                    let variable_name = self.namer.call(&struct_name.to_lowercase());
2085                    write!(self.out, "{level}const {struct_name} {variable_name} = ",)?;
2086                    self.write_expr(module, expr, func_ctx)?;
2087                    writeln!(self.out, ";")?;
2088
2089                    // for entry point returns, we may need to reshuffle the outputs into a different struct
2090                    let ep_output = match func_ctx.ty {
2091                        back::FunctionType::Function(_) => None,
2092                        back::FunctionType::EntryPoint(index) => self
2093                            .entry_point_io
2094                            .get(&(index as usize))
2095                            .unwrap()
2096                            .output
2097                            .as_ref(),
2098                    };
2099                    let final_name = match ep_output {
2100                        Some(ep_output) => {
2101                            let final_name = self.namer.call(&variable_name);
2102                            write!(
2103                                self.out,
2104                                "{}const {} {} = {{ ",
2105                                level, ep_output.ty_name, final_name,
2106                            )?;
2107                            for (index, m) in ep_output.members.iter().enumerate() {
2108                                if index != 0 {
2109                                    write!(self.out, ", ")?;
2110                                }
2111                                let member_name = &self.names[&NameKey::StructMember(ty, m.index)];
2112                                write!(self.out, "{variable_name}.{member_name}")?;
2113                            }
2114                            writeln!(self.out, " }};")?;
2115                            final_name
2116                        }
2117                        None => variable_name,
2118                    };
2119                    writeln!(self.out, "{level}return {final_name};")?;
2120                } else {
2121                    write!(self.out, "{level}return ")?;
2122                    self.write_expr(module, expr, func_ctx)?;
2123                    writeln!(self.out, ";")?
2124                }
2125            }
2126            Statement::Store { pointer, value } => {
2127                let ty_inner = func_ctx.resolve_type(pointer, &module.types);
2128                if let Some(crate::AddressSpace::Storage { .. }) = ty_inner.pointer_space() {
2129                    let var_handle = self.fill_access_chain(module, pointer, func_ctx)?;
2130                    self.write_storage_store(
2131                        module,
2132                        var_handle,
2133                        StoreValue::Expression(value),
2134                        func_ctx,
2135                        level,
2136                        None,
2137                    )?;
2138                } else {
2139                    // We treat matrices of the form `matCx2` as a sequence of C `vec2`s.
2140                    // See the module-level block comment in mod.rs for details.
2141                    //
2142                    // We handle matrix Stores here directly (including sub accesses for Vectors and Scalars).
2143                    // Loads are handled by `Expression::AccessIndex` (since sub accesses work fine for Loads).
2144                    enum MatrixAccess {
2145                        Direct {
2146                            base: Handle<crate::Expression>,
2147                            index: u32,
2148                        },
2149                        Struct {
2150                            columns: crate::VectorSize,
2151                            base: Handle<crate::Expression>,
2152                        },
2153                    }
2154
2155                    let get_members = |expr: Handle<crate::Expression>| {
2156                        let resolved = func_ctx.resolve_type(expr, &module.types);
2157                        match *resolved {
2158                            TypeInner::Pointer { base, .. } => match module.types[base].inner {
2159                                TypeInner::Struct { ref members, .. } => Some(members),
2160                                _ => None,
2161                            },
2162                            _ => None,
2163                        }
2164                    };
2165
2166                    write!(self.out, "{level}")?;
2167
2168                    let matrix_access_on_lhs =
2169                        find_matrix_in_access_chain(module, pointer, func_ctx).and_then(
2170                            |(matrix_expr, vector, scalar)| match (
2171                                func_ctx.resolve_type(matrix_expr, &module.types),
2172                                &func_ctx.expressions[matrix_expr],
2173                            ) {
2174                                (
2175                                    &TypeInner::Pointer { base: ty, .. },
2176                                    &crate::Expression::AccessIndex { base, index },
2177                                ) if matches!(
2178                                    module.types[ty].inner,
2179                                    TypeInner::Matrix {
2180                                        rows: crate::VectorSize::Bi,
2181                                        ..
2182                                    }
2183                                ) && get_members(base)
2184                                    .map(|members| members[index as usize].binding.is_none())
2185                                    == Some(true) =>
2186                                {
2187                                    Some((MatrixAccess::Direct { base, index }, vector, scalar))
2188                                }
2189                                _ => {
2190                                    if let Some(MatrixType {
2191                                        columns,
2192                                        rows: crate::VectorSize::Bi,
2193                                        width: 4,
2194                                    }) = get_inner_matrix_of_struct_array_member(
2195                                        module,
2196                                        matrix_expr,
2197                                        func_ctx,
2198                                        true,
2199                                    ) {
2200                                        Some((
2201                                            MatrixAccess::Struct {
2202                                                columns,
2203                                                base: matrix_expr,
2204                                            },
2205                                            vector,
2206                                            scalar,
2207                                        ))
2208                                    } else {
2209                                        None
2210                                    }
2211                                }
2212                            },
2213                        );
2214
2215                    match matrix_access_on_lhs {
2216                        Some((MatrixAccess::Direct { index, base }, vector, scalar)) => {
2217                            let base_ty_res = &func_ctx.info[base].ty;
2218                            let resolved = base_ty_res.inner_with(&module.types);
2219                            let ty = match *resolved {
2220                                TypeInner::Pointer { base, .. } => base,
2221                                _ => base_ty_res.handle().unwrap(),
2222                            };
2223
2224                            if let Some(Index::Static(vec_index)) = vector {
2225                                self.write_expr(module, base, func_ctx)?;
2226                                write!(
2227                                    self.out,
2228                                    ".{}_{}",
2229                                    &self.names[&NameKey::StructMember(ty, index)],
2230                                    vec_index
2231                                )?;
2232
2233                                if let Some(scalar_index) = scalar {
2234                                    write!(self.out, "[")?;
2235                                    self.write_index(module, scalar_index, func_ctx)?;
2236                                    write!(self.out, "]")?;
2237                                }
2238
2239                                write!(self.out, " = ")?;
2240                                self.write_expr(module, value, func_ctx)?;
2241                                writeln!(self.out, ";")?;
2242                            } else {
2243                                let access = WrappedStructMatrixAccess { ty, index };
2244                                match (&vector, &scalar) {
2245                                    (&Some(_), &Some(_)) => {
2246                                        self.write_wrapped_struct_matrix_set_scalar_function_name(
2247                                            access,
2248                                        )?;
2249                                    }
2250                                    (&Some(_), &None) => {
2251                                        self.write_wrapped_struct_matrix_set_vec_function_name(
2252                                            access,
2253                                        )?;
2254                                    }
2255                                    (&None, _) => {
2256                                        self.write_wrapped_struct_matrix_set_function_name(access)?;
2257                                    }
2258                                }
2259
2260                                write!(self.out, "(")?;
2261                                self.write_expr(module, base, func_ctx)?;
2262                                write!(self.out, ", ")?;
2263                                self.write_expr(module, value, func_ctx)?;
2264
2265                                if let Some(Index::Expression(vec_index)) = vector {
2266                                    write!(self.out, ", ")?;
2267                                    self.write_expr(module, vec_index, func_ctx)?;
2268
2269                                    if let Some(scalar_index) = scalar {
2270                                        write!(self.out, ", ")?;
2271                                        self.write_index(module, scalar_index, func_ctx)?;
2272                                    }
2273                                }
2274                                writeln!(self.out, ");")?;
2275                            }
2276                        }
2277                        Some((
2278                            MatrixAccess::Struct { columns, base },
2279                            Some(Index::Expression(vec_index)),
2280                            scalar,
2281                        )) => {
2282                            // We handle `Store`s to __matCx2 column vectors and scalar elements via
2283                            // the previously injected functions __set_col_of_matCx2 / __set_el_of_matCx2.
2284
2285                            if scalar.is_some() {
2286                                write!(self.out, "__set_el_of_mat{}x2", columns as u8)?;
2287                            } else {
2288                                write!(self.out, "__set_col_of_mat{}x2", columns as u8)?;
2289                            }
2290                            write!(self.out, "(")?;
2291                            self.write_expr(module, base, func_ctx)?;
2292                            write!(self.out, ", ")?;
2293                            self.write_expr(module, vec_index, func_ctx)?;
2294
2295                            if let Some(scalar_index) = scalar {
2296                                write!(self.out, ", ")?;
2297                                self.write_index(module, scalar_index, func_ctx)?;
2298                            }
2299
2300                            write!(self.out, ", ")?;
2301                            self.write_expr(module, value, func_ctx)?;
2302
2303                            writeln!(self.out, ");")?;
2304                        }
2305                        Some((MatrixAccess::Struct { .. }, Some(Index::Static(_)), _))
2306                        | Some((MatrixAccess::Struct { .. }, None, _))
2307                        | None => {
2308                            self.write_expr(module, pointer, func_ctx)?;
2309                            write!(self.out, " = ")?;
2310
2311                            // We cast the RHS of this store in cases where the LHS
2312                            // is a struct member with type:
2313                            //  - matCx2 or
2314                            //  - a (possibly nested) array of matCx2's
2315                            if let Some(MatrixType {
2316                                columns,
2317                                rows: crate::VectorSize::Bi,
2318                                width: 4,
2319                            }) = get_inner_matrix_of_struct_array_member(
2320                                module, pointer, func_ctx, false,
2321                            ) {
2322                                let mut resolved = func_ctx.resolve_type(pointer, &module.types);
2323                                if let TypeInner::Pointer { base, .. } = *resolved {
2324                                    resolved = &module.types[base].inner;
2325                                }
2326
2327                                write!(self.out, "(__mat{}x2", columns as u8)?;
2328                                if let TypeInner::Array { base, size, .. } = *resolved {
2329                                    self.write_array_size(module, base, size)?;
2330                                }
2331                                write!(self.out, ")")?;
2332                            }
2333
2334                            self.write_expr(module, value, func_ctx)?;
2335                            writeln!(self.out, ";")?
2336                        }
2337                    }
2338                }
2339            }
2340            Statement::Loop {
2341                ref body,
2342                ref continuing,
2343                break_if,
2344            } => {
2345                let force_loop_bound_statements = self.gen_force_bounded_loop_statements(level);
2346                let gate_name = (!continuing.is_empty() || break_if.is_some())
2347                    .then(|| self.namer.call("loop_init"));
2348
2349                if let Some((ref decl, _)) = force_loop_bound_statements {
2350                    writeln!(self.out, "{decl}")?;
2351                }
2352                if let Some(ref gate_name) = gate_name {
2353                    writeln!(self.out, "{level}bool {gate_name} = true;")?;
2354                }
2355
2356                self.continue_ctx.enter_loop();
2357                writeln!(self.out, "{level}while(true) {{")?;
2358                if let Some((_, ref break_and_inc)) = force_loop_bound_statements {
2359                    writeln!(self.out, "{break_and_inc}")?;
2360                }
2361                let l2 = level.next();
2362                if let Some(gate_name) = gate_name {
2363                    writeln!(self.out, "{l2}if (!{gate_name}) {{")?;
2364                    let l3 = l2.next();
2365                    for sta in continuing.iter() {
2366                        self.write_stmt(module, sta, func_ctx, l3)?;
2367                    }
2368                    if let Some(condition) = break_if {
2369                        write!(self.out, "{l3}if (")?;
2370                        self.write_expr(module, condition, func_ctx)?;
2371                        writeln!(self.out, ") {{")?;
2372                        writeln!(self.out, "{}break;", l3.next())?;
2373                        writeln!(self.out, "{l3}}}")?;
2374                    }
2375                    writeln!(self.out, "{l2}}}")?;
2376                    writeln!(self.out, "{l2}{gate_name} = false;")?;
2377                }
2378
2379                for sta in body.iter() {
2380                    self.write_stmt(module, sta, func_ctx, l2)?;
2381                }
2382
2383                writeln!(self.out, "{level}}}")?;
2384                self.continue_ctx.exit_loop();
2385            }
2386            Statement::Break => writeln!(self.out, "{level}break;")?,
2387            Statement::Continue => {
2388                if let Some(variable) = self.continue_ctx.continue_encountered() {
2389                    writeln!(self.out, "{level}{variable} = true;")?;
2390                    writeln!(self.out, "{level}break;")?
2391                } else {
2392                    writeln!(self.out, "{level}continue;")?
2393                }
2394            }
2395            Statement::ControlBarrier(barrier) => {
2396                self.write_control_barrier(barrier, level)?;
2397            }
2398            Statement::MemoryBarrier(barrier) => {
2399                self.write_memory_barrier(barrier, level)?;
2400            }
2401            Statement::ImageStore {
2402                image,
2403                coordinate,
2404                array_index,
2405                value,
2406            } => {
2407                write!(self.out, "{level}")?;
2408                self.write_expr(module, image, func_ctx)?;
2409
2410                write!(self.out, "[")?;
2411                if let Some(index) = array_index {
2412                    // Array index accepted only for texture_storage_2d_array, so we can safety use int3(coordinate, array_index) here
2413                    write!(self.out, "int3(")?;
2414                    self.write_expr(module, coordinate, func_ctx)?;
2415                    write!(self.out, ", ")?;
2416                    self.write_expr(module, index, func_ctx)?;
2417                    write!(self.out, ")")?;
2418                } else {
2419                    self.write_expr(module, coordinate, func_ctx)?;
2420                }
2421                write!(self.out, "]")?;
2422
2423                write!(self.out, " = ")?;
2424                self.write_expr(module, value, func_ctx)?;
2425                writeln!(self.out, ";")?;
2426            }
2427            Statement::Call {
2428                function,
2429                ref arguments,
2430                result,
2431            } => {
2432                write!(self.out, "{level}")?;
2433                if let Some(expr) = result {
2434                    write!(self.out, "const ")?;
2435                    let name = Baked(expr).to_string();
2436                    let expr_ty = &func_ctx.info[expr].ty;
2437                    let ty_inner = match *expr_ty {
2438                        proc::TypeResolution::Handle(handle) => {
2439                            self.write_type(module, handle)?;
2440                            &module.types[handle].inner
2441                        }
2442                        proc::TypeResolution::Value(ref value) => {
2443                            self.write_value_type(module, value)?;
2444                            value
2445                        }
2446                    };
2447                    write!(self.out, " {name}")?;
2448                    if let TypeInner::Array { base, size, .. } = *ty_inner {
2449                        self.write_array_size(module, base, size)?;
2450                    }
2451                    write!(self.out, " = ")?;
2452                    self.named_expressions.insert(expr, name);
2453                }
2454                let func_name = &self.names[&NameKey::Function(function)];
2455                write!(self.out, "{func_name}(")?;
2456                for (index, argument) in arguments.iter().enumerate() {
2457                    if index != 0 {
2458                        write!(self.out, ", ")?;
2459                    }
2460                    self.write_expr(module, *argument, func_ctx)?;
2461                }
2462                writeln!(self.out, ");")?
2463            }
2464            Statement::Atomic {
2465                pointer,
2466                ref fun,
2467                value,
2468                result,
2469            } => {
2470                write!(self.out, "{level}")?;
2471                let res_var_info = if let Some(res_handle) = result {
2472                    let name = Baked(res_handle).to_string();
2473                    match func_ctx.info[res_handle].ty {
2474                        proc::TypeResolution::Handle(handle) => self.write_type(module, handle)?,
2475                        proc::TypeResolution::Value(ref value) => {
2476                            self.write_value_type(module, value)?
2477                        }
2478                    };
2479                    write!(self.out, " {name}; ")?;
2480                    self.named_expressions.insert(res_handle, name.clone());
2481                    Some((res_handle, name))
2482                } else {
2483                    None
2484                };
2485                let pointer_space = func_ctx
2486                    .resolve_type(pointer, &module.types)
2487                    .pointer_space()
2488                    .unwrap();
2489                let fun_str = fun.to_hlsl_suffix();
2490                let compare_expr = match *fun {
2491                    crate::AtomicFunction::Exchange { compare: Some(cmp) } => Some(cmp),
2492                    _ => None,
2493                };
2494                match pointer_space {
2495                    crate::AddressSpace::WorkGroup => {
2496                        write!(self.out, "Interlocked{fun_str}(")?;
2497                        self.write_expr(module, pointer, func_ctx)?;
2498                        self.emit_hlsl_atomic_tail(
2499                            module,
2500                            func_ctx,
2501                            fun,
2502                            compare_expr,
2503                            value,
2504                            &res_var_info,
2505                        )?;
2506                    }
2507                    crate::AddressSpace::Storage { .. } => {
2508                        let var_handle = self.fill_access_chain(module, pointer, func_ctx)?;
2509                        let var_name = &self.names[&NameKey::GlobalVariable(var_handle)];
2510                        let width = match func_ctx.resolve_type(value, &module.types) {
2511                            &TypeInner::Scalar(Scalar { width: 8, .. }) => "64",
2512                            _ => "",
2513                        };
2514                        write!(self.out, "{var_name}.Interlocked{fun_str}{width}(")?;
2515                        let chain = mem::take(&mut self.temp_access_chain);
2516                        self.write_storage_address(module, &chain, func_ctx)?;
2517                        self.temp_access_chain = chain;
2518                        self.emit_hlsl_atomic_tail(
2519                            module,
2520                            func_ctx,
2521                            fun,
2522                            compare_expr,
2523                            value,
2524                            &res_var_info,
2525                        )?;
2526                    }
2527                    ref other => {
2528                        return Err(Error::Custom(format!(
2529                            "invalid address space {other:?} for atomic statement"
2530                        )))
2531                    }
2532                }
2533                if let Some(cmp) = compare_expr {
2534                    if let Some(&(_res_handle, ref res_name)) = res_var_info.as_ref() {
2535                        write!(
2536                            self.out,
2537                            "{level}{res_name}.exchanged = ({res_name}.old_value == "
2538                        )?;
2539                        self.write_expr(module, cmp, func_ctx)?;
2540                        writeln!(self.out, ");")?;
2541                    }
2542                }
2543            }
2544            Statement::ImageAtomic {
2545                image,
2546                coordinate,
2547                array_index,
2548                fun,
2549                value,
2550            } => {
2551                write!(self.out, "{level}")?;
2552
2553                let fun_str = fun.to_hlsl_suffix();
2554                write!(self.out, "Interlocked{fun_str}(")?;
2555                self.write_expr(module, image, func_ctx)?;
2556                write!(self.out, "[")?;
2557                self.write_texture_coordinates(
2558                    "int",
2559                    coordinate,
2560                    array_index,
2561                    None,
2562                    module,
2563                    func_ctx,
2564                )?;
2565                write!(self.out, "],")?;
2566
2567                self.write_expr(module, value, func_ctx)?;
2568                writeln!(self.out, ");")?;
2569            }
2570            Statement::WorkGroupUniformLoad { pointer, result } => {
2571                self.write_control_barrier(crate::Barrier::WORK_GROUP, level)?;
2572                write!(self.out, "{level}")?;
2573                let name = Baked(result).to_string();
2574                self.write_named_expr(module, pointer, name, result, func_ctx)?;
2575
2576                self.write_control_barrier(crate::Barrier::WORK_GROUP, level)?;
2577            }
2578            Statement::Switch {
2579                selector,
2580                ref cases,
2581            } => {
2582                self.write_switch(module, func_ctx, level, selector, cases)?;
2583            }
2584            Statement::RayQuery { query, ref fun } => {
2585                // There are three possibilities for a ptr to be:
2586                // 1. A variable
2587                // 2. A function argument
2588                // 3. part of a struct
2589                //
2590                // 2 and 3 are not possible, a ray query (in naga IR)
2591                // is not allowed to be passed into a function, and
2592                // all languages disallow it in a struct (you get fun results if
2593                // you try it :) ).
2594                //
2595                // Therefore, the ray query expression must be a variable.
2596                let crate::Expression::LocalVariable(query_var) = func_ctx.expressions[query]
2597                else {
2598                    unreachable!()
2599                };
2600
2601                let tracker_expr_name = format!(
2602                    "{RAY_QUERY_TRACKER_VARIABLE_PREFIX}{}",
2603                    self.names[&func_ctx.name_key(query_var)]
2604                );
2605
2606                match *fun {
2607                    RayQueryFunction::Initialize {
2608                        acceleration_structure,
2609                        descriptor,
2610                    } => {
2611                        self.write_initialize_function(
2612                            module,
2613                            level,
2614                            query,
2615                            acceleration_structure,
2616                            descriptor,
2617                            &tracker_expr_name,
2618                            func_ctx,
2619                        )?;
2620                    }
2621                    RayQueryFunction::Proceed { result } => {
2622                        self.write_proceed(
2623                            module,
2624                            level,
2625                            query,
2626                            result,
2627                            &tracker_expr_name,
2628                            func_ctx,
2629                        )?;
2630                    }
2631                    RayQueryFunction::GenerateIntersection { hit_t } => {
2632                        self.write_generate_intersection(
2633                            module,
2634                            level,
2635                            query,
2636                            hit_t,
2637                            &tracker_expr_name,
2638                            func_ctx,
2639                        )?;
2640                    }
2641                    RayQueryFunction::ConfirmIntersection => {
2642                        self.write_confirm_intersection(
2643                            module,
2644                            level,
2645                            query,
2646                            &tracker_expr_name,
2647                            func_ctx,
2648                        )?;
2649                    }
2650                    RayQueryFunction::Terminate => {
2651                        self.write_terminate(module, level, query, &tracker_expr_name, func_ctx)?;
2652                    }
2653                }
2654            }
2655            Statement::SubgroupBallot { result, predicate } => {
2656                write!(self.out, "{level}")?;
2657                let name = Baked(result).to_string();
2658                write!(self.out, "const uint4 {name} = ")?;
2659                self.named_expressions.insert(result, name);
2660
2661                write!(self.out, "WaveActiveBallot(")?;
2662                match predicate {
2663                    Some(predicate) => self.write_expr(module, predicate, func_ctx)?,
2664                    None => write!(self.out, "true")?,
2665                }
2666                writeln!(self.out, ");")?;
2667            }
2668            Statement::SubgroupCollectiveOperation {
2669                op,
2670                collective_op,
2671                argument,
2672                result,
2673            } => {
2674                write!(self.out, "{level}")?;
2675                write!(self.out, "const ")?;
2676                let name = Baked(result).to_string();
2677                match func_ctx.info[result].ty {
2678                    proc::TypeResolution::Handle(handle) => self.write_type(module, handle)?,
2679                    proc::TypeResolution::Value(ref value) => {
2680                        self.write_value_type(module, value)?
2681                    }
2682                };
2683                write!(self.out, " {name} = ")?;
2684                self.named_expressions.insert(result, name);
2685
2686                match (collective_op, op) {
2687                    (crate::CollectiveOperation::Reduce, crate::SubgroupOperation::All) => {
2688                        write!(self.out, "WaveActiveAllTrue(")?
2689                    }
2690                    (crate::CollectiveOperation::Reduce, crate::SubgroupOperation::Any) => {
2691                        write!(self.out, "WaveActiveAnyTrue(")?
2692                    }
2693                    (crate::CollectiveOperation::Reduce, crate::SubgroupOperation::Add) => {
2694                        write!(self.out, "WaveActiveSum(")?
2695                    }
2696                    (crate::CollectiveOperation::Reduce, crate::SubgroupOperation::Mul) => {
2697                        write!(self.out, "WaveActiveProduct(")?
2698                    }
2699                    (crate::CollectiveOperation::Reduce, crate::SubgroupOperation::Max) => {
2700                        write!(self.out, "WaveActiveMax(")?
2701                    }
2702                    (crate::CollectiveOperation::Reduce, crate::SubgroupOperation::Min) => {
2703                        write!(self.out, "WaveActiveMin(")?
2704                    }
2705                    (crate::CollectiveOperation::Reduce, crate::SubgroupOperation::And) => {
2706                        write!(self.out, "WaveActiveBitAnd(")?
2707                    }
2708                    (crate::CollectiveOperation::Reduce, crate::SubgroupOperation::Or) => {
2709                        write!(self.out, "WaveActiveBitOr(")?
2710                    }
2711                    (crate::CollectiveOperation::Reduce, crate::SubgroupOperation::Xor) => {
2712                        write!(self.out, "WaveActiveBitXor(")?
2713                    }
2714                    (crate::CollectiveOperation::ExclusiveScan, crate::SubgroupOperation::Add) => {
2715                        write!(self.out, "WavePrefixSum(")?
2716                    }
2717                    (crate::CollectiveOperation::ExclusiveScan, crate::SubgroupOperation::Mul) => {
2718                        write!(self.out, "WavePrefixProduct(")?
2719                    }
2720                    (crate::CollectiveOperation::InclusiveScan, crate::SubgroupOperation::Add) => {
2721                        self.write_expr(module, argument, func_ctx)?;
2722                        write!(self.out, " + WavePrefixSum(")?;
2723                    }
2724                    (crate::CollectiveOperation::InclusiveScan, crate::SubgroupOperation::Mul) => {
2725                        self.write_expr(module, argument, func_ctx)?;
2726                        write!(self.out, " * WavePrefixProduct(")?;
2727                    }
2728                    _ => unimplemented!(),
2729                }
2730                self.write_expr(module, argument, func_ctx)?;
2731                writeln!(self.out, ");")?;
2732            }
2733            Statement::SubgroupGather {
2734                mode,
2735                argument,
2736                result,
2737            } => {
2738                write!(self.out, "{level}")?;
2739                write!(self.out, "const ")?;
2740                let name = Baked(result).to_string();
2741                match func_ctx.info[result].ty {
2742                    proc::TypeResolution::Handle(handle) => self.write_type(module, handle)?,
2743                    proc::TypeResolution::Value(ref value) => {
2744                        self.write_value_type(module, value)?
2745                    }
2746                };
2747                write!(self.out, " {name} = ")?;
2748                self.named_expressions.insert(result, name);
2749                match mode {
2750                    crate::GatherMode::BroadcastFirst => {
2751                        write!(self.out, "WaveReadLaneFirst(")?;
2752                        self.write_expr(module, argument, func_ctx)?;
2753                    }
2754                    crate::GatherMode::QuadBroadcast(index) => {
2755                        write!(self.out, "QuadReadLaneAt(")?;
2756                        self.write_expr(module, argument, func_ctx)?;
2757                        write!(self.out, ", ")?;
2758                        self.write_expr(module, index, func_ctx)?;
2759                    }
2760                    crate::GatherMode::QuadSwap(direction) => {
2761                        match direction {
2762                            crate::Direction::X => {
2763                                write!(self.out, "QuadReadAcrossX(")?;
2764                            }
2765                            crate::Direction::Y => {
2766                                write!(self.out, "QuadReadAcrossY(")?;
2767                            }
2768                            crate::Direction::Diagonal => {
2769                                write!(self.out, "QuadReadAcrossDiagonal(")?;
2770                            }
2771                        }
2772                        self.write_expr(module, argument, func_ctx)?;
2773                    }
2774                    _ => {
2775                        write!(self.out, "WaveReadLaneAt(")?;
2776                        self.write_expr(module, argument, func_ctx)?;
2777                        write!(self.out, ", ")?;
2778                        match mode {
2779                            crate::GatherMode::BroadcastFirst => unreachable!(),
2780                            crate::GatherMode::Broadcast(index)
2781                            | crate::GatherMode::Shuffle(index) => {
2782                                self.write_expr(module, index, func_ctx)?;
2783                            }
2784                            crate::GatherMode::ShuffleDown(index) => {
2785                                write!(self.out, "WaveGetLaneIndex() + ")?;
2786                                self.write_expr(module, index, func_ctx)?;
2787                            }
2788                            crate::GatherMode::ShuffleUp(index) => {
2789                                write!(self.out, "WaveGetLaneIndex() - ")?;
2790                                self.write_expr(module, index, func_ctx)?;
2791                            }
2792                            crate::GatherMode::ShuffleXor(index) => {
2793                                write!(self.out, "WaveGetLaneIndex() ^ ")?;
2794                                self.write_expr(module, index, func_ctx)?;
2795                            }
2796                            crate::GatherMode::QuadBroadcast(_) => unreachable!(),
2797                            crate::GatherMode::QuadSwap(_) => unreachable!(),
2798                        }
2799                    }
2800                }
2801                writeln!(self.out, ");")?;
2802            }
2803            Statement::CooperativeStore { .. } => unimplemented!(),
2804        }
2805
2806        Ok(())
2807    }
2808
2809    fn write_const_expression(
2810        &mut self,
2811        module: &Module,
2812        expr: Handle<crate::Expression>,
2813        arena: &crate::Arena<crate::Expression>,
2814    ) -> BackendResult {
2815        self.write_possibly_const_expression(module, expr, arena, |writer, expr| {
2816            writer.write_const_expression(module, expr, arena)
2817        })
2818    }
2819
2820    pub(super) fn write_literal(&mut self, literal: crate::Literal) -> BackendResult {
2821        match literal {
2822            crate::Literal::F64(value) => write!(self.out, "{value:?}L")?,
2823            crate::Literal::F32(value) => write!(self.out, "{value:?}")?,
2824            crate::Literal::F16(value) => write!(self.out, "{value:?}h")?,
2825            crate::Literal::U32(value) => write!(self.out, "{value}u")?,
2826            // `-2147483648` is parsed by some compilers as unary negation of
2827            // positive 2147483648, which is too large for an int, causing
2828            // issues for some compilers. Neither DXC nor FXC appear to have
2829            // this problem, but this is not specified and could change. We
2830            // therefore use `-2147483647 - 1` as a precaution.
2831            crate::Literal::I32(value) if value == i32::MIN => {
2832                write!(self.out, "int({} - 1)", value + 1)?
2833            }
2834            // HLSL has no suffix for explicit i32 literals, but not using any suffix
2835            // makes the type ambiguous which prevents overload resolution from
2836            // working. So we explicitly use the int() constructor syntax.
2837            crate::Literal::I32(value) => write!(self.out, "int({value})")?,
2838            crate::Literal::U64(value) => write!(self.out, "{value}uL")?,
2839            // I64 version of the minimum I32 value issue described above.
2840            crate::Literal::I64(value) if value == i64::MIN => {
2841                write!(self.out, "({}L - 1L)", value + 1)?;
2842            }
2843            crate::Literal::I64(value) => write!(self.out, "{value}L")?,
2844            crate::Literal::Bool(value) => write!(self.out, "{value}")?,
2845            crate::Literal::AbstractInt(_) | crate::Literal::AbstractFloat(_) => {
2846                return Err(Error::Custom(
2847                    "Abstract types should not appear in IR presented to backends".into(),
2848                ));
2849            }
2850        }
2851        Ok(())
2852    }
2853
2854    fn write_possibly_const_expression<E>(
2855        &mut self,
2856        module: &Module,
2857        expr: Handle<crate::Expression>,
2858        expressions: &crate::Arena<crate::Expression>,
2859        write_expression: E,
2860    ) -> BackendResult
2861    where
2862        E: Fn(&mut Self, Handle<crate::Expression>) -> BackendResult,
2863    {
2864        use crate::Expression;
2865
2866        match expressions[expr] {
2867            Expression::Literal(literal) => {
2868                self.write_literal(literal)?;
2869            }
2870            Expression::Constant(handle) => {
2871                let constant = &module.constants[handle];
2872                if constant.name.is_some() {
2873                    write!(self.out, "{}", self.names[&NameKey::Constant(handle)])?;
2874                } else {
2875                    self.write_const_expression(module, constant.init, &module.global_expressions)?;
2876                }
2877            }
2878            Expression::ZeroValue(ty) => {
2879                self.write_wrapped_zero_value_function_name(module, WrappedZeroValue { ty })?;
2880                write!(self.out, "()")?;
2881            }
2882            Expression::Compose { ty, ref components } => {
2883                match module.types[ty].inner {
2884                    TypeInner::Struct { .. } | TypeInner::Array { .. } => {
2885                        self.write_wrapped_constructor_function_name(
2886                            module,
2887                            WrappedConstructor { ty },
2888                        )?;
2889                    }
2890                    _ => {
2891                        self.write_type(module, ty)?;
2892                    }
2893                };
2894                write!(self.out, "(")?;
2895                for (index, component) in components.iter().enumerate() {
2896                    if index != 0 {
2897                        write!(self.out, ", ")?;
2898                    }
2899                    write_expression(self, *component)?;
2900                }
2901                write!(self.out, ")")?;
2902            }
2903            Expression::Splat { size, value } => {
2904                // hlsl is not supported one value constructor
2905                // if we write, for example, int4(0), dxc returns error:
2906                // error: too few elements in vector initialization (expected 4 elements, have 1)
2907                let number_of_components = match size {
2908                    crate::VectorSize::Bi => "xx",
2909                    crate::VectorSize::Tri => "xxx",
2910                    crate::VectorSize::Quad => "xxxx",
2911                };
2912                write!(self.out, "(")?;
2913                write_expression(self, value)?;
2914                write!(self.out, ").{number_of_components}")?
2915            }
2916            _ => {
2917                return Err(Error::Override);
2918            }
2919        }
2920
2921        Ok(())
2922    }
2923
2924    /// Helper method to write expressions
2925    ///
2926    /// # Notes
2927    /// Doesn't add any newlines or leading/trailing spaces
2928    pub(super) fn write_expr(
2929        &mut self,
2930        module: &Module,
2931        expr: Handle<crate::Expression>,
2932        func_ctx: &back::FunctionCtx<'_>,
2933    ) -> BackendResult {
2934        use crate::Expression;
2935
2936        // Handle the special semantics of vertex_index/instance_index
2937        let ff_input = if self.options.special_constants_binding.is_some() {
2938            func_ctx.is_fixed_function_input(expr, module)
2939        } else {
2940            None
2941        };
2942        let closing_bracket = match ff_input {
2943            Some(crate::BuiltIn::VertexIndex) => {
2944                write!(self.out, "({SPECIAL_CBUF_VAR}.{SPECIAL_FIRST_VERTEX} + ")?;
2945                ")"
2946            }
2947            Some(crate::BuiltIn::InstanceIndex) => {
2948                write!(self.out, "({SPECIAL_CBUF_VAR}.{SPECIAL_FIRST_INSTANCE} + ",)?;
2949                ")"
2950            }
2951            Some(crate::BuiltIn::NumWorkGroups) => {
2952                // Note: despite their names (`FIRST_VERTEX` and `FIRST_INSTANCE`),
2953                // in compute shaders the special constants contain the number
2954                // of workgroups, which we are using here.
2955                write!(
2956                    self.out,
2957                    "uint3({SPECIAL_CBUF_VAR}.{SPECIAL_FIRST_VERTEX}, {SPECIAL_CBUF_VAR}.{SPECIAL_FIRST_INSTANCE}, {SPECIAL_CBUF_VAR}.{SPECIAL_OTHER})",
2958                )?;
2959                return Ok(());
2960            }
2961            _ => "",
2962        };
2963
2964        if let Some(name) = self.named_expressions.get(&expr) {
2965            write!(self.out, "{name}{closing_bracket}")?;
2966            return Ok(());
2967        }
2968
2969        let expression = &func_ctx.expressions[expr];
2970
2971        match *expression {
2972            Expression::Literal(_)
2973            | Expression::Constant(_)
2974            | Expression::ZeroValue(_)
2975            | Expression::Compose { .. }
2976            | Expression::Splat { .. } => {
2977                self.write_possibly_const_expression(
2978                    module,
2979                    expr,
2980                    func_ctx.expressions,
2981                    |writer, expr| writer.write_expr(module, expr, func_ctx),
2982                )?;
2983            }
2984            Expression::Override(_) => return Err(Error::Override),
2985            // Avoid undefined behaviour for addition, subtraction, and
2986            // multiplication of signed integers by casting operands to
2987            // unsigned, performing the operation, then casting the result back
2988            // to signed.
2989            // TODO(#7109): This relies on the asint()/asuint() functions which only work
2990            // for 32-bit types, so we must find another solution for different bit widths.
2991            Expression::Binary {
2992                op:
2993                    op @ crate::BinaryOperator::Add
2994                    | op @ crate::BinaryOperator::Subtract
2995                    | op @ crate::BinaryOperator::Multiply,
2996                left,
2997                right,
2998            } if matches!(
2999                func_ctx.resolve_type(expr, &module.types).scalar(),
3000                Some(Scalar::I32)
3001            ) =>
3002            {
3003                write!(self.out, "asint(asuint(",)?;
3004                self.write_expr(module, left, func_ctx)?;
3005                write!(self.out, ") {} asuint(", back::binary_operation_str(op))?;
3006                self.write_expr(module, right, func_ctx)?;
3007                write!(self.out, "))")?;
3008            }
3009            // All of the multiplication can be expressed as `mul`,
3010            // except vector * vector, which needs to use the "*" operator.
3011            Expression::Binary {
3012                op: crate::BinaryOperator::Multiply,
3013                left,
3014                right,
3015            } if func_ctx.resolve_type(left, &module.types).is_matrix()
3016                || func_ctx.resolve_type(right, &module.types).is_matrix() =>
3017            {
3018                // We intentionally flip the order of multiplication as our matrices are implicitly transposed.
3019                write!(self.out, "mul(")?;
3020                self.write_expr(module, right, func_ctx)?;
3021                write!(self.out, ", ")?;
3022                self.write_expr(module, left, func_ctx)?;
3023                write!(self.out, ")")?;
3024            }
3025
3026            // WGSL says that floating-point division by zero should return
3027            // infinity. Microsoft's Direct3D 11 functional specification
3028            // (https://microsoft.github.io/DirectX-Specs/d3d/archive/D3D11_3_FunctionalSpec.htm)
3029            // says:
3030            //
3031            //     Divide by 0 produces +/- INF, except 0/0 which results in NaN.
3032            //
3033            // which is what we want. The DXIL specification for the FDiv
3034            // instruction corroborates this:
3035            //
3036            // https://github.com/microsoft/DirectXShaderCompiler/blob/main/docs/DXIL.rst#fdiv
3037            Expression::Binary {
3038                op: crate::BinaryOperator::Divide,
3039                left,
3040                right,
3041            } if matches!(
3042                func_ctx.resolve_type(expr, &module.types).scalar_kind(),
3043                Some(ScalarKind::Sint | ScalarKind::Uint)
3044            ) =>
3045            {
3046                write!(self.out, "{DIV_FUNCTION}(")?;
3047                self.write_expr(module, left, func_ctx)?;
3048                write!(self.out, ", ")?;
3049                self.write_expr(module, right, func_ctx)?;
3050                write!(self.out, ")")?;
3051            }
3052
3053            Expression::Binary {
3054                op: crate::BinaryOperator::Modulo,
3055                left,
3056                right,
3057            } if matches!(
3058                func_ctx.resolve_type(expr, &module.types).scalar_kind(),
3059                Some(ScalarKind::Sint | ScalarKind::Uint | ScalarKind::Float)
3060            ) =>
3061            {
3062                write!(self.out, "{MOD_FUNCTION}(")?;
3063                self.write_expr(module, left, func_ctx)?;
3064                write!(self.out, ", ")?;
3065                self.write_expr(module, right, func_ctx)?;
3066                write!(self.out, ")")?;
3067            }
3068
3069            Expression::Binary { op, left, right } => {
3070                write!(self.out, "(")?;
3071                self.write_expr(module, left, func_ctx)?;
3072                write!(self.out, " {} ", back::binary_operation_str(op))?;
3073                self.write_expr(module, right, func_ctx)?;
3074                write!(self.out, ")")?;
3075            }
3076            Expression::Access { base, index } => {
3077                if let Some(crate::AddressSpace::Storage { .. }) =
3078                    func_ctx.resolve_type(expr, &module.types).pointer_space()
3079                {
3080                    // do nothing, the chain is written on `Load`/`Store`
3081                } else {
3082                    // We use the function __get_col_of_matCx2 here in cases
3083                    // where `base`s type resolves to a matCx2 and is part of a
3084                    // struct member with type of (possibly nested) array of matCx2's.
3085                    //
3086                    // Note that this only works for `Load`s and we handle
3087                    // `Store`s differently in `Statement::Store`.
3088                    if let Some(MatrixType {
3089                        columns,
3090                        rows: crate::VectorSize::Bi,
3091                        width: 4,
3092                    }) = get_inner_matrix_of_struct_array_member(module, base, func_ctx, true)
3093                        .or_else(|| get_global_uniform_matrix(module, base, func_ctx))
3094                    {
3095                        write!(self.out, "__get_col_of_mat{}x2(", columns as u8)?;
3096                        self.write_expr(module, base, func_ctx)?;
3097                        write!(self.out, ", ")?;
3098                        self.write_expr(module, index, func_ctx)?;
3099                        write!(self.out, ")")?;
3100                        return Ok(());
3101                    }
3102
3103                    let resolved = func_ctx.resolve_type(base, &module.types);
3104
3105                    let (indexing_binding_array, non_uniform_qualifier) = match *resolved {
3106                        TypeInner::BindingArray { .. } => {
3107                            let uniformity = &func_ctx.info[index].uniformity;
3108
3109                            (true, uniformity.non_uniform_result.is_some())
3110                        }
3111                        _ => (false, false),
3112                    };
3113
3114                    self.write_expr(module, base, func_ctx)?;
3115
3116                    let array_sampler_info = self.sampler_binding_array_info_from_expression(
3117                        module, func_ctx, base, resolved,
3118                    );
3119
3120                    if let Some(ref info) = array_sampler_info {
3121                        write!(self.out, "{}[", info.sampler_heap_name)?;
3122                    } else {
3123                        write!(self.out, "[")?;
3124                    }
3125
3126                    let needs_bound_check = self.options.restrict_indexing
3127                        && !indexing_binding_array
3128                        && match resolved.pointer_space() {
3129                            Some(
3130                                crate::AddressSpace::Function
3131                                | crate::AddressSpace::Private
3132                                | crate::AddressSpace::WorkGroup
3133                                | crate::AddressSpace::Immediate
3134                                | crate::AddressSpace::TaskPayload,
3135                            )
3136                            | None => true,
3137                            Some(crate::AddressSpace::Uniform) => {
3138                                // check if BindTarget.restrict_indexing is set, this is used for dynamic buffers
3139                                let var_handle = self.fill_access_chain(module, base, func_ctx)?;
3140                                let bind_target = self
3141                                    .options
3142                                    .resolve_resource_binding(
3143                                        module.global_variables[var_handle]
3144                                            .binding
3145                                            .as_ref()
3146                                            .unwrap(),
3147                                    )
3148                                    .unwrap();
3149                                bind_target.restrict_indexing
3150                            }
3151                            Some(
3152                                crate::AddressSpace::Handle | crate::AddressSpace::Storage { .. },
3153                            ) => unreachable!(),
3154                        };
3155                    // Decide whether this index needs to be clamped to fall within range.
3156                    let restriction_needed = if needs_bound_check {
3157                        index::access_needs_check(
3158                            base,
3159                            index::GuardedIndex::Expression(index),
3160                            module,
3161                            func_ctx.expressions,
3162                            func_ctx.info,
3163                        )
3164                    } else {
3165                        None
3166                    };
3167                    if let Some(limit) = restriction_needed {
3168                        write!(self.out, "min(uint(")?;
3169                        self.write_expr(module, index, func_ctx)?;
3170                        write!(self.out, "), ")?;
3171                        match limit {
3172                            index::IndexableLength::Known(limit) => {
3173                                write!(self.out, "{}u", limit - 1)?;
3174                            }
3175                            index::IndexableLength::Dynamic => unreachable!(),
3176                        }
3177                        write!(self.out, ")")?;
3178                    } else {
3179                        if non_uniform_qualifier {
3180                            write!(self.out, "NonUniformResourceIndex(")?;
3181                        }
3182                        if let Some(ref info) = array_sampler_info {
3183                            write!(
3184                                self.out,
3185                                "{}[{} + ",
3186                                info.sampler_index_buffer_name, info.binding_array_base_index_name,
3187                            )?;
3188                        }
3189                        self.write_expr(module, index, func_ctx)?;
3190                        if array_sampler_info.is_some() {
3191                            write!(self.out, "]")?;
3192                        }
3193                        if non_uniform_qualifier {
3194                            write!(self.out, ")")?;
3195                        }
3196                    }
3197
3198                    write!(self.out, "]")?;
3199                }
3200            }
3201            Expression::AccessIndex { base, index } => {
3202                if let Some(crate::AddressSpace::Storage { .. }) =
3203                    func_ctx.resolve_type(expr, &module.types).pointer_space()
3204                {
3205                    // do nothing, the chain is written on `Load`/`Store`
3206                } else {
3207                    // See if we need to write the matrix column access in a
3208                    // special way since the type of `base` is our special
3209                    // __matCx2 struct.
3210                    if let Some(MatrixType {
3211                        rows: crate::VectorSize::Bi,
3212                        width: 4,
3213                        ..
3214                    }) = get_inner_matrix_of_struct_array_member(module, base, func_ctx, true)
3215                        .or_else(|| get_global_uniform_matrix(module, base, func_ctx))
3216                    {
3217                        self.write_expr(module, base, func_ctx)?;
3218                        write!(self.out, "._{index}")?;
3219                        return Ok(());
3220                    }
3221
3222                    let base_ty_res = &func_ctx.info[base].ty;
3223                    let mut resolved = base_ty_res.inner_with(&module.types);
3224                    let base_ty_handle = match *resolved {
3225                        TypeInner::Pointer { base, .. } => {
3226                            resolved = &module.types[base].inner;
3227                            Some(base)
3228                        }
3229                        _ => base_ty_res.handle(),
3230                    };
3231
3232                    // We treat matrices of the form `matCx2` as a sequence of C `vec2`s.
3233                    // See the module-level block comment in mod.rs for details.
3234                    //
3235                    // We handle matrix reconstruction here for Loads.
3236                    // Stores are handled directly by `Statement::Store`.
3237                    if let TypeInner::Struct { ref members, .. } = *resolved {
3238                        let member = &members[index as usize];
3239
3240                        match module.types[member.ty].inner {
3241                            TypeInner::Matrix {
3242                                rows: crate::VectorSize::Bi,
3243                                ..
3244                            } if member.binding.is_none() => {
3245                                let ty = base_ty_handle.unwrap();
3246                                self.write_wrapped_struct_matrix_get_function_name(
3247                                    WrappedStructMatrixAccess { ty, index },
3248                                )?;
3249                                write!(self.out, "(")?;
3250                                self.write_expr(module, base, func_ctx)?;
3251                                write!(self.out, ")")?;
3252                                return Ok(());
3253                            }
3254                            _ => {}
3255                        }
3256                    }
3257
3258                    let array_sampler_info = self.sampler_binding_array_info_from_expression(
3259                        module, func_ctx, base, resolved,
3260                    );
3261
3262                    if let Some(ref info) = array_sampler_info {
3263                        write!(
3264                            self.out,
3265                            "{}[{}",
3266                            info.sampler_heap_name, info.sampler_index_buffer_name
3267                        )?;
3268                    }
3269
3270                    self.write_expr(module, base, func_ctx)?;
3271
3272                    match *resolved {
3273                        // We specifically lift the ValuePointer to this case. While `[0]` is valid
3274                        // HLSL for any vector behind a value pointer, FXC completely miscompiles
3275                        // it and generates completely nonsensical DXBC.
3276                        //
3277                        // See https://github.com/gfx-rs/naga/issues/2095 for more details.
3278                        TypeInner::Vector { .. } | TypeInner::ValuePointer { .. } => {
3279                            // Write vector access as a swizzle
3280                            write!(self.out, ".{}", back::COMPONENTS[index as usize])?
3281                        }
3282                        TypeInner::Matrix { .. }
3283                        | TypeInner::Array { .. }
3284                        | TypeInner::BindingArray { .. } => {
3285                            if let Some(ref info) = array_sampler_info {
3286                                write!(
3287                                    self.out,
3288                                    "[{} + {index}]",
3289                                    info.binding_array_base_index_name
3290                                )?;
3291                            } else {
3292                                write!(self.out, "[{index}]")?;
3293                            }
3294                        }
3295                        TypeInner::Struct { .. } => {
3296                            // This will never panic in case the type is a `Struct`, this is not true
3297                            // for other types so we can only check while inside this match arm
3298                            let ty = base_ty_handle.unwrap();
3299
3300                            write!(
3301                                self.out,
3302                                ".{}",
3303                                &self.names[&NameKey::StructMember(ty, index)]
3304                            )?
3305                        }
3306                        ref other => return Err(Error::Custom(format!("Cannot index {other:?}"))),
3307                    }
3308
3309                    if array_sampler_info.is_some() {
3310                        write!(self.out, "]")?;
3311                    }
3312                }
3313            }
3314            Expression::FunctionArgument(pos) => {
3315                let ty = func_ctx.resolve_type(expr, &module.types);
3316
3317                // We know that any external texture function argument has been expanded into
3318                // separate consecutive arguments for each plane and the parameters buffer. And we
3319                // also know that external textures can only ever be used as an argument to another
3320                // function. Therefore we can simply emit each of the expanded arguments in a
3321                // consecutive comma-separated list.
3322                if let TypeInner::Image {
3323                    class: crate::ImageClass::External,
3324                    ..
3325                } = *ty
3326                {
3327                    let plane_names = [0, 1, 2].map(|i| {
3328                        &self.names[&func_ctx
3329                            .external_texture_argument_key(pos, ExternalTextureNameKey::Plane(i))]
3330                    });
3331                    let params_name = &self.names[&func_ctx
3332                        .external_texture_argument_key(pos, ExternalTextureNameKey::Params)];
3333                    write!(
3334                        self.out,
3335                        "{}, {}, {}, {}",
3336                        plane_names[0], plane_names[1], plane_names[2], params_name
3337                    )?;
3338                } else {
3339                    let key = func_ctx.argument_key(pos);
3340                    let name = &self.names[&key];
3341                    write!(self.out, "{name}")?;
3342                }
3343            }
3344            Expression::ImageSample {
3345                coordinate,
3346                image,
3347                sampler,
3348                clamp_to_edge: true,
3349                gather: None,
3350                array_index: None,
3351                offset: None,
3352                level: crate::SampleLevel::Zero,
3353                depth_ref: None,
3354            } => {
3355                write!(self.out, "{IMAGE_SAMPLE_BASE_CLAMP_TO_EDGE_FUNCTION}(")?;
3356                self.write_expr(module, image, func_ctx)?;
3357                write!(self.out, ", ")?;
3358                self.write_expr(module, sampler, func_ctx)?;
3359                write!(self.out, ", ")?;
3360                self.write_expr(module, coordinate, func_ctx)?;
3361                write!(self.out, ")")?;
3362            }
3363            Expression::ImageSample {
3364                image,
3365                sampler,
3366                gather,
3367                coordinate,
3368                array_index,
3369                offset,
3370                level,
3371                depth_ref,
3372                clamp_to_edge,
3373            } => {
3374                if clamp_to_edge {
3375                    return Err(Error::Custom(
3376                        "ImageSample::clamp_to_edge should have been validated out".to_string(),
3377                    ));
3378                }
3379
3380                use crate::SampleLevel as Sl;
3381                const COMPONENTS: [&str; 4] = ["", "Green", "Blue", "Alpha"];
3382
3383                let (base_str, component_str) = match gather {
3384                    Some(component) => ("Gather", COMPONENTS[component as usize]),
3385                    None => ("Sample", ""),
3386                };
3387                let cmp_str = match depth_ref {
3388                    Some(_) => "Cmp",
3389                    None => "",
3390                };
3391                let level_str = match level {
3392                    Sl::Zero if gather.is_none() => "LevelZero",
3393                    Sl::Auto | Sl::Zero => "",
3394                    Sl::Exact(_) => "Level",
3395                    Sl::Bias(_) => "Bias",
3396                    Sl::Gradient { .. } => "Grad",
3397                };
3398
3399                self.write_expr(module, image, func_ctx)?;
3400                write!(self.out, ".{base_str}{cmp_str}{component_str}{level_str}(")?;
3401                self.write_expr(module, sampler, func_ctx)?;
3402                write!(self.out, ", ")?;
3403                self.write_texture_coordinates(
3404                    "float",
3405                    coordinate,
3406                    array_index,
3407                    None,
3408                    module,
3409                    func_ctx,
3410                )?;
3411
3412                if let Some(depth_ref) = depth_ref {
3413                    write!(self.out, ", ")?;
3414                    self.write_expr(module, depth_ref, func_ctx)?;
3415                }
3416
3417                match level {
3418                    Sl::Auto | Sl::Zero => {}
3419                    Sl::Exact(expr) => {
3420                        write!(self.out, ", ")?;
3421                        self.write_expr(module, expr, func_ctx)?;
3422                    }
3423                    Sl::Bias(expr) => {
3424                        write!(self.out, ", ")?;
3425                        self.write_expr(module, expr, func_ctx)?;
3426                    }
3427                    Sl::Gradient { x, y } => {
3428                        write!(self.out, ", ")?;
3429                        self.write_expr(module, x, func_ctx)?;
3430                        write!(self.out, ", ")?;
3431                        self.write_expr(module, y, func_ctx)?;
3432                    }
3433                }
3434
3435                if let Some(offset) = offset {
3436                    write!(self.out, ", ")?;
3437                    write!(self.out, "int2(")?; // work around https://github.com/microsoft/DirectXShaderCompiler/issues/5082#issuecomment-1540147807
3438                    self.write_const_expression(module, offset, func_ctx.expressions)?;
3439                    write!(self.out, ")")?;
3440                }
3441
3442                write!(self.out, ")")?;
3443            }
3444            Expression::ImageQuery { image, query } => {
3445                // use wrapped image query function
3446                if let TypeInner::Image {
3447                    dim,
3448                    arrayed,
3449                    class,
3450                } = *func_ctx.resolve_type(image, &module.types)
3451                {
3452                    let wrapped_image_query = WrappedImageQuery {
3453                        dim,
3454                        arrayed,
3455                        class,
3456                        query: query.into(),
3457                    };
3458
3459                    self.write_wrapped_image_query_function_name(wrapped_image_query)?;
3460                    write!(self.out, "(")?;
3461                    // Image always first param
3462                    self.write_expr(module, image, func_ctx)?;
3463                    if let crate::ImageQuery::Size { level: Some(level) } = query {
3464                        write!(self.out, ", ")?;
3465                        self.write_expr(module, level, func_ctx)?;
3466                    }
3467                    write!(self.out, ")")?;
3468                }
3469            }
3470            Expression::ImageLoad {
3471                image,
3472                coordinate,
3473                array_index,
3474                sample,
3475                level,
3476            } => self.write_image_load(
3477                &module,
3478                expr,
3479                func_ctx,
3480                image,
3481                coordinate,
3482                array_index,
3483                sample,
3484                level,
3485            )?,
3486            Expression::GlobalVariable(handle) => {
3487                let global_variable = &module.global_variables[handle];
3488                let ty = &module.types[global_variable.ty].inner;
3489
3490                // In the case of binding arrays of samplers, we need to not write anything
3491                // as the we are in the wrong position to fully write the expression.
3492                //
3493                // The entire writing is done by AccessIndex.
3494                let is_binding_array_of_samplers = match *ty {
3495                    TypeInner::BindingArray { base, .. } => {
3496                        let base_ty = &module.types[base].inner;
3497                        matches!(*base_ty, TypeInner::Sampler { .. })
3498                    }
3499                    _ => false,
3500                };
3501
3502                let is_storage_space =
3503                    matches!(global_variable.space, crate::AddressSpace::Storage { .. });
3504
3505                // Our external texture global variable has been expanded into multiple
3506                // global variables, one for each plane and the parameters buffer.
3507                // External textures can only ever be used as arguments to a function
3508                // call, and we know that an external texture argument to any function
3509                // will have been expanded to separate consecutive arguments for each
3510                // plane and the parameters buffer. Therefore we can simply emit each of
3511                // the expanded global variables in a consecutive comma-separated list.
3512                if let TypeInner::Image {
3513                    class: crate::ImageClass::External,
3514                    ..
3515                } = *ty
3516                {
3517                    let plane_names = [0, 1, 2].map(|i| {
3518                        &self.names[&NameKey::ExternalTextureGlobalVariable(
3519                            handle,
3520                            ExternalTextureNameKey::Plane(i),
3521                        )]
3522                    });
3523                    let params_name = &self.names[&NameKey::ExternalTextureGlobalVariable(
3524                        handle,
3525                        ExternalTextureNameKey::Params,
3526                    )];
3527                    write!(
3528                        self.out,
3529                        "{}, {}, {}, {}",
3530                        plane_names[0], plane_names[1], plane_names[2], params_name
3531                    )?;
3532                } else if !is_binding_array_of_samplers && !is_storage_space {
3533                    let name = &self.names[&NameKey::GlobalVariable(handle)];
3534                    write!(self.out, "{name}")?;
3535                }
3536            }
3537            Expression::LocalVariable(handle) => {
3538                write!(self.out, "{}", self.names[&func_ctx.name_key(handle)])?
3539            }
3540            Expression::Load { pointer } => {
3541                match func_ctx
3542                    .resolve_type(pointer, &module.types)
3543                    .pointer_space()
3544                {
3545                    Some(crate::AddressSpace::Storage { .. }) => {
3546                        let var_handle = self.fill_access_chain(module, pointer, func_ctx)?;
3547                        let result_ty = func_ctx.info[expr].ty.clone();
3548                        self.write_storage_load(module, var_handle, result_ty, func_ctx)?;
3549                    }
3550                    _ => {
3551                        let mut close_paren = false;
3552
3553                        // We cast the value loaded to a native HLSL floatCx2
3554                        // in cases where it is of type:
3555                        //  - __matCx2 or
3556                        //  - a (possibly nested) array of __matCx2's
3557                        if let Some(MatrixType {
3558                            rows: crate::VectorSize::Bi,
3559                            width: 4,
3560                            ..
3561                        }) = get_inner_matrix_of_struct_array_member(
3562                            module, pointer, func_ctx, false,
3563                        )
3564                        .or_else(|| get_inner_matrix_of_global_uniform(module, pointer, func_ctx))
3565                        {
3566                            let mut resolved = func_ctx.resolve_type(pointer, &module.types);
3567                            let ptr_tr = resolved.pointer_base_type();
3568                            if let Some(ptr_ty) =
3569                                ptr_tr.as_ref().map(|tr| tr.inner_with(&module.types))
3570                            {
3571                                resolved = ptr_ty;
3572                            }
3573
3574                            write!(self.out, "((")?;
3575                            if let TypeInner::Array { base, size, .. } = *resolved {
3576                                self.write_type(module, base)?;
3577                                self.write_array_size(module, base, size)?;
3578                            } else {
3579                                self.write_value_type(module, resolved)?;
3580                            }
3581                            write!(self.out, ")")?;
3582                            close_paren = true;
3583                        }
3584
3585                        self.write_expr(module, pointer, func_ctx)?;
3586
3587                        if close_paren {
3588                            write!(self.out, ")")?;
3589                        }
3590                    }
3591                }
3592            }
3593            Expression::Unary { op, expr } => {
3594                // https://docs.microsoft.com/en-us/windows/win32/direct3dhlsl/dx-graphics-hlsl-operators#unary-operators
3595                let op_str = match op {
3596                    crate::UnaryOperator::Negate => {
3597                        match func_ctx.resolve_type(expr, &module.types).scalar() {
3598                            Some(Scalar::I32) => NEG_FUNCTION,
3599                            _ => "-",
3600                        }
3601                    }
3602                    crate::UnaryOperator::LogicalNot => "!",
3603                    crate::UnaryOperator::BitwiseNot => "~",
3604                };
3605                write!(self.out, "{op_str}(")?;
3606                self.write_expr(module, expr, func_ctx)?;
3607                write!(self.out, ")")?;
3608            }
3609            Expression::As {
3610                expr,
3611                kind,
3612                convert,
3613            } => {
3614                let inner = func_ctx.resolve_type(expr, &module.types);
3615                if inner.scalar_kind() == Some(ScalarKind::Float)
3616                    && (kind == ScalarKind::Sint || kind == ScalarKind::Uint)
3617                    && convert.is_some()
3618                {
3619                    // Use helper functions for float to int casts in order to
3620                    // avoid undefined behaviour when value is out of range for
3621                    // the target type.
3622                    let fun_name = match (kind, convert) {
3623                        (ScalarKind::Sint, Some(4)) => F2I32_FUNCTION,
3624                        (ScalarKind::Uint, Some(4)) => F2U32_FUNCTION,
3625                        (ScalarKind::Sint, Some(8)) => F2I64_FUNCTION,
3626                        (ScalarKind::Uint, Some(8)) => F2U64_FUNCTION,
3627                        _ => unreachable!(),
3628                    };
3629                    write!(self.out, "{fun_name}(")?;
3630                    self.write_expr(module, expr, func_ctx)?;
3631                    write!(self.out, ")")?;
3632                } else {
3633                    let close_paren = match convert {
3634                        Some(dst_width) => {
3635                            let scalar = Scalar {
3636                                kind,
3637                                width: dst_width,
3638                            };
3639                            match *inner {
3640                                TypeInner::Vector { size, .. } => {
3641                                    write!(
3642                                        self.out,
3643                                        "{}{}(",
3644                                        scalar.to_hlsl_str()?,
3645                                        common::vector_size_str(size)
3646                                    )?;
3647                                }
3648                                TypeInner::Scalar(_) => {
3649                                    write!(self.out, "{}(", scalar.to_hlsl_str()?,)?;
3650                                }
3651                                TypeInner::Matrix { columns, rows, .. } => {
3652                                    write!(
3653                                        self.out,
3654                                        "{}{}x{}(",
3655                                        scalar.to_hlsl_str()?,
3656                                        common::vector_size_str(columns),
3657                                        common::vector_size_str(rows)
3658                                    )?;
3659                                }
3660                                _ => {
3661                                    return Err(Error::Unimplemented(format!(
3662                                        "write_expr expression::as {inner:?}"
3663                                    )));
3664                                }
3665                            };
3666                            true
3667                        }
3668                        None => {
3669                            if inner.scalar_width() == Some(8) {
3670                                false
3671                            } else {
3672                                write!(self.out, "{}(", kind.to_hlsl_cast(),)?;
3673                                true
3674                            }
3675                        }
3676                    };
3677                    self.write_expr(module, expr, func_ctx)?;
3678                    if close_paren {
3679                        write!(self.out, ")")?;
3680                    }
3681                }
3682            }
3683            Expression::Math {
3684                fun,
3685                arg,
3686                arg1,
3687                arg2,
3688                arg3,
3689            } => {
3690                use crate::MathFunction as Mf;
3691
3692                enum Function {
3693                    Asincosh { is_sin: bool },
3694                    Atanh,
3695                    Pack2x16float,
3696                    Pack2x16snorm,
3697                    Pack2x16unorm,
3698                    Pack4x8snorm,
3699                    Pack4x8unorm,
3700                    Pack4xI8,
3701                    Pack4xU8,
3702                    Pack4xI8Clamp,
3703                    Pack4xU8Clamp,
3704                    Unpack2x16float,
3705                    Unpack2x16snorm,
3706                    Unpack2x16unorm,
3707                    Unpack4x8snorm,
3708                    Unpack4x8unorm,
3709                    Unpack4xI8,
3710                    Unpack4xU8,
3711                    Dot4I8Packed,
3712                    Dot4U8Packed,
3713                    QuantizeToF16,
3714                    Regular(&'static str),
3715                    MissingIntOverload(&'static str),
3716                    MissingIntReturnType(&'static str),
3717                    CountTrailingZeros,
3718                    CountLeadingZeros,
3719                }
3720
3721                let fun = match fun {
3722                    // comparison
3723                    Mf::Abs => match func_ctx.resolve_type(arg, &module.types).scalar() {
3724                        Some(Scalar::I32) => Function::Regular(ABS_FUNCTION),
3725                        _ => Function::Regular("abs"),
3726                    },
3727                    Mf::Min => Function::Regular("min"),
3728                    Mf::Max => Function::Regular("max"),
3729                    Mf::Clamp => Function::Regular("clamp"),
3730                    Mf::Saturate => Function::Regular("saturate"),
3731                    // trigonometry
3732                    Mf::Cos => Function::Regular("cos"),
3733                    Mf::Cosh => Function::Regular("cosh"),
3734                    Mf::Sin => Function::Regular("sin"),
3735                    Mf::Sinh => Function::Regular("sinh"),
3736                    Mf::Tan => Function::Regular("tan"),
3737                    Mf::Tanh => Function::Regular("tanh"),
3738                    Mf::Acos => Function::Regular("acos"),
3739                    Mf::Asin => Function::Regular("asin"),
3740                    Mf::Atan => Function::Regular("atan"),
3741                    Mf::Atan2 => Function::Regular("atan2"),
3742                    Mf::Asinh => Function::Asincosh { is_sin: true },
3743                    Mf::Acosh => Function::Asincosh { is_sin: false },
3744                    Mf::Atanh => Function::Atanh,
3745                    Mf::Radians => Function::Regular("radians"),
3746                    Mf::Degrees => Function::Regular("degrees"),
3747                    // decomposition
3748                    Mf::Ceil => Function::Regular("ceil"),
3749                    Mf::Floor => Function::Regular("floor"),
3750                    Mf::Round => Function::Regular("round"),
3751                    Mf::Fract => Function::Regular("frac"),
3752                    Mf::Trunc => Function::Regular("trunc"),
3753                    Mf::Modf => Function::Regular(MODF_FUNCTION),
3754                    Mf::Frexp => Function::Regular(FREXP_FUNCTION),
3755                    Mf::Ldexp => Function::Regular("ldexp"),
3756                    // exponent
3757                    Mf::Exp => Function::Regular("exp"),
3758                    Mf::Exp2 => Function::Regular("exp2"),
3759                    Mf::Log => Function::Regular("log"),
3760                    Mf::Log2 => Function::Regular("log2"),
3761                    Mf::Pow => Function::Regular("pow"),
3762                    // geometry
3763                    Mf::Dot => Function::Regular("dot"),
3764                    Mf::Dot4I8Packed => Function::Dot4I8Packed,
3765                    Mf::Dot4U8Packed => Function::Dot4U8Packed,
3766                    //Mf::Outer => ,
3767                    Mf::Cross => Function::Regular("cross"),
3768                    Mf::Distance => Function::Regular("distance"),
3769                    Mf::Length => Function::Regular("length"),
3770                    Mf::Normalize => Function::Regular("normalize"),
3771                    Mf::FaceForward => Function::Regular("faceforward"),
3772                    Mf::Reflect => Function::Regular("reflect"),
3773                    Mf::Refract => Function::Regular("refract"),
3774                    // computational
3775                    Mf::Sign => Function::Regular("sign"),
3776                    Mf::Fma => Function::Regular("mad"),
3777                    Mf::Mix => Function::Regular("lerp"),
3778                    Mf::Step => Function::Regular("step"),
3779                    Mf::SmoothStep => Function::Regular("smoothstep"),
3780                    Mf::Sqrt => Function::Regular("sqrt"),
3781                    Mf::InverseSqrt => Function::Regular("rsqrt"),
3782                    //Mf::Inverse =>,
3783                    Mf::Transpose => Function::Regular("transpose"),
3784                    Mf::Determinant => Function::Regular("determinant"),
3785                    Mf::QuantizeToF16 => Function::QuantizeToF16,
3786                    // bits
3787                    Mf::CountTrailingZeros => Function::CountTrailingZeros,
3788                    Mf::CountLeadingZeros => Function::CountLeadingZeros,
3789                    Mf::CountOneBits => Function::MissingIntOverload("countbits"),
3790                    Mf::ReverseBits => Function::MissingIntOverload("reversebits"),
3791                    Mf::FirstTrailingBit => Function::MissingIntReturnType("firstbitlow"),
3792                    Mf::FirstLeadingBit => Function::MissingIntReturnType("firstbithigh"),
3793                    Mf::ExtractBits => Function::Regular(EXTRACT_BITS_FUNCTION),
3794                    Mf::InsertBits => Function::Regular(INSERT_BITS_FUNCTION),
3795                    // Data Packing
3796                    Mf::Pack2x16float => Function::Pack2x16float,
3797                    Mf::Pack2x16snorm => Function::Pack2x16snorm,
3798                    Mf::Pack2x16unorm => Function::Pack2x16unorm,
3799                    Mf::Pack4x8snorm => Function::Pack4x8snorm,
3800                    Mf::Pack4x8unorm => Function::Pack4x8unorm,
3801                    Mf::Pack4xI8 => Function::Pack4xI8,
3802                    Mf::Pack4xU8 => Function::Pack4xU8,
3803                    Mf::Pack4xI8Clamp => Function::Pack4xI8Clamp,
3804                    Mf::Pack4xU8Clamp => Function::Pack4xU8Clamp,
3805                    // Data Unpacking
3806                    Mf::Unpack2x16float => Function::Unpack2x16float,
3807                    Mf::Unpack2x16snorm => Function::Unpack2x16snorm,
3808                    Mf::Unpack2x16unorm => Function::Unpack2x16unorm,
3809                    Mf::Unpack4x8snorm => Function::Unpack4x8snorm,
3810                    Mf::Unpack4x8unorm => Function::Unpack4x8unorm,
3811                    Mf::Unpack4xI8 => Function::Unpack4xI8,
3812                    Mf::Unpack4xU8 => Function::Unpack4xU8,
3813                    _ => return Err(Error::Unimplemented(format!("write_expr_math {fun:?}"))),
3814                };
3815
3816                match fun {
3817                    Function::Asincosh { is_sin } => {
3818                        write!(self.out, "log(")?;
3819                        self.write_expr(module, arg, func_ctx)?;
3820                        write!(self.out, " + sqrt(")?;
3821                        self.write_expr(module, arg, func_ctx)?;
3822                        write!(self.out, " * ")?;
3823                        self.write_expr(module, arg, func_ctx)?;
3824                        match is_sin {
3825                            true => write!(self.out, " + 1.0))")?,
3826                            false => write!(self.out, " - 1.0))")?,
3827                        }
3828                    }
3829                    Function::Atanh => {
3830                        write!(self.out, "0.5 * log((1.0 + ")?;
3831                        self.write_expr(module, arg, func_ctx)?;
3832                        write!(self.out, ") / (1.0 - ")?;
3833                        self.write_expr(module, arg, func_ctx)?;
3834                        write!(self.out, "))")?;
3835                    }
3836                    Function::Pack2x16float => {
3837                        write!(self.out, "(f32tof16(")?;
3838                        self.write_expr(module, arg, func_ctx)?;
3839                        write!(self.out, "[0]) | f32tof16(")?;
3840                        self.write_expr(module, arg, func_ctx)?;
3841                        write!(self.out, "[1]) << 16)")?;
3842                    }
3843                    Function::Pack2x16snorm => {
3844                        let scale = 32767;
3845
3846                        write!(self.out, "uint((int(round(clamp(")?;
3847                        self.write_expr(module, arg, func_ctx)?;
3848                        write!(
3849                            self.out,
3850                            "[0], -1.0, 1.0) * {scale}.0)) & 0xFFFF) | ((int(round(clamp("
3851                        )?;
3852                        self.write_expr(module, arg, func_ctx)?;
3853                        write!(self.out, "[1], -1.0, 1.0) * {scale}.0)) & 0xFFFF) << 16))",)?;
3854                    }
3855                    Function::Pack2x16unorm => {
3856                        let scale = 65535;
3857
3858                        write!(self.out, "(uint(round(clamp(")?;
3859                        self.write_expr(module, arg, func_ctx)?;
3860                        write!(self.out, "[0], 0.0, 1.0) * {scale}.0)) | uint(round(clamp(")?;
3861                        self.write_expr(module, arg, func_ctx)?;
3862                        write!(self.out, "[1], 0.0, 1.0) * {scale}.0)) << 16)")?;
3863                    }
3864                    Function::Pack4x8snorm => {
3865                        let scale = 127;
3866
3867                        write!(self.out, "uint((int(round(clamp(")?;
3868                        self.write_expr(module, arg, func_ctx)?;
3869                        write!(
3870                            self.out,
3871                            "[0], -1.0, 1.0) * {scale}.0)) & 0xFF) | ((int(round(clamp("
3872                        )?;
3873                        self.write_expr(module, arg, func_ctx)?;
3874                        write!(
3875                            self.out,
3876                            "[1], -1.0, 1.0) * {scale}.0)) & 0xFF) << 8) | ((int(round(clamp("
3877                        )?;
3878                        self.write_expr(module, arg, func_ctx)?;
3879                        write!(
3880                            self.out,
3881                            "[2], -1.0, 1.0) * {scale}.0)) & 0xFF) << 16) | ((int(round(clamp("
3882                        )?;
3883                        self.write_expr(module, arg, func_ctx)?;
3884                        write!(self.out, "[3], -1.0, 1.0) * {scale}.0)) & 0xFF) << 24))",)?;
3885                    }
3886                    Function::Pack4x8unorm => {
3887                        let scale = 255;
3888
3889                        write!(self.out, "(uint(round(clamp(")?;
3890                        self.write_expr(module, arg, func_ctx)?;
3891                        write!(self.out, "[0], 0.0, 1.0) * {scale}.0)) | uint(round(clamp(")?;
3892                        self.write_expr(module, arg, func_ctx)?;
3893                        write!(
3894                            self.out,
3895                            "[1], 0.0, 1.0) * {scale}.0)) << 8 | uint(round(clamp("
3896                        )?;
3897                        self.write_expr(module, arg, func_ctx)?;
3898                        write!(
3899                            self.out,
3900                            "[2], 0.0, 1.0) * {scale}.0)) << 16 | uint(round(clamp("
3901                        )?;
3902                        self.write_expr(module, arg, func_ctx)?;
3903                        write!(self.out, "[3], 0.0, 1.0) * {scale}.0)) << 24)")?;
3904                    }
3905                    fun @ (Function::Pack4xI8
3906                    | Function::Pack4xU8
3907                    | Function::Pack4xI8Clamp
3908                    | Function::Pack4xU8Clamp) => {
3909                        let was_signed =
3910                            matches!(fun, Function::Pack4xI8 | Function::Pack4xI8Clamp);
3911                        let clamp_bounds = match fun {
3912                            Function::Pack4xI8Clamp => Some(("-128", "127")),
3913                            Function::Pack4xU8Clamp => Some(("0", "255")),
3914                            _ => None,
3915                        };
3916                        if was_signed {
3917                            write!(self.out, "uint(")?;
3918                        }
3919                        let write_arg = |this: &mut Self| -> BackendResult {
3920                            if let Some((min, max)) = clamp_bounds {
3921                                write!(this.out, "clamp(")?;
3922                                this.write_expr(module, arg, func_ctx)?;
3923                                write!(this.out, ", {min}, {max})")?;
3924                            } else {
3925                                this.write_expr(module, arg, func_ctx)?;
3926                            }
3927                            Ok(())
3928                        };
3929                        write!(self.out, "(")?;
3930                        write_arg(self)?;
3931                        write!(self.out, "[0] & 0xFF) | ((")?;
3932                        write_arg(self)?;
3933                        write!(self.out, "[1] & 0xFF) << 8) | ((")?;
3934                        write_arg(self)?;
3935                        write!(self.out, "[2] & 0xFF) << 16) | ((")?;
3936                        write_arg(self)?;
3937                        write!(self.out, "[3] & 0xFF) << 24)")?;
3938                        if was_signed {
3939                            write!(self.out, ")")?;
3940                        }
3941                    }
3942
3943                    Function::Unpack2x16float => {
3944                        write!(self.out, "float2(f16tof32(")?;
3945                        self.write_expr(module, arg, func_ctx)?;
3946                        write!(self.out, "), f16tof32((")?;
3947                        self.write_expr(module, arg, func_ctx)?;
3948                        write!(self.out, ") >> 16))")?;
3949                    }
3950                    Function::Unpack2x16snorm => {
3951                        let scale = 32767;
3952
3953                        write!(self.out, "(float2(int2(")?;
3954                        self.write_expr(module, arg, func_ctx)?;
3955                        write!(self.out, " << 16, ")?;
3956                        self.write_expr(module, arg, func_ctx)?;
3957                        write!(self.out, ") >> 16) / {scale}.0)")?;
3958                    }
3959                    Function::Unpack2x16unorm => {
3960                        let scale = 65535;
3961
3962                        write!(self.out, "(float2(")?;
3963                        self.write_expr(module, arg, func_ctx)?;
3964                        write!(self.out, " & 0xFFFF, ")?;
3965                        self.write_expr(module, arg, func_ctx)?;
3966                        write!(self.out, " >> 16) / {scale}.0)")?;
3967                    }
3968                    Function::Unpack4x8snorm => {
3969                        let scale = 127;
3970
3971                        write!(self.out, "(float4(int4(")?;
3972                        self.write_expr(module, arg, func_ctx)?;
3973                        write!(self.out, " << 24, ")?;
3974                        self.write_expr(module, arg, func_ctx)?;
3975                        write!(self.out, " << 16, ")?;
3976                        self.write_expr(module, arg, func_ctx)?;
3977                        write!(self.out, " << 8, ")?;
3978                        self.write_expr(module, arg, func_ctx)?;
3979                        write!(self.out, ") >> 24) / {scale}.0)")?;
3980                    }
3981                    Function::Unpack4x8unorm => {
3982                        let scale = 255;
3983
3984                        write!(self.out, "(float4(")?;
3985                        self.write_expr(module, arg, func_ctx)?;
3986                        write!(self.out, " & 0xFF, ")?;
3987                        self.write_expr(module, arg, func_ctx)?;
3988                        write!(self.out, " >> 8 & 0xFF, ")?;
3989                        self.write_expr(module, arg, func_ctx)?;
3990                        write!(self.out, " >> 16 & 0xFF, ")?;
3991                        self.write_expr(module, arg, func_ctx)?;
3992                        write!(self.out, " >> 24) / {scale}.0)")?;
3993                    }
3994                    fun @ (Function::Unpack4xI8 | Function::Unpack4xU8) => {
3995                        write!(self.out, "(")?;
3996                        if matches!(fun, Function::Unpack4xU8) {
3997                            write!(self.out, "u")?;
3998                        }
3999                        write!(self.out, "int4(")?;
4000                        self.write_expr(module, arg, func_ctx)?;
4001                        write!(self.out, ", ")?;
4002                        self.write_expr(module, arg, func_ctx)?;
4003                        write!(self.out, " >> 8, ")?;
4004                        self.write_expr(module, arg, func_ctx)?;
4005                        write!(self.out, " >> 16, ")?;
4006                        self.write_expr(module, arg, func_ctx)?;
4007                        write!(self.out, " >> 24) << 24 >> 24)")?;
4008                    }
4009                    fun @ (Function::Dot4I8Packed | Function::Dot4U8Packed) => {
4010                        let arg1 = arg1.unwrap();
4011
4012                        if self.options.shader_model >= ShaderModel::V6_4 {
4013                            // Intrinsics `dot4add_{i, u}8packed` are available in SM 6.4 and later.
4014                            let function_name = match fun {
4015                                Function::Dot4I8Packed => "dot4add_i8packed",
4016                                Function::Dot4U8Packed => "dot4add_u8packed",
4017                                _ => unreachable!(),
4018                            };
4019                            write!(self.out, "{function_name}(")?;
4020                            self.write_expr(module, arg, func_ctx)?;
4021                            write!(self.out, ", ")?;
4022                            self.write_expr(module, arg1, func_ctx)?;
4023                            write!(self.out, ", 0)")?;
4024                        } else {
4025                            // Fall back to a polyfill as `dot4add_u8packed` is not available.
4026                            write!(self.out, "dot(")?;
4027
4028                            if matches!(fun, Function::Dot4U8Packed) {
4029                                write!(self.out, "u")?;
4030                            }
4031                            write!(self.out, "int4(")?;
4032                            self.write_expr(module, arg, func_ctx)?;
4033                            write!(self.out, ", ")?;
4034                            self.write_expr(module, arg, func_ctx)?;
4035                            write!(self.out, " >> 8, ")?;
4036                            self.write_expr(module, arg, func_ctx)?;
4037                            write!(self.out, " >> 16, ")?;
4038                            self.write_expr(module, arg, func_ctx)?;
4039                            write!(self.out, " >> 24) << 24 >> 24, ")?;
4040
4041                            if matches!(fun, Function::Dot4U8Packed) {
4042                                write!(self.out, "u")?;
4043                            }
4044                            write!(self.out, "int4(")?;
4045                            self.write_expr(module, arg1, func_ctx)?;
4046                            write!(self.out, ", ")?;
4047                            self.write_expr(module, arg1, func_ctx)?;
4048                            write!(self.out, " >> 8, ")?;
4049                            self.write_expr(module, arg1, func_ctx)?;
4050                            write!(self.out, " >> 16, ")?;
4051                            self.write_expr(module, arg1, func_ctx)?;
4052                            write!(self.out, " >> 24) << 24 >> 24)")?;
4053                        }
4054                    }
4055                    Function::QuantizeToF16 => {
4056                        write!(self.out, "f16tof32(f32tof16(")?;
4057                        self.write_expr(module, arg, func_ctx)?;
4058                        write!(self.out, "))")?;
4059                    }
4060                    Function::Regular(fun_name) => {
4061                        write!(self.out, "{fun_name}(")?;
4062                        self.write_expr(module, arg, func_ctx)?;
4063                        if let Some(arg) = arg1 {
4064                            write!(self.out, ", ")?;
4065                            self.write_expr(module, arg, func_ctx)?;
4066                        }
4067                        if let Some(arg) = arg2 {
4068                            write!(self.out, ", ")?;
4069                            self.write_expr(module, arg, func_ctx)?;
4070                        }
4071                        if let Some(arg) = arg3 {
4072                            write!(self.out, ", ")?;
4073                            self.write_expr(module, arg, func_ctx)?;
4074                        }
4075                        write!(self.out, ")")?
4076                    }
4077                    // These overloads are only missing on FXC, so this is only needed for 32bit types,
4078                    // as non-32bit types are DXC only.
4079                    Function::MissingIntOverload(fun_name) => {
4080                        let scalar_kind = func_ctx.resolve_type(arg, &module.types).scalar();
4081                        if let Some(Scalar::I32) = scalar_kind {
4082                            write!(self.out, "asint({fun_name}(asuint(")?;
4083                            self.write_expr(module, arg, func_ctx)?;
4084                            write!(self.out, ")))")?;
4085                        } else {
4086                            write!(self.out, "{fun_name}(")?;
4087                            self.write_expr(module, arg, func_ctx)?;
4088                            write!(self.out, ")")?;
4089                        }
4090                    }
4091                    // These overloads are only missing on FXC, so this is only needed for 32bit types,
4092                    // as non-32bit types are DXC only.
4093                    Function::MissingIntReturnType(fun_name) => {
4094                        let scalar_kind = func_ctx.resolve_type(arg, &module.types).scalar();
4095                        if let Some(Scalar::I32) = scalar_kind {
4096                            write!(self.out, "asint({fun_name}(")?;
4097                            self.write_expr(module, arg, func_ctx)?;
4098                            write!(self.out, "))")?;
4099                        } else {
4100                            write!(self.out, "{fun_name}(")?;
4101                            self.write_expr(module, arg, func_ctx)?;
4102                            write!(self.out, ")")?;
4103                        }
4104                    }
4105                    Function::CountTrailingZeros => {
4106                        match *func_ctx.resolve_type(arg, &module.types) {
4107                            TypeInner::Vector { size, scalar } => {
4108                                let s = match size {
4109                                    crate::VectorSize::Bi => ".xx",
4110                                    crate::VectorSize::Tri => ".xxx",
4111                                    crate::VectorSize::Quad => ".xxxx",
4112                                };
4113
4114                                let scalar_width_bits = scalar.width * 8;
4115
4116                                if scalar.kind == ScalarKind::Uint || scalar.width != 4 {
4117                                    write!(
4118                                        self.out,
4119                                        "min(({scalar_width_bits}u){s}, firstbitlow("
4120                                    )?;
4121                                    self.write_expr(module, arg, func_ctx)?;
4122                                    write!(self.out, "))")?;
4123                                } else {
4124                                    // This is only needed for the FXC path, on 32bit signed integers.
4125                                    write!(
4126                                        self.out,
4127                                        "asint(min(({scalar_width_bits}u){s}, firstbitlow("
4128                                    )?;
4129                                    self.write_expr(module, arg, func_ctx)?;
4130                                    write!(self.out, ")))")?;
4131                                }
4132                            }
4133                            TypeInner::Scalar(scalar) => {
4134                                let scalar_width_bits = scalar.width * 8;
4135
4136                                if scalar.kind == ScalarKind::Uint || scalar.width != 4 {
4137                                    write!(self.out, "min({scalar_width_bits}u, firstbitlow(")?;
4138                                    self.write_expr(module, arg, func_ctx)?;
4139                                    write!(self.out, "))")?;
4140                                } else {
4141                                    // This is only needed for the FXC path, on 32bit signed integers.
4142                                    write!(
4143                                        self.out,
4144                                        "asint(min({scalar_width_bits}u, firstbitlow("
4145                                    )?;
4146                                    self.write_expr(module, arg, func_ctx)?;
4147                                    write!(self.out, ")))")?;
4148                                }
4149                            }
4150                            _ => unreachable!(),
4151                        }
4152
4153                        return Ok(());
4154                    }
4155                    Function::CountLeadingZeros => {
4156                        match *func_ctx.resolve_type(arg, &module.types) {
4157                            TypeInner::Vector { size, scalar } => {
4158                                let s = match size {
4159                                    crate::VectorSize::Bi => ".xx",
4160                                    crate::VectorSize::Tri => ".xxx",
4161                                    crate::VectorSize::Quad => ".xxxx",
4162                                };
4163
4164                                // scalar width - 1
4165                                let constant = scalar.width * 8 - 1;
4166
4167                                if scalar.kind == ScalarKind::Uint {
4168                                    write!(self.out, "(({constant}u){s} - firstbithigh(")?;
4169                                    self.write_expr(module, arg, func_ctx)?;
4170                                    write!(self.out, "))")?;
4171                                } else {
4172                                    let conversion_func = match scalar.width {
4173                                        4 => "asint",
4174                                        _ => "",
4175                                    };
4176                                    write!(self.out, "(")?;
4177                                    self.write_expr(module, arg, func_ctx)?;
4178                                    write!(
4179                                        self.out,
4180                                        " < (0){s} ? (0){s} : ({constant}){s} - {conversion_func}(firstbithigh("
4181                                    )?;
4182                                    self.write_expr(module, arg, func_ctx)?;
4183                                    write!(self.out, ")))")?;
4184                                }
4185                            }
4186                            TypeInner::Scalar(scalar) => {
4187                                // scalar width - 1
4188                                let constant = scalar.width * 8 - 1;
4189
4190                                if let ScalarKind::Uint = scalar.kind {
4191                                    write!(self.out, "({constant}u - firstbithigh(")?;
4192                                    self.write_expr(module, arg, func_ctx)?;
4193                                    write!(self.out, "))")?;
4194                                } else {
4195                                    let conversion_func = match scalar.width {
4196                                        4 => "asint",
4197                                        _ => "",
4198                                    };
4199                                    write!(self.out, "(")?;
4200                                    self.write_expr(module, arg, func_ctx)?;
4201                                    write!(
4202                                        self.out,
4203                                        " < 0 ? 0 : {constant} - {conversion_func}(firstbithigh("
4204                                    )?;
4205                                    self.write_expr(module, arg, func_ctx)?;
4206                                    write!(self.out, ")))")?;
4207                                }
4208                            }
4209                            _ => unreachable!(),
4210                        }
4211
4212                        return Ok(());
4213                    }
4214                }
4215            }
4216            Expression::Swizzle {
4217                size,
4218                vector,
4219                pattern,
4220            } => {
4221                self.write_expr(module, vector, func_ctx)?;
4222                write!(self.out, ".")?;
4223                for &sc in pattern[..size as usize].iter() {
4224                    self.out.write_char(back::COMPONENTS[sc as usize])?;
4225                }
4226            }
4227            Expression::ArrayLength(expr) => {
4228                let var_handle = match func_ctx.expressions[expr] {
4229                    Expression::AccessIndex { base, index: _ } => {
4230                        match func_ctx.expressions[base] {
4231                            Expression::GlobalVariable(handle) => handle,
4232                            _ => unreachable!(),
4233                        }
4234                    }
4235                    Expression::GlobalVariable(handle) => handle,
4236                    _ => unreachable!(),
4237                };
4238
4239                let var = &module.global_variables[var_handle];
4240                let (offset, stride) = match module.types[var.ty].inner {
4241                    TypeInner::Array { stride, .. } => (0, stride),
4242                    TypeInner::Struct { ref members, .. } => {
4243                        let last = members.last().unwrap();
4244                        let stride = match module.types[last.ty].inner {
4245                            TypeInner::Array { stride, .. } => stride,
4246                            _ => unreachable!(),
4247                        };
4248                        (last.offset, stride)
4249                    }
4250                    _ => unreachable!(),
4251                };
4252
4253                let storage_access = match var.space {
4254                    crate::AddressSpace::Storage { access } => access,
4255                    _ => crate::StorageAccess::default(),
4256                };
4257                let wrapped_array_length = WrappedArrayLength {
4258                    writable: storage_access.contains(crate::StorageAccess::STORE),
4259                };
4260
4261                write!(self.out, "((")?;
4262                self.write_wrapped_array_length_function_name(wrapped_array_length)?;
4263                let var_name = &self.names[&NameKey::GlobalVariable(var_handle)];
4264                write!(self.out, "({var_name}) - {offset}) / {stride})")?
4265            }
4266            Expression::Derivative { axis, ctrl, expr } => {
4267                use crate::{DerivativeAxis as Axis, DerivativeControl as Ctrl};
4268                if axis == Axis::Width && (ctrl == Ctrl::Coarse || ctrl == Ctrl::Fine) {
4269                    let tail = match ctrl {
4270                        Ctrl::Coarse => "coarse",
4271                        Ctrl::Fine => "fine",
4272                        Ctrl::None => unreachable!(),
4273                    };
4274                    write!(self.out, "abs(ddx_{tail}(")?;
4275                    self.write_expr(module, expr, func_ctx)?;
4276                    write!(self.out, ")) + abs(ddy_{tail}(")?;
4277                    self.write_expr(module, expr, func_ctx)?;
4278                    write!(self.out, "))")?
4279                } else {
4280                    let fun_str = match (axis, ctrl) {
4281                        (Axis::X, Ctrl::Coarse) => "ddx_coarse",
4282                        (Axis::X, Ctrl::Fine) => "ddx_fine",
4283                        (Axis::X, Ctrl::None) => "ddx",
4284                        (Axis::Y, Ctrl::Coarse) => "ddy_coarse",
4285                        (Axis::Y, Ctrl::Fine) => "ddy_fine",
4286                        (Axis::Y, Ctrl::None) => "ddy",
4287                        (Axis::Width, Ctrl::Coarse | Ctrl::Fine) => unreachable!(),
4288                        (Axis::Width, Ctrl::None) => "fwidth",
4289                    };
4290                    write!(self.out, "{fun_str}(")?;
4291                    self.write_expr(module, expr, func_ctx)?;
4292                    write!(self.out, ")")?
4293                }
4294            }
4295            Expression::Relational { fun, argument } => {
4296                use crate::RelationalFunction as Rf;
4297
4298                let fun_str = match fun {
4299                    Rf::All => "all",
4300                    Rf::Any => "any",
4301                    Rf::IsNan => "isnan",
4302                    Rf::IsInf => "isinf",
4303                };
4304                write!(self.out, "{fun_str}(")?;
4305                self.write_expr(module, argument, func_ctx)?;
4306                write!(self.out, ")")?
4307            }
4308            Expression::Select {
4309                condition,
4310                accept,
4311                reject,
4312            } => {
4313                write!(self.out, "(")?;
4314                self.write_expr(module, condition, func_ctx)?;
4315                write!(self.out, " ? ")?;
4316                self.write_expr(module, accept, func_ctx)?;
4317                write!(self.out, " : ")?;
4318                self.write_expr(module, reject, func_ctx)?;
4319                write!(self.out, ")")?
4320            }
4321            Expression::RayQueryGetIntersection { query, committed } => {
4322                // For reasoning, see write_stmt
4323                let Expression::LocalVariable(query_var) = func_ctx.expressions[query] else {
4324                    unreachable!()
4325                };
4326
4327                let tracker_expr_name = format!(
4328                    "{RAY_QUERY_TRACKER_VARIABLE_PREFIX}{}",
4329                    self.names[&func_ctx.name_key(query_var)]
4330                );
4331
4332                if committed {
4333                    write!(self.out, "GetCommittedIntersection(")?;
4334                    self.write_expr(module, query, func_ctx)?;
4335                    write!(self.out, ", {tracker_expr_name})")?;
4336                } else {
4337                    write!(self.out, "GetCandidateIntersection(")?;
4338                    self.write_expr(module, query, func_ctx)?;
4339                    write!(self.out, ", {tracker_expr_name})")?;
4340                }
4341            }
4342            // Not supported yet
4343            Expression::RayQueryVertexPositions { .. }
4344            | Expression::CooperativeLoad { .. }
4345            | Expression::CooperativeMultiplyAdd { .. } => {
4346                unreachable!()
4347            }
4348            // Nothing to do here, since call expression already cached
4349            Expression::CallResult(_)
4350            | Expression::AtomicResult { .. }
4351            | Expression::WorkGroupUniformLoadResult { .. }
4352            | Expression::RayQueryProceedResult
4353            | Expression::SubgroupBallotResult
4354            | Expression::SubgroupOperationResult { .. } => {}
4355        }
4356
4357        if !closing_bracket.is_empty() {
4358            write!(self.out, "{closing_bracket}")?;
4359        }
4360        Ok(())
4361    }
4362
4363    #[allow(clippy::too_many_arguments)]
4364    fn write_image_load(
4365        &mut self,
4366        module: &&Module,
4367        expr: Handle<crate::Expression>,
4368        func_ctx: &back::FunctionCtx,
4369        image: Handle<crate::Expression>,
4370        coordinate: Handle<crate::Expression>,
4371        array_index: Option<Handle<crate::Expression>>,
4372        sample: Option<Handle<crate::Expression>>,
4373        level: Option<Handle<crate::Expression>>,
4374    ) -> Result<(), Error> {
4375        let mut wrapping_type = None;
4376        match *func_ctx.resolve_type(image, &module.types) {
4377            TypeInner::Image {
4378                class: crate::ImageClass::External,
4379                ..
4380            } => {
4381                write!(self.out, "{IMAGE_LOAD_EXTERNAL_FUNCTION}(")?;
4382                self.write_expr(module, image, func_ctx)?;
4383                write!(self.out, ", ")?;
4384                self.write_expr(module, coordinate, func_ctx)?;
4385                write!(self.out, ")")?;
4386                return Ok(());
4387            }
4388            TypeInner::Image {
4389                class: crate::ImageClass::Storage { format, .. },
4390                ..
4391            } => {
4392                if format.single_component() {
4393                    wrapping_type = Some(Scalar::from(format));
4394                }
4395            }
4396            _ => {}
4397        }
4398        if let Some(scalar) = wrapping_type {
4399            write!(
4400                self.out,
4401                "{}{}(",
4402                help::IMAGE_STORAGE_LOAD_SCALAR_WRAPPER,
4403                scalar.to_hlsl_str()?
4404            )?;
4405        }
4406        // https://docs.microsoft.com/en-us/windows/win32/direct3dhlsl/dx-graphics-hlsl-to-load
4407        self.write_expr(module, image, func_ctx)?;
4408        write!(self.out, ".Load(")?;
4409
4410        self.write_texture_coordinates("int", coordinate, array_index, level, module, func_ctx)?;
4411
4412        if let Some(sample) = sample {
4413            write!(self.out, ", ")?;
4414            self.write_expr(module, sample, func_ctx)?;
4415        }
4416
4417        // close bracket for Load function
4418        write!(self.out, ")")?;
4419
4420        if wrapping_type.is_some() {
4421            write!(self.out, ")")?;
4422        }
4423
4424        // return x component if return type is scalar
4425        if let TypeInner::Scalar(_) = *func_ctx.resolve_type(expr, &module.types) {
4426            write!(self.out, ".x")?;
4427        }
4428        Ok(())
4429    }
4430
4431    /// Find the [`BindingArraySamplerInfo`] from an expression so that such an access
4432    /// can be generated later.
4433    fn sampler_binding_array_info_from_expression(
4434        &mut self,
4435        module: &Module,
4436        func_ctx: &back::FunctionCtx<'_>,
4437        base: Handle<crate::Expression>,
4438        resolved: &TypeInner,
4439    ) -> Option<BindingArraySamplerInfo> {
4440        if let TypeInner::BindingArray {
4441            base: base_ty_handle,
4442            ..
4443        } = *resolved
4444        {
4445            let base_ty = &module.types[base_ty_handle].inner;
4446            if let TypeInner::Sampler { comparison, .. } = *base_ty {
4447                let base = &func_ctx.expressions[base];
4448
4449                if let crate::Expression::GlobalVariable(handle) = *base {
4450                    let variable = &module.global_variables[handle];
4451
4452                    let sampler_heap_name = match comparison {
4453                        true => COMPARISON_SAMPLER_HEAP_VAR,
4454                        false => SAMPLER_HEAP_VAR,
4455                    };
4456
4457                    return Some(BindingArraySamplerInfo {
4458                        sampler_heap_name,
4459                        sampler_index_buffer_name: self
4460                            .wrapped
4461                            .sampler_index_buffers
4462                            .get(&super::SamplerIndexBufferKey {
4463                                group: variable.binding.unwrap().group,
4464                            })
4465                            .unwrap()
4466                            .clone(),
4467                        binding_array_base_index_name: self.names[&NameKey::GlobalVariable(handle)]
4468                            .clone(),
4469                    });
4470                }
4471            }
4472        }
4473
4474        None
4475    }
4476
4477    fn write_named_expr(
4478        &mut self,
4479        module: &Module,
4480        handle: Handle<crate::Expression>,
4481        name: String,
4482        // The expression which is being named.
4483        // Generally, this is the same as handle, except in WorkGroupUniformLoad
4484        named: Handle<crate::Expression>,
4485        ctx: &back::FunctionCtx,
4486    ) -> BackendResult {
4487        match ctx.info[named].ty {
4488            proc::TypeResolution::Handle(ty_handle) => match module.types[ty_handle].inner {
4489                TypeInner::Struct { .. } => {
4490                    let ty_name = &self.names[&NameKey::Type(ty_handle)];
4491                    write!(self.out, "{ty_name}")?;
4492                }
4493                _ => {
4494                    self.write_type(module, ty_handle)?;
4495                }
4496            },
4497            proc::TypeResolution::Value(ref inner) => {
4498                self.write_value_type(module, inner)?;
4499            }
4500        }
4501
4502        let resolved = ctx.resolve_type(named, &module.types);
4503
4504        write!(self.out, " {name}")?;
4505        // If rhs is a array type, we should write array size
4506        if let TypeInner::Array { base, size, .. } = *resolved {
4507            self.write_array_size(module, base, size)?;
4508        }
4509        write!(self.out, " = ")?;
4510        self.write_expr(module, handle, ctx)?;
4511        writeln!(self.out, ";")?;
4512        self.named_expressions.insert(named, name);
4513
4514        Ok(())
4515    }
4516
4517    /// Helper function that write default zero initialization
4518    pub(super) fn write_default_init(
4519        &mut self,
4520        module: &Module,
4521        ty: Handle<crate::Type>,
4522    ) -> BackendResult {
4523        write!(self.out, "(")?;
4524        self.write_type(module, ty)?;
4525        if let TypeInner::Array { base, size, .. } = module.types[ty].inner {
4526            self.write_array_size(module, base, size)?;
4527        }
4528        write!(self.out, ")0")?;
4529        Ok(())
4530    }
4531
4532    fn write_control_barrier(
4533        &mut self,
4534        barrier: crate::Barrier,
4535        level: back::Level,
4536    ) -> BackendResult {
4537        if barrier.contains(crate::Barrier::STORAGE) {
4538            writeln!(self.out, "{level}DeviceMemoryBarrierWithGroupSync();")?;
4539        }
4540        if barrier.contains(crate::Barrier::WORK_GROUP) {
4541            writeln!(self.out, "{level}GroupMemoryBarrierWithGroupSync();")?;
4542        }
4543        if barrier.contains(crate::Barrier::SUB_GROUP) {
4544            // Does not exist in DirectX
4545        }
4546        if barrier.contains(crate::Barrier::TEXTURE) {
4547            writeln!(self.out, "{level}DeviceMemoryBarrierWithGroupSync();")?;
4548        }
4549        Ok(())
4550    }
4551
4552    fn write_memory_barrier(
4553        &mut self,
4554        barrier: crate::Barrier,
4555        level: back::Level,
4556    ) -> BackendResult {
4557        if barrier.contains(crate::Barrier::STORAGE) {
4558            writeln!(self.out, "{level}DeviceMemoryBarrier();")?;
4559        }
4560        if barrier.contains(crate::Barrier::WORK_GROUP) {
4561            writeln!(self.out, "{level}GroupMemoryBarrier();")?;
4562        }
4563        if barrier.contains(crate::Barrier::SUB_GROUP) {
4564            // Does not exist in DirectX
4565        }
4566        if barrier.contains(crate::Barrier::TEXTURE) {
4567            writeln!(self.out, "{level}DeviceMemoryBarrier();")?;
4568        }
4569        Ok(())
4570    }
4571
4572    /// Helper to emit the shared tail of an HLSL atomic call (arguments, value, result)
4573    fn emit_hlsl_atomic_tail(
4574        &mut self,
4575        module: &Module,
4576        func_ctx: &back::FunctionCtx<'_>,
4577        fun: &crate::AtomicFunction,
4578        compare_expr: Option<Handle<crate::Expression>>,
4579        value: Handle<crate::Expression>,
4580        res_var_info: &Option<(Handle<crate::Expression>, String)>,
4581    ) -> BackendResult {
4582        if let Some(cmp) = compare_expr {
4583            write!(self.out, ", ")?;
4584            self.write_expr(module, cmp, func_ctx)?;
4585        }
4586        write!(self.out, ", ")?;
4587        if let crate::AtomicFunction::Subtract = *fun {
4588            // we just wrote `InterlockedAdd`, so negate the argument
4589            write!(self.out, "-")?;
4590        }
4591        self.write_expr(module, value, func_ctx)?;
4592        if let Some(&(_res_handle, ref res_name)) = res_var_info.as_ref() {
4593            write!(self.out, ", ")?;
4594            if compare_expr.is_some() {
4595                write!(self.out, "{res_name}.old_value")?;
4596            } else {
4597                write!(self.out, "{res_name}")?;
4598            }
4599        }
4600        writeln!(self.out, ");")?;
4601        Ok(())
4602    }
4603}
4604
4605pub(super) struct MatrixType {
4606    pub(super) columns: crate::VectorSize,
4607    pub(super) rows: crate::VectorSize,
4608    pub(super) width: crate::Bytes,
4609}
4610
4611pub(super) fn get_inner_matrix_data(
4612    module: &Module,
4613    handle: Handle<crate::Type>,
4614) -> Option<MatrixType> {
4615    match module.types[handle].inner {
4616        TypeInner::Matrix {
4617            columns,
4618            rows,
4619            scalar,
4620        } => Some(MatrixType {
4621            columns,
4622            rows,
4623            width: scalar.width,
4624        }),
4625        TypeInner::Array { base, .. } => get_inner_matrix_data(module, base),
4626        _ => None,
4627    }
4628}
4629
4630/// If `base` is an access chain of the form `mat`, `mat[col]`, or `mat[col][row]`,
4631/// returns a tuple of the matrix, the column (vector) index (if present), and
4632/// the row (scalar) index (if present).
4633fn find_matrix_in_access_chain(
4634    module: &Module,
4635    base: Handle<crate::Expression>,
4636    func_ctx: &back::FunctionCtx<'_>,
4637) -> Option<(Handle<crate::Expression>, Option<Index>, Option<Index>)> {
4638    let mut current_base = base;
4639    let mut vector = None;
4640    let mut scalar = None;
4641    loop {
4642        let resolved_tr = func_ctx
4643            .resolve_type(current_base, &module.types)
4644            .pointer_base_type();
4645        let resolved = resolved_tr.as_ref()?.inner_with(&module.types);
4646
4647        match *resolved {
4648            TypeInner::Matrix { .. } => return Some((current_base, vector, scalar)),
4649            TypeInner::Scalar(_) | TypeInner::Vector { .. } => {}
4650            _ => return None,
4651        }
4652
4653        let index;
4654        (current_base, index) = match func_ctx.expressions[current_base] {
4655            crate::Expression::Access { base, index } => (base, Index::Expression(index)),
4656            crate::Expression::AccessIndex { base, index } => (base, Index::Static(index)),
4657            _ => return None,
4658        };
4659
4660        match *resolved {
4661            TypeInner::Scalar(_) => scalar = Some(index),
4662            TypeInner::Vector { .. } => vector = Some(index),
4663            _ => unreachable!(),
4664        }
4665    }
4666}
4667
4668/// Returns the matrix data if the access chain starting at `base`:
4669/// - starts with an expression with resolved type of [`TypeInner::Matrix`] if `direct = true`
4670/// - contains one or more expressions with resolved type of [`TypeInner::Array`] of [`TypeInner::Matrix`]
4671/// - ends at an expression with resolved type of [`TypeInner::Struct`]
4672pub(super) fn get_inner_matrix_of_struct_array_member(
4673    module: &Module,
4674    base: Handle<crate::Expression>,
4675    func_ctx: &back::FunctionCtx<'_>,
4676    direct: bool,
4677) -> Option<MatrixType> {
4678    let mut mat_data = None;
4679    let mut array_base = None;
4680
4681    let mut current_base = base;
4682    loop {
4683        let mut resolved = func_ctx.resolve_type(current_base, &module.types);
4684        if let TypeInner::Pointer { base, .. } = *resolved {
4685            resolved = &module.types[base].inner;
4686        };
4687
4688        match *resolved {
4689            TypeInner::Matrix {
4690                columns,
4691                rows,
4692                scalar,
4693            } => {
4694                mat_data = Some(MatrixType {
4695                    columns,
4696                    rows,
4697                    width: scalar.width,
4698                })
4699            }
4700            TypeInner::Array { base, .. } => {
4701                array_base = Some(base);
4702            }
4703            TypeInner::Struct { .. } => {
4704                if let Some(array_base) = array_base {
4705                    if direct {
4706                        return mat_data;
4707                    } else {
4708                        return get_inner_matrix_data(module, array_base);
4709                    }
4710                }
4711
4712                break;
4713            }
4714            _ => break,
4715        }
4716
4717        current_base = match func_ctx.expressions[current_base] {
4718            crate::Expression::Access { base, .. } => base,
4719            crate::Expression::AccessIndex { base, .. } => base,
4720            _ => break,
4721        };
4722    }
4723    None
4724}
4725
4726/// Simpler version of get_inner_matrix_of_global_uniform that only looks at the
4727/// immediate expression, rather than traversing an access chain.
4728fn get_global_uniform_matrix(
4729    module: &Module,
4730    base: Handle<crate::Expression>,
4731    func_ctx: &back::FunctionCtx<'_>,
4732) -> Option<MatrixType> {
4733    let base_tr = func_ctx
4734        .resolve_type(base, &module.types)
4735        .pointer_base_type();
4736    let base_ty = base_tr.as_ref().map(|tr| tr.inner_with(&module.types));
4737    match (&func_ctx.expressions[base], base_ty) {
4738        (
4739            &crate::Expression::GlobalVariable(handle),
4740            Some(&TypeInner::Matrix {
4741                columns,
4742                rows,
4743                scalar,
4744            }),
4745        ) if module.global_variables[handle].space == crate::AddressSpace::Uniform => {
4746            Some(MatrixType {
4747                columns,
4748                rows,
4749                width: scalar.width,
4750            })
4751        }
4752        _ => None,
4753    }
4754}
4755
4756/// Returns the matrix data if the access chain starting at `base`:
4757/// - starts with an expression with resolved type of [`TypeInner::Matrix`]
4758/// - contains zero or more expressions with resolved type of [`TypeInner::Array`] of [`TypeInner::Matrix`]
4759/// - ends with an [`Expression::GlobalVariable`](crate::Expression::GlobalVariable) in [`AddressSpace::Uniform`](crate::AddressSpace::Uniform)
4760fn get_inner_matrix_of_global_uniform(
4761    module: &Module,
4762    base: Handle<crate::Expression>,
4763    func_ctx: &back::FunctionCtx<'_>,
4764) -> Option<MatrixType> {
4765    let mut mat_data = None;
4766    let mut array_base = None;
4767
4768    let mut current_base = base;
4769    loop {
4770        let mut resolved = func_ctx.resolve_type(current_base, &module.types);
4771        if let TypeInner::Pointer { base, .. } = *resolved {
4772            resolved = &module.types[base].inner;
4773        };
4774
4775        match *resolved {
4776            TypeInner::Matrix {
4777                columns,
4778                rows,
4779                scalar,
4780            } => {
4781                mat_data = Some(MatrixType {
4782                    columns,
4783                    rows,
4784                    width: scalar.width,
4785                })
4786            }
4787            TypeInner::Array { base, .. } => {
4788                array_base = Some(base);
4789            }
4790            _ => break,
4791        }
4792
4793        current_base = match func_ctx.expressions[current_base] {
4794            crate::Expression::Access { base, .. } => base,
4795            crate::Expression::AccessIndex { base, .. } => base,
4796            crate::Expression::GlobalVariable(handle)
4797                if module.global_variables[handle].space == crate::AddressSpace::Uniform =>
4798            {
4799                return mat_data.or_else(|| {
4800                    array_base.and_then(|array_base| get_inner_matrix_data(module, array_base))
4801                })
4802            }
4803            _ => break,
4804        };
4805    }
4806    None
4807}