naga/back/hlsl/
writer.rs

1use alloc::{
2    format,
3    string::{String, ToString},
4    vec::Vec,
5};
6use core::{
7    fmt::{self, Write as _},
8    mem,
9};
10
11use super::{
12    help,
13    help::{
14        WrappedArrayLength, WrappedConstructor, WrappedImageQuery, WrappedStructMatrixAccess,
15        WrappedZeroValue,
16    },
17    storage::StoreValue,
18    BackendResult, Error, FragmentEntryPoint, Options, PipelineOptions, ShaderModel,
19};
20use crate::{
21    back::{self, get_entry_points, Baked},
22    common,
23    proc::{self, index, ExternalTextureNameKey, NameKey},
24    valid, Handle, Module, RayQueryFunction, Scalar, ScalarKind, ShaderStage, TypeInner,
25};
26
27const LOCATION_SEMANTIC: &str = "LOC";
28const SPECIAL_CBUF_TYPE: &str = "NagaConstants";
29const SPECIAL_CBUF_VAR: &str = "_NagaConstants";
30const SPECIAL_FIRST_VERTEX: &str = "first_vertex";
31const SPECIAL_FIRST_INSTANCE: &str = "first_instance";
32const SPECIAL_OTHER: &str = "other";
33
34pub(crate) const MODF_FUNCTION: &str = "naga_modf";
35pub(crate) const FREXP_FUNCTION: &str = "naga_frexp";
36pub(crate) const EXTRACT_BITS_FUNCTION: &str = "naga_extractBits";
37pub(crate) const INSERT_BITS_FUNCTION: &str = "naga_insertBits";
38pub(crate) const SAMPLER_HEAP_VAR: &str = "nagaSamplerHeap";
39pub(crate) const COMPARISON_SAMPLER_HEAP_VAR: &str = "nagaComparisonSamplerHeap";
40pub(crate) const SAMPLE_EXTERNAL_TEXTURE_FUNCTION: &str = "nagaSampleExternalTexture";
41pub(crate) const ABS_FUNCTION: &str = "naga_abs";
42pub(crate) const DIV_FUNCTION: &str = "naga_div";
43pub(crate) const MOD_FUNCTION: &str = "naga_mod";
44pub(crate) const NEG_FUNCTION: &str = "naga_neg";
45pub(crate) const F2I32_FUNCTION: &str = "naga_f2i32";
46pub(crate) const F2U32_FUNCTION: &str = "naga_f2u32";
47pub(crate) const F2I64_FUNCTION: &str = "naga_f2i64";
48pub(crate) const F2U64_FUNCTION: &str = "naga_f2u64";
49pub(crate) const IMAGE_SAMPLE_BASE_CLAMP_TO_EDGE_FUNCTION: &str =
50    "nagaTextureSampleBaseClampToEdge";
51pub(crate) const IMAGE_LOAD_EXTERNAL_FUNCTION: &str = "nagaTextureLoadExternal";
52pub(crate) const RAY_QUERY_TRACKER_VARIABLE_PREFIX: &str = "naga_query_init_tracker_for_";
53/// Prefix for variables in a naga statement
54pub(crate) const INTERNAL_PREFIX: &str = "naga_";
55
56enum Index {
57    Expression(Handle<crate::Expression>),
58    Static(u32),
59}
60
61pub(super) struct EpStructMember {
62    pub(super) name: String,
63    pub(super) ty: Handle<crate::Type>,
64    // technically, this should always be `Some`
65    // (we `debug_assert!` this in `write_interface_struct`)
66    pub(super) binding: Option<crate::Binding>,
67    pub(super) index: u32,
68}
69
70/// Structure contains information required for generating
71/// wrapped structure of all entry points arguments
72pub(super) struct EntryPointBinding {
73    /// Name of the fake EP argument that contains the struct
74    /// with all the flattened input data.
75    pub(super) arg_name: String,
76    /// Generated structure name
77    pub(super) ty_name: String,
78    /// Members of generated structure
79    pub(super) members: Vec<EpStructMember>,
80    pub(super) local_invocation_index_name: Option<String>,
81}
82
83pub(super) struct EntryPointInterface {
84    /// If `Some`, the input of an entry point is gathered in a special
85    /// struct with members sorted by binding.
86    /// The `EntryPointBinding::members` array is sorted by index,
87    /// so that we can walk it in `write_ep_arguments_initialization`.
88    pub(crate) input: Option<EntryPointBinding>,
89    /// If `Some`, the output of an entry point is flattened.
90    /// The `EntryPointBinding::members` array is sorted by binding,
91    /// So that we can walk it in `Statement::Return` handler.
92    pub(crate) output: Option<EntryPointBinding>,
93    pub(crate) mesh_vertices: Option<EntryPointBinding>,
94    pub(crate) mesh_primitives: Option<EntryPointBinding>,
95    pub(crate) mesh_indices: Option<EntryPointBinding>,
96}
97
98#[derive(Clone, Eq, PartialEq, PartialOrd, Ord)]
99enum InterfaceKey {
100    Location(u32),
101    BuiltIn(crate::BuiltIn),
102    Other,
103}
104
105impl InterfaceKey {
106    const fn new(binding: Option<&crate::Binding>) -> Self {
107        match binding {
108            Some(&crate::Binding::Location { location, .. }) => Self::Location(location),
109            Some(&crate::Binding::BuiltIn(built_in)) => Self::BuiltIn(built_in),
110            None => Self::Other,
111        }
112    }
113}
114
115#[derive(Copy, Clone, PartialEq)]
116pub(super) enum Io {
117    Input,
118    Output,
119    MeshVertices,
120    MeshPrimitives,
121}
122
123/// Argument list for nested entry points
124pub(super) struct NestedEntryPointArgs {
125    /// Arguments literally declared by the user
126    pub user_args: Vec<String>,
127    pub task_payload: Option<String>,
128    pub local_invocation_index: String,
129}
130
131const fn is_subgroup_builtin_binding(binding: &Option<crate::Binding>) -> bool {
132    let &Some(crate::Binding::BuiltIn(builtin)) = binding else {
133        return false;
134    };
135    matches!(
136        builtin,
137        crate::BuiltIn::SubgroupSize
138            | crate::BuiltIn::SubgroupInvocationId
139            | crate::BuiltIn::NumSubgroups
140            | crate::BuiltIn::SubgroupId
141    )
142}
143
144/// Information for how to generate a `binding_array<sampler>` access.
145struct BindingArraySamplerInfo {
146    /// Variable name of the sampler heap
147    sampler_heap_name: &'static str,
148    /// Variable name of the sampler index buffer
149    sampler_index_buffer_name: String,
150    /// Variable name of the base index _into_ the sampler index buffer
151    binding_array_base_index_name: String,
152}
153
154impl<'a, W: fmt::Write> super::Writer<'a, W> {
155    pub fn new(out: W, options: &'a Options, pipeline_options: &'a PipelineOptions) -> Self {
156        Self {
157            out,
158            names: crate::FastHashMap::default(),
159            namer: proc::Namer::default(),
160            options,
161            pipeline_options,
162            entry_point_io: crate::FastHashMap::default(),
163            named_expressions: crate::NamedExpressions::default(),
164            wrapped: super::Wrapped::default(),
165            written_committed_intersection: false,
166            written_candidate_intersection: false,
167            continue_ctx: back::continue_forward::ContinueCtx::default(),
168            temp_access_chain: Vec::new(),
169            need_bake_expressions: Default::default(),
170            function_task_payload_var: Default::default(),
171        }
172    }
173
174    fn reset(&mut self, module: &Module) {
175        self.names.clear();
176        self.namer.reset(
177            module,
178            &super::keywords::RESERVED_SET,
179            proc::KeywordSet::empty(),
180            &super::keywords::RESERVED_CASE_INSENSITIVE_SET,
181            super::keywords::RESERVED_PREFIXES,
182            &mut self.names,
183        );
184        self.entry_point_io.clear();
185        self.named_expressions.clear();
186        self.wrapped.clear();
187        self.written_committed_intersection = false;
188        self.written_candidate_intersection = false;
189        self.continue_ctx.clear();
190        self.need_bake_expressions.clear();
191        self.function_task_payload_var.clear();
192    }
193
194    /// Generates statements to be inserted immediately before and at the very
195    /// start of the body of each loop, to defeat infinite loop reasoning.
196    /// The 0th item of the returned tuple should be inserted immediately prior
197    /// to the loop and the 1st item should be inserted at the very start of
198    /// the loop body.
199    ///
200    /// See [`back::msl::Writer::gen_force_bounded_loop_statements`] for details.
201    fn gen_force_bounded_loop_statements(
202        &mut self,
203        level: back::Level,
204    ) -> Option<(String, String)> {
205        if !self.options.force_loop_bounding {
206            return None;
207        }
208
209        let loop_bound_name = self.namer.call("loop_bound");
210        let max = u32::MAX;
211        // Count down from u32::MAX rather than up from 0 to avoid hang on
212        // certain Intel drivers. See <https://github.com/gfx-rs/wgpu/issues/7319>.
213        let decl = format!("{level}uint2 {loop_bound_name} = uint2({max}u, {max}u);");
214        let level = level.next();
215        let break_and_inc = format!(
216            "{level}if (all({loop_bound_name} == uint2(0u, 0u))) {{ break; }}
217{level}{loop_bound_name} -= uint2({loop_bound_name}.y == 0u, 1u);"
218        );
219
220        Some((decl, break_and_inc))
221    }
222
223    /// Helper method used to find which expressions of a given function require baking
224    ///
225    /// # Notes
226    /// Clears `need_bake_expressions` set before adding to it
227    fn update_expressions_to_bake(
228        &mut self,
229        module: &Module,
230        func: &crate::Function,
231        info: &valid::FunctionInfo,
232    ) {
233        use crate::Expression;
234        self.need_bake_expressions.clear();
235        for (exp_handle, expr) in func.expressions.iter() {
236            let expr_info = &info[exp_handle];
237            let min_ref_count = func.expressions[exp_handle].bake_ref_count();
238            if min_ref_count <= expr_info.ref_count {
239                self.need_bake_expressions.insert(exp_handle);
240            }
241
242            if let Expression::Math { fun, arg, arg1, .. } = *expr {
243                match fun {
244                    crate::MathFunction::Asinh
245                    | crate::MathFunction::Acosh
246                    | crate::MathFunction::Atanh
247                    | crate::MathFunction::Unpack2x16float
248                    | crate::MathFunction::Unpack2x16snorm
249                    | crate::MathFunction::Unpack2x16unorm
250                    | crate::MathFunction::Unpack4x8snorm
251                    | crate::MathFunction::Unpack4x8unorm
252                    | crate::MathFunction::Unpack4xI8
253                    | crate::MathFunction::Unpack4xU8
254                    | crate::MathFunction::Pack2x16float
255                    | crate::MathFunction::Pack2x16snorm
256                    | crate::MathFunction::Pack2x16unorm
257                    | crate::MathFunction::Pack4x8snorm
258                    | crate::MathFunction::Pack4x8unorm
259                    | crate::MathFunction::Pack4xI8
260                    | crate::MathFunction::Pack4xU8
261                    | crate::MathFunction::Pack4xI8Clamp
262                    | crate::MathFunction::Pack4xU8Clamp => {
263                        self.need_bake_expressions.insert(arg);
264                    }
265                    crate::MathFunction::CountLeadingZeros => {
266                        let inner = info[exp_handle].ty.inner_with(&module.types);
267                        if let Some(ScalarKind::Sint) = inner.scalar_kind() {
268                            self.need_bake_expressions.insert(arg);
269                        }
270                    }
271                    crate::MathFunction::Dot4U8Packed | crate::MathFunction::Dot4I8Packed => {
272                        self.need_bake_expressions.insert(arg);
273                        self.need_bake_expressions.insert(arg1.unwrap());
274                    }
275                    _ => {}
276                }
277            }
278
279            if let Expression::Derivative { axis, ctrl, expr } = *expr {
280                use crate::{DerivativeAxis as Axis, DerivativeControl as Ctrl};
281                if axis == Axis::Width && (ctrl == Ctrl::Coarse || ctrl == Ctrl::Fine) {
282                    self.need_bake_expressions.insert(expr);
283                }
284            }
285
286            if let Expression::GlobalVariable(_) = *expr {
287                let inner = info[exp_handle].ty.inner_with(&module.types);
288
289                if let TypeInner::Sampler { .. } = *inner {
290                    self.need_bake_expressions.insert(exp_handle);
291                }
292            }
293        }
294        for statement in func.body.iter() {
295            match *statement {
296                crate::Statement::SubgroupCollectiveOperation {
297                    op: _,
298                    collective_op: crate::CollectiveOperation::InclusiveScan,
299                    argument,
300                    result: _,
301                } => {
302                    self.need_bake_expressions.insert(argument);
303                }
304                crate::Statement::Atomic {
305                    fun: crate::AtomicFunction::Exchange { compare: Some(cmp) },
306                    ..
307                } => {
308                    self.need_bake_expressions.insert(cmp);
309                }
310                _ => {}
311            }
312        }
313    }
314
315    pub fn write(
316        &mut self,
317        module: &Module,
318        module_info: &valid::ModuleInfo,
319        fragment_entry_point: Option<&FragmentEntryPoint<'_>>,
320    ) -> Result<super::ReflectionInfo, Error> {
321        self.reset(module);
322
323        if module.uses_mesh_shaders() && self.options.shader_model < ShaderModel::V6_5 {
324            return Err(Error::ShaderModelTooLow(
325                "mesh shaders".to_string(),
326                ShaderModel::V6_5,
327            ));
328        }
329
330        // Write special constants, if needed
331        if let Some(ref bt) = self.options.special_constants_binding {
332            writeln!(self.out, "struct {SPECIAL_CBUF_TYPE} {{")?;
333            writeln!(self.out, "{}int {};", back::INDENT, SPECIAL_FIRST_VERTEX)?;
334            writeln!(self.out, "{}int {};", back::INDENT, SPECIAL_FIRST_INSTANCE)?;
335            writeln!(self.out, "{}uint {};", back::INDENT, SPECIAL_OTHER)?;
336            writeln!(self.out, "}};")?;
337            write!(
338                self.out,
339                "ConstantBuffer<{}> {}: register(b{}",
340                SPECIAL_CBUF_TYPE, SPECIAL_CBUF_VAR, bt.register
341            )?;
342            if bt.space != 0 {
343                write!(self.out, ", space{}", bt.space)?;
344            }
345            writeln!(self.out, ");")?;
346
347            // Extra newline for readability
348            writeln!(self.out)?;
349        }
350
351        for (group, bt) in self.options.dynamic_storage_buffer_offsets_targets.iter() {
352            writeln!(self.out, "struct __dynamic_buffer_offsetsTy{group} {{")?;
353            for i in 0..bt.size {
354                writeln!(self.out, "{}uint _{};", back::INDENT, i)?;
355            }
356            writeln!(self.out, "}};")?;
357            writeln!(
358                self.out,
359                "ConstantBuffer<__dynamic_buffer_offsetsTy{}> __dynamic_buffer_offsets{}: register(b{}, space{});",
360                group, group, bt.register, bt.space
361            )?;
362
363            // Extra newline for readability
364            writeln!(self.out)?;
365        }
366
367        // Save all entry point output types
368        let ep_results = module
369            .entry_points
370            .iter()
371            .map(|ep| (ep.stage, ep.function.result.clone()))
372            .collect::<Vec<(ShaderStage, Option<crate::FunctionResult>)>>();
373
374        self.write_all_mat_cx2_typedefs_and_functions(module)?;
375
376        // Write all structs
377        for (handle, ty) in module.types.iter() {
378            if let TypeInner::Struct { ref members, span } = ty.inner {
379                if module.types[members.last().unwrap().ty]
380                    .inner
381                    .is_dynamically_sized(&module.types)
382                {
383                    // unsized arrays can only be in storage buffers,
384                    // for which we use `ByteAddressBuffer` anyway.
385                    continue;
386                }
387
388                let ep_result = ep_results.iter().find(|e| {
389                    if let Some(ref result) = e.1 {
390                        result.ty == handle
391                    } else {
392                        false
393                    }
394                });
395
396                self.write_struct(
397                    module,
398                    handle,
399                    members,
400                    span,
401                    ep_result.map(|r| (r.0, Io::Output)),
402                )?;
403                writeln!(self.out)?;
404            }
405        }
406
407        self.write_special_functions(module)?;
408
409        self.write_wrapped_expression_functions(module, &module.global_expressions, None)?;
410        self.write_wrapped_zero_value_functions(module, &module.global_expressions)?;
411
412        // Write all named constants
413        let mut constants = module
414            .constants
415            .iter()
416            .filter(|&(_, c)| c.name.is_some())
417            .peekable();
418        while let Some((handle, _)) = constants.next() {
419            self.write_global_constant(module, handle)?;
420            // Add extra newline for readability on last iteration
421            if constants.peek().is_none() {
422                writeln!(self.out)?;
423            }
424        }
425
426        // Write all globals
427        for (global, _) in module.global_variables.iter() {
428            self.write_global(module, global)?;
429        }
430
431        if !module.global_variables.is_empty() {
432            // Add extra newline for readability
433            writeln!(self.out)?;
434        }
435
436        let ep_range = get_entry_points(module, self.pipeline_options.entry_point.as_ref())
437            .map_err(|(stage, name)| Error::EntryPointNotFound(stage, name))?;
438
439        // Write all entry points wrapped structs
440        for index in ep_range.clone() {
441            let ep = &module.entry_points[index];
442            let ep_name = self.names[&NameKey::EntryPoint(index as u16)].clone();
443            let ep_io = self.write_ep_interface(module, ep, &ep_name, fragment_entry_point)?;
444            self.entry_point_io.insert(index, ep_io);
445        }
446
447        // Write all regular functions
448        for (handle, function) in module.functions.iter() {
449            let info = &module_info[handle];
450
451            // Check if all of the globals are accessible
452            if !self.options.fake_missing_bindings {
453                if let Some((var_handle, _)) =
454                    module
455                        .global_variables
456                        .iter()
457                        .find(|&(var_handle, var)| match var.binding {
458                            Some(ref binding) if !info[var_handle].is_empty() => {
459                                self.options.resolve_resource_binding(binding).is_err()
460                                    && self
461                                        .options
462                                        .resolve_external_texture_resource_binding(binding)
463                                        .is_err()
464                            }
465                            _ => false,
466                        })
467                {
468                    log::debug!(
469                        "Skipping function {:?} (name {:?}) because global {:?} is inaccessible",
470                        handle,
471                        function.name,
472                        var_handle
473                    );
474                    continue;
475                }
476            }
477
478            let ctx = back::FunctionCtx {
479                ty: back::FunctionType::Function(handle),
480                info,
481                expressions: &function.expressions,
482                named_expressions: &function.named_expressions,
483            };
484            let name = self.names[&NameKey::Function(handle)].clone();
485
486            self.write_wrapped_functions(module, &ctx)?;
487
488            self.write_function(module, name.as_str(), function, &ctx, info, String::new())?;
489
490            writeln!(self.out)?;
491        }
492
493        let mut translated_ep_names = Vec::with_capacity(ep_range.len());
494
495        // Write all entry points
496        for index in ep_range {
497            let ep = &module.entry_points[index];
498            let info = module_info.get_entry_point(index);
499
500            if !self.options.fake_missing_bindings {
501                let mut ep_error = None;
502                for (var_handle, var) in module.global_variables.iter() {
503                    match var.binding {
504                        Some(ref binding) if !info[var_handle].is_empty() => {
505                            if let Err(err) = self.options.resolve_resource_binding(binding) {
506                                if self
507                                    .options
508                                    .resolve_external_texture_resource_binding(binding)
509                                    .is_err()
510                                {
511                                    ep_error = Some(err);
512                                    break;
513                                }
514                            }
515                        }
516                        _ => {}
517                    }
518                }
519                if let Some(err) = ep_error {
520                    translated_ep_names.push(Err(err));
521                    continue;
522                }
523            }
524
525            let ctx = back::FunctionCtx {
526                ty: back::FunctionType::EntryPoint(index as u16),
527                info,
528                expressions: &ep.function.expressions,
529                named_expressions: &ep.function.named_expressions,
530            };
531
532            self.write_wrapped_functions(module, &ctx)?;
533
534            // Mesh/task shaders have a wrapper entry point which is declared after the "main"
535            // user-written function. We therefore cannot always just document the next function.
536            let mut attribute_string = String::new();
537            if ep.stage.compute_like() {
538                // HLSL is calling workgroup size "num threads"
539                let num_threads = ep.workgroup_size;
540                writeln!(
541                    attribute_string,
542                    "[numthreads({}, {}, {})]",
543                    num_threads[0], num_threads[1], num_threads[2]
544                )?;
545            }
546            if let Some(ref info) = ep.mesh_info {
547                let topology_str = match info.topology {
548                    crate::MeshOutputTopology::Points => unreachable!(),
549                    crate::MeshOutputTopology::Lines => "line",
550                    crate::MeshOutputTopology::Triangles => "triangle",
551                };
552                writeln!(attribute_string, "[outputtopology(\"{topology_str}\")]")?;
553            }
554
555            let name = self.names[&NameKey::EntryPoint(index as u16)].clone();
556            self.write_function(module, &name, &ep.function, &ctx, info, attribute_string)?;
557
558            if index < module.entry_points.len() - 1 {
559                writeln!(self.out)?;
560            }
561
562            translated_ep_names.push(Ok(name));
563        }
564
565        Ok(super::ReflectionInfo {
566            entry_point_names: translated_ep_names,
567        })
568    }
569
570    fn write_modifier(&mut self, binding: &crate::Binding) -> BackendResult {
571        match *binding {
572            crate::Binding::BuiltIn(crate::BuiltIn::Position { invariant: true }) => {
573                write!(self.out, "precise ")?;
574            }
575            crate::Binding::BuiltIn(crate::BuiltIn::Barycentric { perspective: false }) => {
576                write!(self.out, "noperspective ")?;
577            }
578            crate::Binding::Location {
579                interpolation,
580                sampling,
581                ..
582            } => {
583                if let Some(interpolation) = interpolation {
584                    if let Some(string) = interpolation.to_hlsl_str() {
585                        write!(self.out, "{string} ")?
586                    }
587                }
588
589                if let Some(sampling) = sampling {
590                    if let Some(string) = sampling.to_hlsl_str() {
591                        write!(self.out, "{string} ")?
592                    }
593                }
594            }
595            crate::Binding::BuiltIn(_) => {}
596        }
597
598        Ok(())
599    }
600
601    //TODO: we could force fragment outputs to always go through `entry_point_io.output` path
602    // if they are struct, so that the `stage` argument here could be omitted.
603    pub(super) fn write_semantic(
604        &mut self,
605        binding: &Option<crate::Binding>,
606        stage: Option<(ShaderStage, Io)>,
607    ) -> BackendResult {
608        let is_per_primitive = match *binding {
609            Some(crate::Binding::BuiltIn(builtin)) if !is_subgroup_builtin_binding(binding) => {
610                if builtin == crate::BuiltIn::ViewIndex
611                    && self.options.shader_model < ShaderModel::V6_1
612                {
613                    return Err(Error::ShaderModelTooLow(
614                        "used @builtin(view_index) or SV_ViewID".to_string(),
615                        ShaderModel::V6_1,
616                    ));
617                }
618                if let Some(builtin_str) = builtin.to_hlsl_str()? {
619                    write!(self.out, " : {builtin_str}")?;
620                }
621                false
622            }
623            Some(crate::Binding::Location {
624                blend_src: Some(1),
625                per_primitive,
626                ..
627            }) => {
628                write!(self.out, " : SV_Target1")?;
629                per_primitive
630            }
631            Some(crate::Binding::Location {
632                location,
633                per_primitive,
634                ..
635            }) => {
636                if stage == Some((ShaderStage::Fragment, Io::Output)) {
637                    write!(self.out, " : SV_Target{location}")?;
638                } else {
639                    write!(self.out, " : {LOCATION_SEMANTIC}{location}")?;
640                }
641                per_primitive
642            }
643            _ => false,
644        };
645        if is_per_primitive {
646            write!(self.out, " : primitive")?;
647        }
648
649        Ok(())
650    }
651
652    pub(super) fn write_interface_struct(
653        &mut self,
654        module: &Module,
655        shader_stage: (ShaderStage, Io),
656        struct_name: String,
657        var_name: Option<&str>,
658        mut members: Vec<EpStructMember>,
659    ) -> Result<EntryPointBinding, Error> {
660        let struct_name = self.namer.call(&struct_name);
661        // Sort the members so that first come the user-defined varyings
662        // in ascending locations, and then built-ins. This allows VS and FS
663        // interfaces to match with regards to order.
664        members.sort_by_key(|m| InterfaceKey::new(m.binding.as_ref()));
665
666        write!(self.out, "struct {struct_name}")?;
667        writeln!(self.out, " {{")?;
668        let mut local_invocation_index_name = None;
669        let mut subgroup_id_used = false;
670        for m in members.iter() {
671            // Sanity check that each IO member is a built-in or is assigned a
672            // location. Also see note about nesting in `write_ep_input_struct`.
673            debug_assert!(m.binding.is_some());
674
675            match m.binding {
676                Some(crate::Binding::BuiltIn(crate::BuiltIn::SubgroupId)) => {
677                    subgroup_id_used = true;
678                }
679                Some(crate::Binding::BuiltIn(crate::BuiltIn::LocalInvocationIndex)) => {
680                    local_invocation_index_name = Some(m.name.clone());
681                }
682                _ => (),
683            }
684
685            if is_subgroup_builtin_binding(&m.binding) {
686                continue;
687            }
688            write!(self.out, "{}", back::INDENT)?;
689            if let Some(ref binding) = m.binding {
690                self.write_modifier(binding)?;
691            }
692            self.write_type(module, m.ty)?;
693            write!(self.out, " {}", &m.name)?;
694            self.write_semantic(&m.binding, Some(shader_stage))?;
695            writeln!(self.out, ";")?;
696        }
697        if subgroup_id_used && local_invocation_index_name.is_none() {
698            let name = self.namer.call("local_invocation_index");
699            writeln!(self.out, "{}uint {name} : SV_GroupIndex;", back::INDENT)?;
700            local_invocation_index_name = Some(name);
701        }
702        writeln!(self.out, "}};")?;
703        writeln!(self.out)?;
704
705        // See ordering notes on EntryPointInterface fields
706        match shader_stage.1 {
707            Io::Input => {
708                // bring back the original order
709                members.sort_by_key(|m| m.index);
710            }
711            Io::Output | Io::MeshVertices | Io::MeshPrimitives => {
712                // keep it sorted by binding
713            }
714        }
715
716        Ok(EntryPointBinding {
717            arg_name: self
718                .namer
719                .call(var_name.unwrap_or(struct_name.to_lowercase().as_str())),
720            ty_name: struct_name,
721            members,
722            local_invocation_index_name,
723        })
724    }
725
726    /// Flatten all entry point arguments into a single struct.
727    /// This is needed since we need to re-order them: first placing user locations,
728    /// then built-ins.
729    fn write_ep_input_struct(
730        &mut self,
731        module: &Module,
732        func: &crate::Function,
733        stage: ShaderStage,
734        entry_point_name: &str,
735    ) -> Result<EntryPointBinding, Error> {
736        let struct_name = format!("{stage:?}Input_{entry_point_name}");
737
738        let mut fake_members = Vec::new();
739        for arg in func.arguments.iter() {
740            // NOTE: We don't need to handle nesting structs. All members must
741            // be either built-ins or assigned a location. I.E. `binding` is
742            // `Some`. This is checked in `VaryingContext::validate`. See:
743            // https://gpuweb.github.io/gpuweb/wgsl/#input-output-locations
744            match module.types[arg.ty].inner {
745                TypeInner::Struct { ref members, .. } => {
746                    for member in members.iter() {
747                        let name = self.namer.call_or(&member.name, "member");
748                        let index = fake_members.len() as u32;
749                        fake_members.push(EpStructMember {
750                            name,
751                            ty: member.ty,
752                            binding: member.binding.clone(),
753                            index,
754                        });
755                    }
756                }
757                _ => {
758                    let member_name = self.namer.call_or(&arg.name, "member");
759                    let index = fake_members.len() as u32;
760                    fake_members.push(EpStructMember {
761                        name: member_name,
762                        ty: arg.ty,
763                        binding: arg.binding.clone(),
764                        index,
765                    });
766                }
767            }
768        }
769
770        self.write_interface_struct(module, (stage, Io::Input), struct_name, None, fake_members)
771    }
772
773    /// Flatten all entry point results into a single struct.
774    /// This is needed since we need to re-order them: first placing user locations,
775    /// then built-ins.
776    fn write_ep_output_struct(
777        &mut self,
778        module: &Module,
779        result: &crate::FunctionResult,
780        stage: ShaderStage,
781        entry_point_name: &str,
782        frag_ep: Option<&FragmentEntryPoint<'_>>,
783    ) -> Result<EntryPointBinding, Error> {
784        let struct_name = format!("{stage:?}Output_{entry_point_name}");
785
786        let empty = [];
787        let members = match module.types[result.ty].inner {
788            TypeInner::Struct { ref members, .. } => members,
789            ref other => {
790                log::error!("Unexpected {other:?} output type without a binding");
791                &empty[..]
792            }
793        };
794
795        // Gather list of fragment input locations. We use this below to remove user-defined
796        // varyings from VS outputs that aren't in the FS inputs. This makes the VS interface match
797        // as long as the FS inputs are a subset of the VS outputs. This is only applied if the
798        // writer is supplied with information about the fragment entry point.
799        let fs_input_locs = if let (Some(frag_ep), ShaderStage::Vertex) = (frag_ep, stage) {
800            let mut fs_input_locs = Vec::new();
801            for arg in frag_ep.func.arguments.iter() {
802                let mut push_if_location = |binding: &Option<crate::Binding>| match *binding {
803                    Some(crate::Binding::Location { location, .. }) => fs_input_locs.push(location),
804                    Some(crate::Binding::BuiltIn(_)) | None => {}
805                };
806
807                // NOTE: We don't need to handle struct nesting. See note in
808                // `write_ep_input_struct`.
809                match frag_ep.module.types[arg.ty].inner {
810                    TypeInner::Struct { ref members, .. } => {
811                        for member in members.iter() {
812                            push_if_location(&member.binding);
813                        }
814                    }
815                    _ => push_if_location(&arg.binding),
816                }
817            }
818            fs_input_locs.sort();
819            Some(fs_input_locs)
820        } else {
821            None
822        };
823
824        let mut fake_members = Vec::new();
825        for (index, member) in members.iter().enumerate() {
826            if let Some(ref fs_input_locs) = fs_input_locs {
827                match member.binding {
828                    Some(crate::Binding::Location { location, .. }) => {
829                        if fs_input_locs.binary_search(&location).is_err() {
830                            continue;
831                        }
832                    }
833                    Some(crate::Binding::BuiltIn(_)) | None => {}
834                }
835            }
836
837            let member_name = self.namer.call_or(&member.name, "member");
838            fake_members.push(EpStructMember {
839                name: member_name,
840                ty: member.ty,
841                binding: member.binding.clone(),
842                index: index as u32,
843            });
844        }
845
846        self.write_interface_struct(module, (stage, Io::Output), struct_name, None, fake_members)
847    }
848
849    /// Writes special interface structures for an entry point. The special structures have
850    /// all the fields flattened into them and sorted by binding. They are needed to emulate
851    /// subgroup built-ins and to make the interfaces between VS outputs and FS inputs match.
852    fn write_ep_interface(
853        &mut self,
854        module: &Module,
855        ep: &crate::EntryPoint,
856        ep_name: &str,
857        frag_ep: Option<&FragmentEntryPoint<'_>>,
858    ) -> Result<EntryPointInterface, Error> {
859        let func = &ep.function;
860        let stage = ep.stage;
861        Ok(EntryPointInterface {
862            input: if !func.arguments.is_empty()
863                && (stage == ShaderStage::Fragment
864                    || func
865                        .arguments
866                        .iter()
867                        .any(|arg| is_subgroup_builtin_binding(&arg.binding)))
868            {
869                Some(self.write_ep_input_struct(module, func, stage, ep_name)?)
870            } else {
871                None
872            },
873            output: match func.result {
874                Some(ref fr) if fr.binding.is_none() && stage == ShaderStage::Vertex => {
875                    Some(self.write_ep_output_struct(module, fr, stage, ep_name, frag_ep)?)
876                }
877                _ => None,
878            },
879            mesh_vertices: if let Some(ref info) = ep.mesh_info {
880                Some(self.write_ep_mesh_output_struct(module, ep_name, false, info)?)
881            } else {
882                None
883            },
884            mesh_primitives: if let Some(ref info) = ep.mesh_info {
885                Some(self.write_ep_mesh_output_struct(module, ep_name, true, info)?)
886            } else {
887                None
888            },
889            mesh_indices: if let Some(ref info) = ep.mesh_info {
890                Some(self.write_ep_mesh_output_indices(info.topology)?)
891            } else {
892                None
893            },
894        })
895    }
896
897    fn write_ep_argument_initialization(
898        &mut self,
899        ep: &crate::EntryPoint,
900        ep_input: &EntryPointBinding,
901        fake_member: &EpStructMember,
902    ) -> BackendResult {
903        match fake_member.binding {
904            Some(crate::Binding::BuiltIn(crate::BuiltIn::SubgroupSize)) => {
905                write!(self.out, "WaveGetLaneCount()")?
906            }
907            Some(crate::Binding::BuiltIn(crate::BuiltIn::SubgroupInvocationId)) => {
908                write!(self.out, "WaveGetLaneIndex()")?
909            }
910            Some(crate::Binding::BuiltIn(crate::BuiltIn::NumSubgroups)) => write!(
911                self.out,
912                "({}u + WaveGetLaneCount() - 1u) / WaveGetLaneCount()",
913                ep.workgroup_size[0] * ep.workgroup_size[1] * ep.workgroup_size[2]
914            )?,
915            Some(crate::Binding::BuiltIn(crate::BuiltIn::SubgroupId)) => {
916                write!(
917                    self.out,
918                    "{}.{} / WaveGetLaneCount()",
919                    ep_input.arg_name,
920                    // When writing SubgroupId, we always guarantee that local_invocation_index_name is written
921                    ep_input.local_invocation_index_name.as_ref().unwrap()
922                )?;
923            }
924            _ => {
925                write!(self.out, "{}.{}", ep_input.arg_name, fake_member.name)?;
926            }
927        }
928        Ok(())
929    }
930
931    /// Write an entry point preface that initializes the arguments as specified in IR.
932    fn write_ep_arguments_initialization(
933        &mut self,
934        module: &Module,
935        func: &crate::Function,
936        ep_index: u16,
937    ) -> BackendResult {
938        let ep = &module.entry_points[ep_index as usize];
939        let ep_input = match self
940            .entry_point_io
941            .get_mut(&(ep_index as usize))
942            .unwrap()
943            .input
944            .take()
945        {
946            Some(ep_input) => ep_input,
947            None => return Ok(()),
948        };
949        let mut fake_iter = ep_input.members.iter();
950        for (arg_index, arg) in func.arguments.iter().enumerate() {
951            write!(self.out, "{}", back::INDENT)?;
952            self.write_type(module, arg.ty)?;
953            let arg_name = &self.names[&NameKey::EntryPointArgument(ep_index, arg_index as u32)];
954            write!(self.out, " {arg_name}")?;
955            match module.types[arg.ty].inner {
956                TypeInner::Array { base, size, .. } => {
957                    self.write_array_size(module, base, size)?;
958                    write!(self.out, " = ")?;
959                    self.write_ep_argument_initialization(
960                        ep,
961                        &ep_input,
962                        fake_iter.next().unwrap(),
963                    )?;
964                    writeln!(self.out, ";")?;
965                }
966                TypeInner::Struct { ref members, .. } => {
967                    write!(self.out, " = {{ ")?;
968                    for index in 0..members.len() {
969                        if index != 0 {
970                            write!(self.out, ", ")?;
971                        }
972                        self.write_ep_argument_initialization(
973                            ep,
974                            &ep_input,
975                            fake_iter.next().unwrap(),
976                        )?;
977                    }
978                    writeln!(self.out, " }};")?;
979                }
980                _ => {
981                    write!(self.out, " = ")?;
982                    self.write_ep_argument_initialization(
983                        ep,
984                        &ep_input,
985                        fake_iter.next().unwrap(),
986                    )?;
987                    writeln!(self.out, ";")?;
988                }
989            }
990        }
991        assert!(fake_iter.next().is_none());
992        Ok(())
993    }
994
995    /// Helper method used to write global variables
996    /// # Notes
997    /// Always adds a newline
998    fn write_global(
999        &mut self,
1000        module: &Module,
1001        handle: Handle<crate::GlobalVariable>,
1002    ) -> BackendResult {
1003        let global = &module.global_variables[handle];
1004        let inner = &module.types[global.ty].inner;
1005
1006        let handle_ty = match *inner {
1007            TypeInner::BindingArray { ref base, .. } => &module.types[*base].inner,
1008            _ => inner,
1009        };
1010
1011        // External textures are handled entirely differently, so defer entirely to that method.
1012        // We do so prior to calling resolve_resource_binding() below, as we even need to resolve
1013        // their bindings separately.
1014        let is_external_texture = matches!(
1015            *handle_ty,
1016            TypeInner::Image {
1017                class: crate::ImageClass::External,
1018                ..
1019            }
1020        );
1021        if is_external_texture {
1022            return self.write_global_external_texture(module, handle, global);
1023        }
1024
1025        if let Some(ref binding) = global.binding {
1026            if let Err(err) = self.options.resolve_resource_binding(binding) {
1027                log::debug!(
1028                    "Skipping global {:?} (name {:?}) for being inaccessible: {}",
1029                    handle,
1030                    global.name,
1031                    err,
1032                );
1033                return Ok(());
1034            }
1035        }
1036
1037        // Samplers are handled entirely differently, so defer entirely to that method.
1038        let is_sampler = matches!(*handle_ty, TypeInner::Sampler { .. });
1039
1040        if is_sampler {
1041            return self.write_global_sampler(module, handle, global);
1042        }
1043
1044        // https://docs.microsoft.com/en-us/windows/win32/direct3dhlsl/dx-graphics-hlsl-variable-register
1045        let register_ty = match global.space {
1046            crate::AddressSpace::Function => unreachable!("Function address space"),
1047            crate::AddressSpace::Private => {
1048                write!(self.out, "static ")?;
1049                self.write_type(module, global.ty)?;
1050                ""
1051            }
1052            crate::AddressSpace::WorkGroup | crate::AddressSpace::TaskPayload => {
1053                write!(self.out, "groupshared ")?;
1054                self.write_type(module, global.ty)?;
1055                ""
1056            }
1057            crate::AddressSpace::Uniform => {
1058                // constant buffer declarations are expected to be inlined, e.g.
1059                // `cbuffer foo: register(b0) { field1: type1; }`
1060                write!(self.out, "cbuffer")?;
1061                "b"
1062            }
1063            crate::AddressSpace::Storage { access } => {
1064                if global
1065                    .memory_decorations
1066                    .contains(crate::MemoryDecorations::COHERENT)
1067                {
1068                    write!(self.out, "globallycoherent ")?;
1069                }
1070                let (prefix, register) = if access.contains(crate::StorageAccess::STORE) {
1071                    ("RW", "u")
1072                } else {
1073                    ("", "t")
1074                };
1075                write!(self.out, "{prefix}ByteAddressBuffer")?;
1076                register
1077            }
1078            crate::AddressSpace::Handle => {
1079                let register = match *handle_ty {
1080                    // all storage textures are UAV, unconditionally
1081                    TypeInner::Image {
1082                        class: crate::ImageClass::Storage { .. },
1083                        ..
1084                    } => "u",
1085                    _ => "t",
1086                };
1087                self.write_type(module, global.ty)?;
1088                register
1089            }
1090            crate::AddressSpace::Immediate => {
1091                // The type of the immediates will be wrapped in `ConstantBuffer`
1092                write!(self.out, "ConstantBuffer<")?;
1093                "b"
1094            }
1095            crate::AddressSpace::RayPayload | crate::AddressSpace::IncomingRayPayload => {
1096                unimplemented!()
1097            }
1098        };
1099
1100        // If the global is a immediate data write the type now because it will be a
1101        // generic argument to `ConstantBuffer`
1102        if global.space == crate::AddressSpace::Immediate {
1103            self.write_global_type(module, global.ty)?;
1104
1105            // need to write the array size if the type was emitted with `write_type`
1106            if let TypeInner::Array { base, size, .. } = module.types[global.ty].inner {
1107                self.write_array_size(module, base, size)?;
1108            }
1109
1110            // Close the angled brackets for the generic argument
1111            write!(self.out, ">")?;
1112        }
1113
1114        let name = &self.names[&NameKey::GlobalVariable(handle)];
1115        write!(self.out, " {name}")?;
1116
1117        // Immediates need to be assigned a binding explicitly by the consumer
1118        // since naga has no way to know the binding from the shader alone
1119        if global.space == crate::AddressSpace::Immediate {
1120            match module.types[global.ty].inner {
1121                TypeInner::Struct { .. } => {}
1122                _ => {
1123                    return Err(Error::Unimplemented(format!(
1124                        "push-constant '{name}' has non-struct type; tracked by: https://github.com/gfx-rs/wgpu/issues/5683"
1125                    )));
1126                }
1127            }
1128
1129            let target = self
1130                .options
1131                .immediates_target
1132                .as_ref()
1133                .expect("No bind target was defined for the immediates block");
1134            write!(self.out, ": register(b{}", target.register)?;
1135            if target.space != 0 {
1136                write!(self.out, ", space{}", target.space)?;
1137            }
1138            write!(self.out, ")")?;
1139        }
1140
1141        if let Some(ref binding) = global.binding {
1142            // this was already resolved earlier when we started evaluating an entry point.
1143            let bt = self.options.resolve_resource_binding(binding).unwrap();
1144
1145            // need to write the binding array size if the type was emitted with `write_type`
1146            if let TypeInner::BindingArray { base, size, .. } = module.types[global.ty].inner {
1147                if let Some(overridden_size) = bt.binding_array_size {
1148                    write!(self.out, "[{overridden_size}]")?;
1149                } else {
1150                    self.write_array_size(module, base, size)?;
1151                }
1152            }
1153
1154            write!(self.out, " : register({}{}", register_ty, bt.register)?;
1155            if bt.space != 0 {
1156                write!(self.out, ", space{}", bt.space)?;
1157            }
1158            write!(self.out, ")")?;
1159        } else {
1160            // need to write the array size if the type was emitted with `write_type`
1161            if let TypeInner::Array { base, size, .. } = module.types[global.ty].inner {
1162                self.write_array_size(module, base, size)?;
1163            }
1164            if global.space == crate::AddressSpace::Private {
1165                write!(self.out, " = ")?;
1166                if let Some(init) = global.init {
1167                    self.write_const_expression(module, init, &module.global_expressions)?;
1168                } else {
1169                    self.write_default_init(module, global.ty)?;
1170                }
1171            }
1172        }
1173
1174        if global.space == crate::AddressSpace::Uniform {
1175            write!(self.out, " {{ ")?;
1176
1177            self.write_global_type(module, global.ty)?;
1178
1179            write!(
1180                self.out,
1181                " {}",
1182                &self.names[&NameKey::GlobalVariable(handle)]
1183            )?;
1184
1185            // need to write the array size if the type was emitted with `write_type`
1186            if let TypeInner::Array { base, size, .. } = module.types[global.ty].inner {
1187                self.write_array_size(module, base, size)?;
1188            }
1189
1190            writeln!(self.out, "; }}")?;
1191        } else {
1192            writeln!(self.out, ";")?;
1193        }
1194
1195        Ok(())
1196    }
1197
1198    fn write_global_sampler(
1199        &mut self,
1200        module: &Module,
1201        handle: Handle<crate::GlobalVariable>,
1202        global: &crate::GlobalVariable,
1203    ) -> BackendResult {
1204        let binding = *global.binding.as_ref().unwrap();
1205
1206        let key = super::SamplerIndexBufferKey {
1207            group: binding.group,
1208        };
1209        self.write_wrapped_sampler_buffer(key)?;
1210
1211        // This was already validated, so we can confidently unwrap it.
1212        let bt = self.options.resolve_resource_binding(&binding).unwrap();
1213
1214        match module.types[global.ty].inner {
1215            TypeInner::Sampler { comparison } => {
1216                // If we are generating a static access, we create a variable for the sampler.
1217                //
1218                // This prevents the DXIL from containing multiple lookups for the sampler, which
1219                // the backend compiler will then have to eliminate. AMD does seem to be able to
1220                // eliminate these, but better safe than sorry.
1221
1222                write!(self.out, "static const ")?;
1223                self.write_type(module, global.ty)?;
1224
1225                let heap_var = if comparison {
1226                    COMPARISON_SAMPLER_HEAP_VAR
1227                } else {
1228                    SAMPLER_HEAP_VAR
1229                };
1230
1231                let index_buffer_name = &self.wrapped.sampler_index_buffers[&key];
1232                let name = &self.names[&NameKey::GlobalVariable(handle)];
1233                writeln!(
1234                    self.out,
1235                    " {name} = {heap_var}[{index_buffer_name}[{register}]];",
1236                    register = bt.register
1237                )?;
1238            }
1239            TypeInner::BindingArray { .. } => {
1240                // If we are generating a binding array, we cannot directly access the sampler as the index
1241                // into the sampler index buffer is unknown at compile time. Instead we generate a constant
1242                // that represents the "base" index into the sampler index buffer. This constant is added
1243                // to the user provided index to get the final index into the sampler index buffer.
1244
1245                let name = &self.names[&NameKey::GlobalVariable(handle)];
1246                writeln!(
1247                    self.out,
1248                    "static const uint {name} = {register};",
1249                    register = bt.register
1250                )?;
1251            }
1252            _ => unreachable!(),
1253        };
1254
1255        Ok(())
1256    }
1257
1258    /// Write the declarations for an external texture global variable.
1259    /// These are emitted as multiple global variables: Three `Texture2D`s
1260    /// (one for each plane) and a parameters cbuffer.
1261    fn write_global_external_texture(
1262        &mut self,
1263        module: &Module,
1264        handle: Handle<crate::GlobalVariable>,
1265        global: &crate::GlobalVariable,
1266    ) -> BackendResult {
1267        let res_binding = global
1268            .binding
1269            .as_ref()
1270            .expect("External texture global variables must have a resource binding");
1271        let ext_tex_bindings = match self
1272            .options
1273            .resolve_external_texture_resource_binding(res_binding)
1274        {
1275            Ok(bindings) => bindings,
1276            Err(err) => {
1277                log::debug!(
1278                    "Skipping global {:?} (name {:?}) for being inaccessible: {}",
1279                    handle,
1280                    global.name,
1281                    err,
1282                );
1283                return Ok(());
1284            }
1285        };
1286
1287        let mut write_plane = |bt: &super::BindTarget, name| -> BackendResult {
1288            write!(
1289                self.out,
1290                "Texture2D<float4> {}: register(t{}",
1291                name, bt.register
1292            )?;
1293            if bt.space != 0 {
1294                write!(self.out, ", space{}", bt.space)?;
1295            }
1296            writeln!(self.out, ");")?;
1297            Ok(())
1298        };
1299        for (i, bt) in ext_tex_bindings.planes.iter().enumerate() {
1300            let plane_name = &self.names
1301                [&NameKey::ExternalTextureGlobalVariable(handle, ExternalTextureNameKey::Plane(i))];
1302            write_plane(bt, plane_name)?;
1303        }
1304
1305        let params_name = &self.names
1306            [&NameKey::ExternalTextureGlobalVariable(handle, ExternalTextureNameKey::Params)];
1307        let params_ty_name =
1308            &self.names[&NameKey::Type(module.special_types.external_texture_params.unwrap())];
1309        write!(
1310            self.out,
1311            "cbuffer {}: register(b{}",
1312            params_name, ext_tex_bindings.params.register
1313        )?;
1314        if ext_tex_bindings.params.space != 0 {
1315            write!(self.out, ", space{}", ext_tex_bindings.params.space)?;
1316        }
1317        writeln!(self.out, ") {{ {params_ty_name} {params_name}; }};")?;
1318
1319        Ok(())
1320    }
1321
1322    /// Helper method used to write global constants
1323    ///
1324    /// # Notes
1325    /// Ends in a newline
1326    fn write_global_constant(
1327        &mut self,
1328        module: &Module,
1329        handle: Handle<crate::Constant>,
1330    ) -> BackendResult {
1331        write!(self.out, "static const ")?;
1332        let constant = &module.constants[handle];
1333        self.write_type(module, constant.ty)?;
1334        let name = &self.names[&NameKey::Constant(handle)];
1335        write!(self.out, " {name}")?;
1336        // Write size for array type
1337        if let TypeInner::Array { base, size, .. } = module.types[constant.ty].inner {
1338            self.write_array_size(module, base, size)?;
1339        }
1340        write!(self.out, " = ")?;
1341        self.write_const_expression(module, constant.init, &module.global_expressions)?;
1342        writeln!(self.out, ";")?;
1343        Ok(())
1344    }
1345
1346    pub(super) fn write_array_size(
1347        &mut self,
1348        module: &Module,
1349        base: Handle<crate::Type>,
1350        size: crate::ArraySize,
1351    ) -> BackendResult {
1352        write!(self.out, "[")?;
1353
1354        match size.resolve(module.to_ctx())? {
1355            proc::IndexableLength::Known(size) => {
1356                write!(self.out, "{size}")?;
1357            }
1358            proc::IndexableLength::Dynamic => unreachable!(),
1359        }
1360
1361        write!(self.out, "]")?;
1362
1363        if let TypeInner::Array {
1364            base: next_base,
1365            size: next_size,
1366            ..
1367        } = module.types[base].inner
1368        {
1369            self.write_array_size(module, next_base, next_size)?;
1370        }
1371
1372        Ok(())
1373    }
1374
1375    /// Helper method used to write structs
1376    ///
1377    /// # Notes
1378    /// Ends in a newline
1379    fn write_struct(
1380        &mut self,
1381        module: &Module,
1382        handle: Handle<crate::Type>,
1383        members: &[crate::StructMember],
1384        span: u32,
1385        shader_stage: Option<(ShaderStage, Io)>,
1386    ) -> BackendResult {
1387        // Write struct name
1388        let struct_name = &self.names[&NameKey::Type(handle)];
1389        writeln!(self.out, "struct {struct_name} {{")?;
1390
1391        let mut last_offset = 0;
1392        for (index, member) in members.iter().enumerate() {
1393            if member.binding.is_none() && member.offset > last_offset {
1394                // using int as padding should work as long as the backend
1395                // doesn't support a type that's less than 4 bytes in size
1396                // (Error::UnsupportedScalar catches this)
1397                let padding = (member.offset - last_offset) / 4;
1398                for i in 0..padding {
1399                    writeln!(self.out, "{}int _pad{}_{};", back::INDENT, index, i)?;
1400                }
1401            }
1402            let ty_inner = &module.types[member.ty].inner;
1403            last_offset = member.offset + ty_inner.size_hlsl(module.to_ctx())?;
1404
1405            // The indentation is only for readability
1406            write!(self.out, "{}", back::INDENT)?;
1407
1408            match module.types[member.ty].inner {
1409                TypeInner::Array { base, size, .. } => {
1410                    // HLSL arrays are written as `type name[size]`
1411
1412                    self.write_global_type(module, member.ty)?;
1413
1414                    // Write `name`
1415                    write!(
1416                        self.out,
1417                        " {}",
1418                        &self.names[&NameKey::StructMember(handle, index as u32)]
1419                    )?;
1420                    // Write [size]
1421                    self.write_array_size(module, base, size)?;
1422                }
1423                // We treat matrices of the form `matCx2` as a sequence of C `vec2`s.
1424                // See the module-level block comment in mod.rs for details.
1425                TypeInner::Matrix {
1426                    rows,
1427                    columns,
1428                    scalar,
1429                } if member.binding.is_none() && rows == crate::VectorSize::Bi => {
1430                    let vec_ty = TypeInner::Vector { size: rows, scalar };
1431                    let field_name_key = NameKey::StructMember(handle, index as u32);
1432
1433                    for i in 0..columns as u8 {
1434                        if i != 0 {
1435                            write!(self.out, "; ")?;
1436                        }
1437                        self.write_value_type(module, &vec_ty)?;
1438                        write!(self.out, " {}_{}", &self.names[&field_name_key], i)?;
1439                    }
1440                }
1441                _ => {
1442                    // Write modifier before type
1443                    if let Some(ref binding) = member.binding {
1444                        self.write_modifier(binding)?;
1445                    }
1446
1447                    // Even though Naga IR matrices are column-major, we must describe
1448                    // matrices passed from the CPU as being in row-major order.
1449                    // See the module-level block comment in mod.rs for details.
1450                    if let TypeInner::Matrix { .. } = module.types[member.ty].inner {
1451                        write!(self.out, "row_major ")?;
1452                    }
1453
1454                    // Write the member type and name
1455                    self.write_type(module, member.ty)?;
1456                    write!(
1457                        self.out,
1458                        " {}",
1459                        &self.names[&NameKey::StructMember(handle, index as u32)]
1460                    )?;
1461                }
1462            }
1463
1464            self.write_semantic(&member.binding, shader_stage)?;
1465            writeln!(self.out, ";")?;
1466        }
1467
1468        // add padding at the end since sizes of types don't get rounded up to their alignment in HLSL
1469        if members.last().unwrap().binding.is_none() && span > last_offset {
1470            let padding = (span - last_offset) / 4;
1471            for i in 0..padding {
1472                writeln!(self.out, "{}int _end_pad_{};", back::INDENT, i)?;
1473            }
1474        }
1475
1476        writeln!(self.out, "}};")?;
1477        Ok(())
1478    }
1479
1480    /// Helper method used to write global/structs non image/sampler types
1481    ///
1482    /// # Notes
1483    /// Adds no trailing or leading whitespace
1484    pub(super) fn write_global_type(
1485        &mut self,
1486        module: &Module,
1487        ty: Handle<crate::Type>,
1488    ) -> BackendResult {
1489        let matrix_data = get_inner_matrix_data(module, ty);
1490
1491        // We treat matrices of the form `matCx2` as a sequence of C `vec2`s.
1492        // See the module-level block comment in mod.rs for details.
1493        if let Some(MatrixType {
1494            columns,
1495            rows: crate::VectorSize::Bi,
1496            width: 4,
1497        }) = matrix_data
1498        {
1499            write!(self.out, "__mat{}x2", columns as u8)?;
1500        } else {
1501            // Even though Naga IR matrices are column-major, we must describe
1502            // matrices passed from the CPU as being in row-major order.
1503            // See the module-level block comment in mod.rs for details.
1504            if matrix_data.is_some() {
1505                write!(self.out, "row_major ")?;
1506            }
1507
1508            self.write_type(module, ty)?;
1509        }
1510
1511        Ok(())
1512    }
1513
1514    /// Helper method used to write non image/sampler types
1515    ///
1516    /// # Notes
1517    /// Adds no trailing or leading whitespace
1518    pub(super) fn write_type(&mut self, module: &Module, ty: Handle<crate::Type>) -> BackendResult {
1519        let inner = &module.types[ty].inner;
1520        match *inner {
1521            TypeInner::Struct { .. } => write!(self.out, "{}", self.names[&NameKey::Type(ty)])?,
1522            // hlsl array has the size separated from the base type
1523            TypeInner::Array { base, .. } | TypeInner::BindingArray { base, .. } => {
1524                self.write_type(module, base)?
1525            }
1526            ref other => self.write_value_type(module, other)?,
1527        }
1528
1529        Ok(())
1530    }
1531
1532    /// Helper method used to write value types
1533    ///
1534    /// # Notes
1535    /// Adds no trailing or leading whitespace
1536    pub(super) fn write_value_type(&mut self, module: &Module, inner: &TypeInner) -> BackendResult {
1537        match *inner {
1538            TypeInner::Scalar(scalar) | TypeInner::Atomic(scalar) => {
1539                write!(self.out, "{}", scalar.to_hlsl_str()?)?;
1540            }
1541            TypeInner::Vector { size, scalar } => {
1542                write!(
1543                    self.out,
1544                    "{}{}",
1545                    scalar.to_hlsl_str()?,
1546                    common::vector_size_str(size)
1547                )?;
1548            }
1549            TypeInner::Matrix {
1550                columns,
1551                rows,
1552                scalar,
1553            } => {
1554                // The IR supports only float matrix
1555                // https://docs.microsoft.com/en-us/windows/win32/direct3dhlsl/dx-graphics-hlsl-matrix
1556
1557                // Because of the implicit transpose all matrices have in HLSL, we need to transpose the size as well.
1558                write!(
1559                    self.out,
1560                    "{}{}x{}",
1561                    scalar.to_hlsl_str()?,
1562                    common::vector_size_str(columns),
1563                    common::vector_size_str(rows),
1564                )?;
1565            }
1566            TypeInner::Image {
1567                dim,
1568                arrayed,
1569                class,
1570            } => {
1571                self.write_image_type(dim, arrayed, class)?;
1572            }
1573            TypeInner::Sampler { comparison } => {
1574                let sampler = if comparison {
1575                    "SamplerComparisonState"
1576                } else {
1577                    "SamplerState"
1578                };
1579                write!(self.out, "{sampler}")?;
1580            }
1581            // HLSL arrays are written as `type name[size]`
1582            // Current code is written arrays only as `[size]`
1583            // Base `type` and `name` should be written outside
1584            TypeInner::Array { base, size, .. } | TypeInner::BindingArray { base, size } => {
1585                self.write_array_size(module, base, size)?;
1586            }
1587            TypeInner::AccelerationStructure { .. } => {
1588                write!(self.out, "RaytracingAccelerationStructure")?;
1589            }
1590            TypeInner::RayQuery { .. } => {
1591                // these are constant flags, there are dynamic flags also but constant flags are not supported by naga
1592                write!(self.out, "RayQuery<RAY_FLAG_NONE>")?;
1593            }
1594            _ => return Err(Error::Unimplemented(format!("write_value_type {inner:?}"))),
1595        }
1596
1597        Ok(())
1598    }
1599
1600    /// Helper method used to write functions
1601    /// # Notes
1602    /// Ends in a newline
1603    fn write_function(
1604        &mut self,
1605        module: &Module,
1606        name: &str,
1607        func: &crate::Function,
1608        func_ctx: &back::FunctionCtx<'_>,
1609        info: &valid::FunctionInfo,
1610        header: String,
1611    ) -> BackendResult {
1612        // Function Declaration Syntax - https://docs.microsoft.com/en-us/windows/win32/direct3dhlsl/dx-graphics-hlsl-function-syntax
1613
1614        self.update_expressions_to_bake(module, func, info);
1615        let ep = match func_ctx.ty {
1616            back::FunctionType::EntryPoint(idx) => Some(&module.entry_points[idx as usize]),
1617            back::FunctionType::Function(_) => None,
1618        };
1619
1620        let nested = matches!(
1621            ep,
1622            Some(crate::EntryPoint {
1623                stage: ShaderStage::Task | ShaderStage::Mesh,
1624                ..
1625            })
1626        );
1627        if !nested {
1628            write!(self.out, "{header}")?;
1629        }
1630
1631        if let Some(ref result) = func.result {
1632            // Write typedef if return type is an array
1633            let array_return_type = match module.types[result.ty].inner {
1634                TypeInner::Array { base, size, .. } => {
1635                    let array_return_type = self.namer.call(&format!("ret_{name}"));
1636                    write!(self.out, "typedef ")?;
1637                    self.write_type(module, result.ty)?;
1638                    write!(self.out, " {array_return_type}")?;
1639                    self.write_array_size(module, base, size)?;
1640                    writeln!(self.out, ";")?;
1641                    Some(array_return_type)
1642                }
1643                _ => None,
1644            };
1645
1646            // Write modifier
1647            if let Some(
1648                ref binding @ crate::Binding::BuiltIn(crate::BuiltIn::Position { invariant: true }),
1649            ) = result.binding
1650            {
1651                self.write_modifier(binding)?;
1652            }
1653
1654            // Write return type
1655            match func_ctx.ty {
1656                back::FunctionType::Function(_) => {
1657                    if let Some(array_return_type) = array_return_type {
1658                        write!(self.out, "{array_return_type}")?;
1659                    } else {
1660                        self.write_type(module, result.ty)?;
1661                    }
1662                }
1663                back::FunctionType::EntryPoint(index) => {
1664                    if let Some(ref ep_output) =
1665                        self.entry_point_io.get(&(index as usize)).unwrap().output
1666                    {
1667                        write!(self.out, "{}", ep_output.ty_name)?;
1668                    } else {
1669                        self.write_type(module, result.ty)?;
1670                    }
1671                }
1672            }
1673        } else {
1674            write!(self.out, "void")?;
1675        }
1676
1677        let nested_name = if nested {
1678            self.namer.call(&format!("_{name}"))
1679        } else {
1680            name.to_string()
1681        };
1682
1683        // Write function name
1684        write!(self.out, " {nested_name}(")?;
1685
1686        let need_workgroup_variables_initialization =
1687            self.need_workgroup_variables_initialization(func_ctx, module);
1688
1689        let mut any_args_written = false;
1690        let mut separator = || {
1691            if any_args_written {
1692                ", "
1693            } else {
1694                any_args_written = true;
1695                ""
1696            }
1697        };
1698
1699        let needs_local_invocation_index_name = need_workgroup_variables_initialization || nested;
1700        let mut local_invocation_index_name = None;
1701        // For nested entry points, collect arg names as we write them so that
1702        // write_nested_function_outer can pass the exact same names to the call site.
1703        let mut nested_wgsl_args: Vec<String> = Vec::new();
1704        let mut nested_task_payload_name: Option<String> = None;
1705        // Write function arguments for non entry point functions
1706        match func_ctx.ty {
1707            back::FunctionType::Function(handle) => {
1708                for (index, arg) in func.arguments.iter().enumerate() {
1709                    write!(self.out, "{}", separator())?;
1710                    self.write_function_argument(module, handle, arg, index)?;
1711                }
1712                // If this reads a task payload variable the variable needs to be passed as an `in` argument
1713                for (var_handle, var) in module.global_variables.iter() {
1714                    let uses = info[var_handle];
1715                    if uses.contains(valid::GlobalUse::READ)
1716                        && !uses.contains(valid::GlobalUse::WRITE)
1717                        && var.space == crate::AddressSpace::TaskPayload
1718                    {
1719                        self.function_task_payload_var.insert(handle, var_handle);
1720                        write!(self.out, "{}in ", separator())?;
1721
1722                        self.write_type(module, var.ty)?;
1723                        let name = &self.names[&NameKey::GlobalVariable(var_handle)];
1724                        write!(self.out, " {name}")?;
1725                        break;
1726                    }
1727                }
1728            }
1729            back::FunctionType::EntryPoint(ep_index) => {
1730                let ep = &module.entry_points[ep_index as usize];
1731                if let Some(ref ep_input) =
1732                    self.entry_point_io.get(&(ep_index as usize)).unwrap().input
1733                {
1734                    write!(self.out, "{} {}", ep_input.ty_name, ep_input.arg_name)?;
1735                    separator();
1736                    nested_wgsl_args.push(ep_input.arg_name.clone());
1737                } else {
1738                    let stage = ep.stage;
1739                    for (index, arg) in func.arguments.iter().enumerate() {
1740                        write!(self.out, "{}", separator())?;
1741                        self.write_type(module, arg.ty)?;
1742
1743                        let argument_name =
1744                            &self.names[&NameKey::EntryPointArgument(ep_index, index as u32)];
1745
1746                        if arg.binding
1747                            == Some(crate::Binding::BuiltIn(
1748                                crate::BuiltIn::LocalInvocationIndex,
1749                            ))
1750                        {
1751                            local_invocation_index_name = Some(argument_name.clone());
1752                        }
1753
1754                        nested_wgsl_args.push(argument_name.clone());
1755                        write!(self.out, " {argument_name}")?;
1756                        if let TypeInner::Array { base, size, .. } = module.types[arg.ty].inner {
1757                            self.write_array_size(module, base, size)?;
1758                        }
1759
1760                        self.write_semantic(&arg.binding, Some((stage, Io::Input)))?;
1761                    }
1762                }
1763                if ep.stage == ShaderStage::Mesh {
1764                    if let Some(var_handle) = ep.task_payload {
1765                        let var = &module.global_variables[var_handle];
1766                        write!(self.out, "{}in ", separator())?;
1767                        self.write_type(module, var.ty)?;
1768                        let arg_name = &self.names[&NameKey::GlobalVariable(var_handle)];
1769                        write!(self.out, " {arg_name}")?;
1770                        nested_task_payload_name = Some(arg_name.clone());
1771                        if let TypeInner::Array { base, size, .. } = module.types[var.ty].inner {
1772                            self.write_array_size(module, base, size)?;
1773                        }
1774                    }
1775                }
1776                if needs_local_invocation_index_name && local_invocation_index_name.is_none() {
1777                    let name = self.namer.call("local_invocation_index");
1778                    write!(self.out, "{}uint {name}", separator())?;
1779                    write!(self.out, " : SV_GroupIndex")?;
1780                    local_invocation_index_name = Some(name);
1781                }
1782            }
1783        }
1784        // Ends of arguments
1785        write!(self.out, ")")?;
1786
1787        // Write semantic if it present
1788        if let back::FunctionType::EntryPoint(index) = func_ctx.ty {
1789            let stage = module.entry_points[index as usize].stage;
1790            if let Some(crate::FunctionResult { ref binding, .. }) = func.result {
1791                self.write_semantic(binding, Some((stage, Io::Output)))?;
1792            }
1793        }
1794
1795        // Function body start
1796        writeln!(self.out)?;
1797        writeln!(self.out, "{{")?;
1798
1799        if need_workgroup_variables_initialization && !nested {
1800            let back::FunctionType::EntryPoint(index) = func_ctx.ty else {
1801                unreachable!();
1802            };
1803            writeln!(
1804                self.out,
1805                "{}if ({} == 0) {{",
1806                back::INDENT,
1807                // need_workgroup_variables_initialization forces this to be written
1808                // if the user doesn't specify it (so this must be Some())
1809                local_invocation_index_name.as_ref().unwrap(),
1810            )?;
1811            self.write_workgroup_variables_initialization(
1812                func_ctx,
1813                module,
1814                module.entry_points[index as usize].stage,
1815            )?;
1816
1817            writeln!(self.out, "{}}}", back::INDENT)?;
1818            self.write_control_barrier(crate::Barrier::WORK_GROUP, back::Level(1))?;
1819        }
1820
1821        if let back::FunctionType::EntryPoint(index) = func_ctx.ty {
1822            self.write_ep_arguments_initialization(module, func, index)?;
1823        }
1824
1825        // Write function local variables
1826        for (handle, local) in func.local_variables.iter() {
1827            // Write indentation (only for readability)
1828            write!(self.out, "{}", back::INDENT)?;
1829
1830            // Write the local name
1831            // The leading space is important
1832            self.write_type(module, local.ty)?;
1833            write!(self.out, " {}", self.names[&func_ctx.name_key(handle)])?;
1834            // Write size for array type
1835            if let TypeInner::Array { base, size, .. } = module.types[local.ty].inner {
1836                self.write_array_size(module, base, size)?;
1837            }
1838
1839            let is_ray_query = match module.types[local.ty].inner {
1840                // from https://microsoft.github.io/DirectX-Specs/d3d/Raytracing.html#tracerayinline-example-1 it seems that ray queries shouldn't be zeroed
1841                TypeInner::RayQuery { .. } => true,
1842                _ => {
1843                    write!(self.out, " = ")?;
1844                    // Write the local initializer if needed
1845                    if let Some(init) = local.init {
1846                        self.write_expr(module, init, func_ctx)?;
1847                    } else {
1848                        // Zero initialize local variables
1849                        self.write_default_init(module, local.ty)?;
1850                    }
1851                    false
1852                }
1853            };
1854            // Finish the local with `;` and add a newline (only for readability)
1855            writeln!(self.out, ";")?;
1856            // If it's a ray query, we also want a tracker variable
1857            if is_ray_query {
1858                write!(self.out, "{}", back::INDENT)?;
1859                self.write_value_type(module, &TypeInner::Scalar(Scalar::U32))?;
1860                writeln!(
1861                    self.out,
1862                    " {RAY_QUERY_TRACKER_VARIABLE_PREFIX}{} = 0;",
1863                    self.names[&func_ctx.name_key(handle)]
1864                )?;
1865            }
1866        }
1867
1868        if !func.local_variables.is_empty() {
1869            writeln!(self.out)?;
1870        }
1871
1872        // Write the function body (statement list)
1873        for sta in func.body.iter() {
1874            // The indentation should always be 1 when writing the function body
1875            self.write_stmt(module, sta, func_ctx, back::Level(1))?;
1876        }
1877
1878        writeln!(self.out, "}}")?;
1879
1880        if nested {
1881            self.write_nested_function_outer(
1882                module,
1883                func_ctx,
1884                &header,
1885                name,
1886                need_workgroup_variables_initialization,
1887                &nested_name,
1888                ep.unwrap(),
1889                NestedEntryPointArgs {
1890                    user_args: nested_wgsl_args,
1891                    task_payload: nested_task_payload_name,
1892                    // guaranteed to be set for nested functions (task/mesh shaders)
1893                    local_invocation_index: local_invocation_index_name.unwrap(),
1894                },
1895            )?;
1896        }
1897
1898        self.named_expressions.clear();
1899
1900        Ok(())
1901    }
1902
1903    fn write_function_argument(
1904        &mut self,
1905        module: &Module,
1906        handle: Handle<crate::Function>,
1907        arg: &crate::FunctionArgument,
1908        index: usize,
1909    ) -> BackendResult {
1910        // External texture arguments must be expanded into separate
1911        // arguments for each plane and the params buffer.
1912        if let TypeInner::Image {
1913            class: crate::ImageClass::External,
1914            ..
1915        } = module.types[arg.ty].inner
1916        {
1917            return self.write_function_external_texture_argument(module, handle, index);
1918        }
1919
1920        // Write argument type
1921        let arg_ty = match module.types[arg.ty].inner {
1922            // pointers in function arguments are expected and resolve to `inout`
1923            TypeInner::Pointer { base, .. } => {
1924                //TODO: can we narrow this down to just `in` when possible?
1925                write!(self.out, "inout ")?;
1926                base
1927            }
1928            _ => arg.ty,
1929        };
1930        self.write_type(module, arg_ty)?;
1931
1932        let argument_name = &self.names[&NameKey::FunctionArgument(handle, index as u32)];
1933
1934        // Write argument name. Space is important.
1935        write!(self.out, " {argument_name}")?;
1936        if let TypeInner::Array { base, size, .. } = module.types[arg_ty].inner {
1937            self.write_array_size(module, base, size)?;
1938        }
1939
1940        Ok(())
1941    }
1942
1943    fn write_function_external_texture_argument(
1944        &mut self,
1945        module: &Module,
1946        handle: Handle<crate::Function>,
1947        index: usize,
1948    ) -> BackendResult {
1949        let plane_names = [0, 1, 2].map(|i| {
1950            &self.names[&NameKey::ExternalTextureFunctionArgument(
1951                handle,
1952                index as u32,
1953                ExternalTextureNameKey::Plane(i),
1954            )]
1955        });
1956        let params_name = &self.names[&NameKey::ExternalTextureFunctionArgument(
1957            handle,
1958            index as u32,
1959            ExternalTextureNameKey::Params,
1960        )];
1961        let params_ty_name =
1962            &self.names[&NameKey::Type(module.special_types.external_texture_params.unwrap())];
1963        write!(
1964            self.out,
1965            "Texture2D<float4> {}, Texture2D<float4> {}, Texture2D<float4> {}, {params_ty_name} {params_name}",
1966            plane_names[0], plane_names[1], plane_names[2],
1967        )?;
1968        Ok(())
1969    }
1970
1971    fn need_workgroup_variables_initialization(
1972        &mut self,
1973        func_ctx: &back::FunctionCtx,
1974        module: &Module,
1975    ) -> bool {
1976        self.options.zero_initialize_workgroup_memory
1977            && func_ctx.ty.is_compute_like_entry_point(module)
1978            && module.global_variables.iter().any(|(handle, var)| {
1979                !func_ctx.info[handle].is_empty() && var.space.is_workgroup_like()
1980            })
1981    }
1982
1983    pub(super) fn write_workgroup_variables_initialization(
1984        &mut self,
1985        func_ctx: &back::FunctionCtx,
1986        module: &Module,
1987        stage: ShaderStage,
1988    ) -> BackendResult {
1989        let vars = module.global_variables.iter().filter(|&(handle, var)| {
1990            // Read-only in mesh shaders
1991            let task_needs_zero =
1992                (var.space == crate::AddressSpace::TaskPayload) && stage == ShaderStage::Task;
1993            !func_ctx.info[handle].is_empty()
1994                && (var.space == crate::AddressSpace::WorkGroup || task_needs_zero)
1995        });
1996
1997        for (handle, var) in vars {
1998            let name = &self.names[&NameKey::GlobalVariable(handle)];
1999            write!(self.out, "{}{} = ", back::Level(2), name)?;
2000            self.write_default_init(module, var.ty)?;
2001            writeln!(self.out, ";")?;
2002        }
2003        Ok(())
2004    }
2005
2006    /// Helper method used to write switches
2007    fn write_switch(
2008        &mut self,
2009        module: &Module,
2010        func_ctx: &back::FunctionCtx<'_>,
2011        level: back::Level,
2012        selector: Handle<crate::Expression>,
2013        cases: &[crate::SwitchCase],
2014    ) -> BackendResult {
2015        // Write all cases
2016        let indent_level_1 = level.next();
2017        let indent_level_2 = indent_level_1.next();
2018
2019        // See docs of `back::continue_forward` module.
2020        if let Some(variable) = self.continue_ctx.enter_switch(&mut self.namer) {
2021            writeln!(self.out, "{level}bool {variable} = false;",)?;
2022        };
2023
2024        // Check if there is only one body, by seeing if all except the last case are fall through
2025        // with empty bodies. FXC doesn't handle these switches correctly, so
2026        // we generate a `do {} while(false);` loop instead. There must be a default case, so there
2027        // is no need to check if one of the cases would have matched.
2028        let one_body = cases
2029            .iter()
2030            .rev()
2031            .skip(1)
2032            .all(|case| case.fall_through && case.body.is_empty());
2033        if one_body {
2034            // Start the do-while
2035            writeln!(self.out, "{level}do {{")?;
2036            // Note: Expressions have no side-effects so we don't need to emit selector expression.
2037
2038            // Body
2039            if let Some(case) = cases.last() {
2040                for sta in case.body.iter() {
2041                    self.write_stmt(module, sta, func_ctx, indent_level_1)?;
2042                }
2043            }
2044            // End do-while
2045            writeln!(self.out, "{level}}} while(false);")?;
2046        } else {
2047            // Start the switch
2048            write!(self.out, "{level}")?;
2049            write!(self.out, "switch(")?;
2050            self.write_expr(module, selector, func_ctx)?;
2051            writeln!(self.out, ") {{")?;
2052
2053            for (i, case) in cases.iter().enumerate() {
2054                match case.value {
2055                    crate::SwitchValue::I32(value) => {
2056                        write!(self.out, "{indent_level_1}case {value}:")?
2057                    }
2058                    crate::SwitchValue::U32(value) => {
2059                        write!(self.out, "{indent_level_1}case {value}u:")?
2060                    }
2061                    crate::SwitchValue::Default => write!(self.out, "{indent_level_1}default:")?,
2062                }
2063
2064                // The new block is not only stylistic, it plays a role here:
2065                // We might end up having to write the same case body
2066                // multiple times due to FXC not supporting fallthrough.
2067                // Therefore, some `Expression`s written by `Statement::Emit`
2068                // will end up having the same name (`_expr<handle_index>`).
2069                // So we need to put each case in its own scope.
2070                let write_block_braces = !(case.fall_through && case.body.is_empty());
2071                if write_block_braces {
2072                    writeln!(self.out, " {{")?;
2073                } else {
2074                    writeln!(self.out)?;
2075                }
2076
2077                // Although FXC does support a series of case clauses before
2078                // a block[^yes], it does not support fallthrough from a
2079                // non-empty case block to the next[^no]. If this case has a
2080                // non-empty body with a fallthrough, emulate that by
2081                // duplicating the bodies of all the cases it would fall
2082                // into as extensions of this case's own body. This makes
2083                // the HLSL output potentially quadratic in the size of the
2084                // Naga IR.
2085                //
2086                // [^yes]: ```hlsl
2087                // case 1:
2088                // case 2: do_stuff()
2089                // ```
2090                // [^no]: ```hlsl
2091                // case 1: do_this();
2092                // case 2: do_that();
2093                // ```
2094                if case.fall_through && !case.body.is_empty() {
2095                    let curr_len = i + 1;
2096                    let end_case_idx = curr_len
2097                        + cases
2098                            .iter()
2099                            .skip(curr_len)
2100                            .position(|case| !case.fall_through)
2101                            .unwrap();
2102                    let indent_level_3 = indent_level_2.next();
2103                    for case in &cases[i..=end_case_idx] {
2104                        writeln!(self.out, "{indent_level_2}{{")?;
2105                        let prev_len = self.named_expressions.len();
2106                        for sta in case.body.iter() {
2107                            self.write_stmt(module, sta, func_ctx, indent_level_3)?;
2108                        }
2109                        // Clear all named expressions that were previously inserted by the statements in the block
2110                        self.named_expressions.truncate(prev_len);
2111                        writeln!(self.out, "{indent_level_2}}}")?;
2112                    }
2113
2114                    let last_case = &cases[end_case_idx];
2115                    if last_case.body.last().is_none_or(|s| !s.is_terminator()) {
2116                        writeln!(self.out, "{indent_level_2}break;")?;
2117                    }
2118                } else {
2119                    for sta in case.body.iter() {
2120                        self.write_stmt(module, sta, func_ctx, indent_level_2)?;
2121                    }
2122                    if !case.fall_through && case.body.last().is_none_or(|s| !s.is_terminator()) {
2123                        writeln!(self.out, "{indent_level_2}break;")?;
2124                    }
2125                }
2126
2127                if write_block_braces {
2128                    writeln!(self.out, "{indent_level_1}}}")?;
2129                }
2130            }
2131
2132            writeln!(self.out, "{level}}}")?;
2133        }
2134
2135        // Handle any forwarded continue statements.
2136        use back::continue_forward::ExitControlFlow;
2137        let op = match self.continue_ctx.exit_switch() {
2138            ExitControlFlow::None => None,
2139            ExitControlFlow::Continue { variable } => Some(("continue", variable)),
2140            ExitControlFlow::Break { variable } => Some(("break", variable)),
2141        };
2142        if let Some((control_flow, variable)) = op {
2143            writeln!(self.out, "{level}if ({variable}) {{")?;
2144            writeln!(self.out, "{indent_level_1}{control_flow};")?;
2145            writeln!(self.out, "{level}}}")?;
2146        }
2147
2148        Ok(())
2149    }
2150
2151    fn write_index(
2152        &mut self,
2153        module: &Module,
2154        index: Index,
2155        func_ctx: &back::FunctionCtx<'_>,
2156    ) -> BackendResult {
2157        match index {
2158            Index::Static(index) => {
2159                write!(self.out, "{index}")?;
2160            }
2161            Index::Expression(index) => {
2162                self.write_expr(module, index, func_ctx)?;
2163            }
2164        }
2165        Ok(())
2166    }
2167
2168    /// Helper method used to write statements
2169    ///
2170    /// # Notes
2171    /// Always adds a newline
2172    fn write_stmt(
2173        &mut self,
2174        module: &Module,
2175        stmt: &crate::Statement,
2176        func_ctx: &back::FunctionCtx<'_>,
2177        level: back::Level,
2178    ) -> BackendResult {
2179        use crate::Statement;
2180
2181        match *stmt {
2182            Statement::Emit(ref range) => {
2183                for handle in range.clone() {
2184                    let ptr_class = func_ctx.resolve_type(handle, &module.types).pointer_space();
2185                    let expr_name = if ptr_class.is_some() {
2186                        // HLSL can't save a pointer-valued expression in a variable,
2187                        // but we shouldn't ever need to: they should never be named expressions,
2188                        // and none of the expression types flagged by bake_ref_count can be pointer-valued.
2189                        None
2190                    } else if let Some(name) = func_ctx.named_expressions.get(&handle) {
2191                        // Front end provides names for all variables at the start of writing.
2192                        // But we write them to step by step. We need to recache them
2193                        // Otherwise, we could accidentally write variable name instead of full expression.
2194                        // Also, we use sanitized names! It defense backend from generating variable with name from reserved keywords.
2195                        Some(self.namer.call(name))
2196                    } else if self.need_bake_expressions.contains(&handle) {
2197                        Some(Baked(handle).to_string())
2198                    } else {
2199                        None
2200                    };
2201
2202                    if let Some(name) = expr_name {
2203                        write!(self.out, "{level}")?;
2204                        self.write_named_expr(module, handle, name, handle, func_ctx)?;
2205                    }
2206                }
2207            }
2208            // TODO: copy-paste from glsl-out
2209            Statement::Block(ref block) => {
2210                write!(self.out, "{level}")?;
2211                writeln!(self.out, "{{")?;
2212                for sta in block.iter() {
2213                    // Increase the indentation to help with readability
2214                    self.write_stmt(module, sta, func_ctx, level.next())?
2215                }
2216                writeln!(self.out, "{level}}}")?
2217            }
2218            // TODO: copy-paste from glsl-out
2219            Statement::If {
2220                condition,
2221                ref accept,
2222                ref reject,
2223            } => {
2224                write!(self.out, "{level}")?;
2225                write!(self.out, "if (")?;
2226                self.write_expr(module, condition, func_ctx)?;
2227                writeln!(self.out, ") {{")?;
2228
2229                let l2 = level.next();
2230                for sta in accept {
2231                    // Increase indentation to help with readability
2232                    self.write_stmt(module, sta, func_ctx, l2)?;
2233                }
2234
2235                // If there are no statements in the reject block we skip writing it
2236                // This is only for readability
2237                if !reject.is_empty() {
2238                    writeln!(self.out, "{level}}} else {{")?;
2239
2240                    for sta in reject {
2241                        // Increase indentation to help with readability
2242                        self.write_stmt(module, sta, func_ctx, l2)?;
2243                    }
2244                }
2245
2246                writeln!(self.out, "{level}}}")?
2247            }
2248            // TODO: copy-paste from glsl-out
2249            Statement::Kill => writeln!(self.out, "{level}discard;")?,
2250            Statement::Return { value: None } => {
2251                writeln!(self.out, "{level}return;")?;
2252            }
2253            Statement::Return { value: Some(expr) } => {
2254                let base_ty_res = &func_ctx.info[expr].ty;
2255                let mut resolved = base_ty_res.inner_with(&module.types);
2256                if let TypeInner::Pointer { base, space: _ } = *resolved {
2257                    resolved = &module.types[base].inner;
2258                }
2259
2260                if let TypeInner::Struct { .. } = *resolved {
2261                    // We can safely unwrap here, since we now we working with struct
2262                    let ty = base_ty_res.handle().unwrap();
2263                    let struct_name = &self.names[&NameKey::Type(ty)];
2264                    let variable_name = self.namer.call(&struct_name.to_lowercase());
2265                    write!(self.out, "{level}const {struct_name} {variable_name} = ",)?;
2266                    self.write_expr(module, expr, func_ctx)?;
2267                    writeln!(self.out, ";")?;
2268
2269                    // for entry point returns, we may need to reshuffle the outputs into a different struct
2270                    let ep_output = match func_ctx.ty {
2271                        back::FunctionType::Function(_) => None,
2272                        back::FunctionType::EntryPoint(index) => self
2273                            .entry_point_io
2274                            .get(&(index as usize))
2275                            .unwrap()
2276                            .output
2277                            .as_ref(),
2278                    };
2279                    let final_name = match ep_output {
2280                        Some(ep_output) => {
2281                            let final_name = self.namer.call(&variable_name);
2282                            write!(
2283                                self.out,
2284                                "{}const {} {} = {{ ",
2285                                level, ep_output.ty_name, final_name,
2286                            )?;
2287                            for (index, m) in ep_output.members.iter().enumerate() {
2288                                if index != 0 {
2289                                    write!(self.out, ", ")?;
2290                                }
2291                                let member_name = &self.names[&NameKey::StructMember(ty, m.index)];
2292                                write!(self.out, "{variable_name}.{member_name}")?;
2293                            }
2294                            writeln!(self.out, " }};")?;
2295                            final_name
2296                        }
2297                        None => variable_name,
2298                    };
2299                    writeln!(self.out, "{level}return {final_name};")?;
2300                } else {
2301                    write!(self.out, "{level}return ")?;
2302                    self.write_expr(module, expr, func_ctx)?;
2303                    writeln!(self.out, ";")?
2304                }
2305            }
2306            Statement::Store { pointer, value } => {
2307                let ty_inner = func_ctx.resolve_type(pointer, &module.types);
2308                if let Some(crate::AddressSpace::Storage { .. }) = ty_inner.pointer_space() {
2309                    let var_handle = self.fill_access_chain(module, pointer, func_ctx)?;
2310                    self.write_storage_store(
2311                        module,
2312                        var_handle,
2313                        StoreValue::Expression(value),
2314                        func_ctx,
2315                        level,
2316                        None,
2317                    )?;
2318                } else {
2319                    // We treat matrices of the form `matCx2` as a sequence of C `vec2`s.
2320                    // See the module-level block comment in mod.rs for details.
2321                    //
2322                    // We handle matrix Stores here directly (including sub accesses for Vectors and Scalars).
2323                    // Loads are handled by `Expression::AccessIndex` (since sub accesses work fine for Loads).
2324                    enum MatrixAccess {
2325                        Direct {
2326                            base: Handle<crate::Expression>,
2327                            index: u32,
2328                        },
2329                        Struct {
2330                            columns: crate::VectorSize,
2331                            base: Handle<crate::Expression>,
2332                        },
2333                    }
2334
2335                    let get_members = |expr: Handle<crate::Expression>| {
2336                        let resolved = func_ctx.resolve_type(expr, &module.types);
2337                        match *resolved {
2338                            TypeInner::Pointer { base, .. } => match module.types[base].inner {
2339                                TypeInner::Struct { ref members, .. } => Some(members),
2340                                _ => None,
2341                            },
2342                            _ => None,
2343                        }
2344                    };
2345
2346                    write!(self.out, "{level}")?;
2347
2348                    let matrix_access_on_lhs =
2349                        find_matrix_in_access_chain(module, pointer, func_ctx).and_then(
2350                            |(matrix_expr, vector, scalar)| match (
2351                                func_ctx.resolve_type(matrix_expr, &module.types),
2352                                &func_ctx.expressions[matrix_expr],
2353                            ) {
2354                                (
2355                                    &TypeInner::Pointer { base: ty, .. },
2356                                    &crate::Expression::AccessIndex { base, index },
2357                                ) if matches!(
2358                                    module.types[ty].inner,
2359                                    TypeInner::Matrix {
2360                                        rows: crate::VectorSize::Bi,
2361                                        ..
2362                                    }
2363                                ) && get_members(base)
2364                                    .map(|members| members[index as usize].binding.is_none())
2365                                    == Some(true) =>
2366                                {
2367                                    Some((MatrixAccess::Direct { base, index }, vector, scalar))
2368                                }
2369                                _ => {
2370                                    if let Some(MatrixType {
2371                                        columns,
2372                                        rows: crate::VectorSize::Bi,
2373                                        width: 4,
2374                                    }) = get_inner_matrix_of_struct_array_member(
2375                                        module,
2376                                        matrix_expr,
2377                                        func_ctx,
2378                                        true,
2379                                    ) {
2380                                        Some((
2381                                            MatrixAccess::Struct {
2382                                                columns,
2383                                                base: matrix_expr,
2384                                            },
2385                                            vector,
2386                                            scalar,
2387                                        ))
2388                                    } else {
2389                                        None
2390                                    }
2391                                }
2392                            },
2393                        );
2394
2395                    match matrix_access_on_lhs {
2396                        Some((MatrixAccess::Direct { index, base }, vector, scalar)) => {
2397                            let base_ty_res = &func_ctx.info[base].ty;
2398                            let resolved = base_ty_res.inner_with(&module.types);
2399                            let ty = match *resolved {
2400                                TypeInner::Pointer { base, .. } => base,
2401                                _ => base_ty_res.handle().unwrap(),
2402                            };
2403
2404                            if let Some(Index::Static(vec_index)) = vector {
2405                                self.write_expr(module, base, func_ctx)?;
2406                                write!(
2407                                    self.out,
2408                                    ".{}_{}",
2409                                    &self.names[&NameKey::StructMember(ty, index)],
2410                                    vec_index
2411                                )?;
2412
2413                                if let Some(scalar_index) = scalar {
2414                                    write!(self.out, "[")?;
2415                                    self.write_index(module, scalar_index, func_ctx)?;
2416                                    write!(self.out, "]")?;
2417                                }
2418
2419                                write!(self.out, " = ")?;
2420                                self.write_expr(module, value, func_ctx)?;
2421                                writeln!(self.out, ";")?;
2422                            } else {
2423                                let access = WrappedStructMatrixAccess { ty, index };
2424                                match (&vector, &scalar) {
2425                                    (&Some(_), &Some(_)) => {
2426                                        self.write_wrapped_struct_matrix_set_scalar_function_name(
2427                                            access,
2428                                        )?;
2429                                    }
2430                                    (&Some(_), &None) => {
2431                                        self.write_wrapped_struct_matrix_set_vec_function_name(
2432                                            access,
2433                                        )?;
2434                                    }
2435                                    (&None, _) => {
2436                                        self.write_wrapped_struct_matrix_set_function_name(access)?;
2437                                    }
2438                                }
2439
2440                                write!(self.out, "(")?;
2441                                self.write_expr(module, base, func_ctx)?;
2442                                write!(self.out, ", ")?;
2443                                self.write_expr(module, value, func_ctx)?;
2444
2445                                if let Some(Index::Expression(vec_index)) = vector {
2446                                    write!(self.out, ", ")?;
2447                                    self.write_expr(module, vec_index, func_ctx)?;
2448
2449                                    if let Some(scalar_index) = scalar {
2450                                        write!(self.out, ", ")?;
2451                                        self.write_index(module, scalar_index, func_ctx)?;
2452                                    }
2453                                }
2454                                writeln!(self.out, ");")?;
2455                            }
2456                        }
2457                        Some((
2458                            MatrixAccess::Struct { columns, base },
2459                            Some(Index::Expression(vec_index)),
2460                            scalar,
2461                        )) => {
2462                            // We handle `Store`s to __matCx2 column vectors and scalar elements via
2463                            // the previously injected functions __set_col_of_matCx2 / __set_el_of_matCx2.
2464
2465                            if scalar.is_some() {
2466                                write!(self.out, "__set_el_of_mat{}x2", columns as u8)?;
2467                            } else {
2468                                write!(self.out, "__set_col_of_mat{}x2", columns as u8)?;
2469                            }
2470                            write!(self.out, "(")?;
2471                            self.write_expr(module, base, func_ctx)?;
2472                            write!(self.out, ", ")?;
2473                            self.write_expr(module, vec_index, func_ctx)?;
2474
2475                            if let Some(scalar_index) = scalar {
2476                                write!(self.out, ", ")?;
2477                                self.write_index(module, scalar_index, func_ctx)?;
2478                            }
2479
2480                            write!(self.out, ", ")?;
2481                            self.write_expr(module, value, func_ctx)?;
2482
2483                            writeln!(self.out, ");")?;
2484                        }
2485                        Some((MatrixAccess::Struct { .. }, Some(Index::Static(_)), _))
2486                        | Some((MatrixAccess::Struct { .. }, None, _))
2487                        | None => {
2488                            self.write_expr(module, pointer, func_ctx)?;
2489                            write!(self.out, " = ")?;
2490
2491                            // We cast the RHS of this store in cases where the LHS
2492                            // is a struct member with type:
2493                            //  - matCx2 or
2494                            //  - a (possibly nested) array of matCx2's
2495                            if let Some(MatrixType {
2496                                columns,
2497                                rows: crate::VectorSize::Bi,
2498                                width: 4,
2499                            }) = get_inner_matrix_of_struct_array_member(
2500                                module, pointer, func_ctx, false,
2501                            ) {
2502                                let mut resolved = func_ctx.resolve_type(pointer, &module.types);
2503                                if let TypeInner::Pointer { base, .. } = *resolved {
2504                                    resolved = &module.types[base].inner;
2505                                }
2506
2507                                write!(self.out, "(__mat{}x2", columns as u8)?;
2508                                if let TypeInner::Array { base, size, .. } = *resolved {
2509                                    self.write_array_size(module, base, size)?;
2510                                }
2511                                write!(self.out, ")")?;
2512                            }
2513
2514                            self.write_expr(module, value, func_ctx)?;
2515                            writeln!(self.out, ";")?
2516                        }
2517                    }
2518                }
2519            }
2520            Statement::Loop {
2521                ref body,
2522                ref continuing,
2523                break_if,
2524            } => {
2525                let force_loop_bound_statements = self.gen_force_bounded_loop_statements(level);
2526                let gate_name = (!continuing.is_empty() || break_if.is_some())
2527                    .then(|| self.namer.call("loop_init"));
2528
2529                if let Some((ref decl, _)) = force_loop_bound_statements {
2530                    writeln!(self.out, "{decl}")?;
2531                }
2532                if let Some(ref gate_name) = gate_name {
2533                    writeln!(self.out, "{level}bool {gate_name} = true;")?;
2534                }
2535
2536                self.continue_ctx.enter_loop();
2537                writeln!(self.out, "{level}while(true) {{")?;
2538                if let Some((_, ref break_and_inc)) = force_loop_bound_statements {
2539                    writeln!(self.out, "{break_and_inc}")?;
2540                }
2541                let l2 = level.next();
2542                if let Some(gate_name) = gate_name {
2543                    writeln!(self.out, "{l2}if (!{gate_name}) {{")?;
2544                    let l3 = l2.next();
2545                    for sta in continuing.iter() {
2546                        self.write_stmt(module, sta, func_ctx, l3)?;
2547                    }
2548                    if let Some(condition) = break_if {
2549                        write!(self.out, "{l3}if (")?;
2550                        self.write_expr(module, condition, func_ctx)?;
2551                        writeln!(self.out, ") {{")?;
2552                        writeln!(self.out, "{}break;", l3.next())?;
2553                        writeln!(self.out, "{l3}}}")?;
2554                    }
2555                    writeln!(self.out, "{l2}}}")?;
2556                    writeln!(self.out, "{l2}{gate_name} = false;")?;
2557                }
2558
2559                for sta in body.iter() {
2560                    self.write_stmt(module, sta, func_ctx, l2)?;
2561                }
2562
2563                writeln!(self.out, "{level}}}")?;
2564                self.continue_ctx.exit_loop();
2565            }
2566            Statement::Break => writeln!(self.out, "{level}break;")?,
2567            Statement::Continue => {
2568                if let Some(variable) = self.continue_ctx.continue_encountered() {
2569                    writeln!(self.out, "{level}{variable} = true;")?;
2570                    writeln!(self.out, "{level}break;")?
2571                } else {
2572                    writeln!(self.out, "{level}continue;")?
2573                }
2574            }
2575            Statement::ControlBarrier(barrier) => {
2576                self.write_control_barrier(barrier, level)?;
2577            }
2578            Statement::MemoryBarrier(barrier) => {
2579                self.write_memory_barrier(barrier, level)?;
2580            }
2581            Statement::ImageStore {
2582                image,
2583                coordinate,
2584                array_index,
2585                value,
2586            } => {
2587                write!(self.out, "{level}")?;
2588                self.write_expr(module, image, func_ctx)?;
2589
2590                write!(self.out, "[")?;
2591                if let Some(index) = array_index {
2592                    // Array index accepted only for texture_storage_2d_array, so we can safety use int3(coordinate, array_index) here
2593                    write!(self.out, "int3(")?;
2594                    self.write_expr(module, coordinate, func_ctx)?;
2595                    write!(self.out, ", ")?;
2596                    self.write_expr(module, index, func_ctx)?;
2597                    write!(self.out, ")")?;
2598                } else {
2599                    self.write_expr(module, coordinate, func_ctx)?;
2600                }
2601                write!(self.out, "]")?;
2602
2603                write!(self.out, " = ")?;
2604                self.write_expr(module, value, func_ctx)?;
2605                writeln!(self.out, ";")?;
2606            }
2607            Statement::Call {
2608                function,
2609                ref arguments,
2610                result,
2611            } => {
2612                write!(self.out, "{level}")?;
2613
2614                if let Some(expr) = result {
2615                    write!(self.out, "const ")?;
2616                    let name = Baked(expr).to_string();
2617                    let expr_ty = &func_ctx.info[expr].ty;
2618                    let ty_inner = match *expr_ty {
2619                        proc::TypeResolution::Handle(handle) => {
2620                            self.write_type(module, handle)?;
2621                            &module.types[handle].inner
2622                        }
2623                        proc::TypeResolution::Value(ref value) => {
2624                            self.write_value_type(module, value)?;
2625                            value
2626                        }
2627                    };
2628                    write!(self.out, " {name}")?;
2629                    if let TypeInner::Array { base, size, .. } = *ty_inner {
2630                        self.write_array_size(module, base, size)?;
2631                    }
2632                    write!(self.out, " = ")?;
2633                    self.named_expressions.insert(expr, name);
2634                }
2635                let func_name = &self.names[&NameKey::Function(function)];
2636                write!(self.out, "{func_name}(")?;
2637                let mut any_args_written = false;
2638                let mut separator = || {
2639                    if any_args_written {
2640                        ", "
2641                    } else {
2642                        any_args_written = true;
2643                        ""
2644                    }
2645                };
2646                for argument in arguments {
2647                    write!(self.out, "{}", separator())?;
2648                    self.write_expr(module, *argument, func_ctx)?;
2649                }
2650                if let Some(&var) = self.function_task_payload_var.get(&function) {
2651                    let name = &self.names[&NameKey::GlobalVariable(var)];
2652                    // Pass it through directly, whether its an in variable to this function or the global variable
2653                    write!(self.out, "{}{name}", separator())?;
2654                }
2655                writeln!(self.out, ");")?;
2656            }
2657            Statement::Atomic {
2658                pointer,
2659                ref fun,
2660                value,
2661                result,
2662            } => {
2663                write!(self.out, "{level}")?;
2664                let res_var_info = if let Some(res_handle) = result {
2665                    let name = Baked(res_handle).to_string();
2666                    match func_ctx.info[res_handle].ty {
2667                        proc::TypeResolution::Handle(handle) => self.write_type(module, handle)?,
2668                        proc::TypeResolution::Value(ref value) => {
2669                            self.write_value_type(module, value)?
2670                        }
2671                    };
2672                    write!(self.out, " {name}; ")?;
2673                    self.named_expressions.insert(res_handle, name.clone());
2674                    Some((res_handle, name))
2675                } else {
2676                    None
2677                };
2678                let pointer_space = func_ctx
2679                    .resolve_type(pointer, &module.types)
2680                    .pointer_space()
2681                    .unwrap();
2682                let fun_str = fun.to_hlsl_suffix();
2683                let compare_expr = match *fun {
2684                    crate::AtomicFunction::Exchange { compare: Some(cmp) } => Some(cmp),
2685                    _ => None,
2686                };
2687                match pointer_space {
2688                    crate::AddressSpace::WorkGroup => {
2689                        write!(self.out, "Interlocked{fun_str}(")?;
2690                        self.write_expr(module, pointer, func_ctx)?;
2691                        self.emit_hlsl_atomic_tail(
2692                            module,
2693                            func_ctx,
2694                            fun,
2695                            compare_expr,
2696                            value,
2697                            &res_var_info,
2698                        )?;
2699                    }
2700                    crate::AddressSpace::Storage { .. } => {
2701                        let var_handle = self.fill_access_chain(module, pointer, func_ctx)?;
2702                        let var_name = &self.names[&NameKey::GlobalVariable(var_handle)];
2703                        let width = match func_ctx.resolve_type(value, &module.types) {
2704                            &TypeInner::Scalar(Scalar { width: 8, .. }) => "64",
2705                            _ => "",
2706                        };
2707                        write!(self.out, "{var_name}.Interlocked{fun_str}{width}(")?;
2708                        let chain = mem::take(&mut self.temp_access_chain);
2709                        self.write_storage_address(module, &chain, func_ctx)?;
2710                        self.temp_access_chain = chain;
2711                        self.emit_hlsl_atomic_tail(
2712                            module,
2713                            func_ctx,
2714                            fun,
2715                            compare_expr,
2716                            value,
2717                            &res_var_info,
2718                        )?;
2719                    }
2720                    ref other => {
2721                        return Err(Error::Custom(format!(
2722                            "invalid address space {other:?} for atomic statement"
2723                        )))
2724                    }
2725                }
2726                if let Some(cmp) = compare_expr {
2727                    if let Some(&(_res_handle, ref res_name)) = res_var_info.as_ref() {
2728                        write!(
2729                            self.out,
2730                            "{level}{res_name}.exchanged = ({res_name}.old_value == "
2731                        )?;
2732                        self.write_expr(module, cmp, func_ctx)?;
2733                        writeln!(self.out, ");")?;
2734                    }
2735                }
2736            }
2737            Statement::ImageAtomic {
2738                image,
2739                coordinate,
2740                array_index,
2741                fun,
2742                value,
2743            } => {
2744                write!(self.out, "{level}")?;
2745
2746                let fun_str = fun.to_hlsl_suffix();
2747                write!(self.out, "Interlocked{fun_str}(")?;
2748                self.write_expr(module, image, func_ctx)?;
2749                write!(self.out, "[")?;
2750                self.write_texture_coordinates(
2751                    "int",
2752                    coordinate,
2753                    array_index,
2754                    None,
2755                    module,
2756                    func_ctx,
2757                )?;
2758                write!(self.out, "],")?;
2759
2760                self.write_expr(module, value, func_ctx)?;
2761                writeln!(self.out, ");")?;
2762            }
2763            Statement::WorkGroupUniformLoad { pointer, result } => {
2764                self.write_control_barrier(crate::Barrier::WORK_GROUP, level)?;
2765                write!(self.out, "{level}")?;
2766                let name = Baked(result).to_string();
2767                self.write_named_expr(module, pointer, name, result, func_ctx)?;
2768
2769                self.write_control_barrier(crate::Barrier::WORK_GROUP, level)?;
2770            }
2771            Statement::Switch {
2772                selector,
2773                ref cases,
2774            } => {
2775                self.write_switch(module, func_ctx, level, selector, cases)?;
2776            }
2777            Statement::RayQuery { query, ref fun } => {
2778                // There are three possibilities for a ptr to be:
2779                // 1. A variable
2780                // 2. A function argument
2781                // 3. part of a struct
2782                //
2783                // 2 and 3 are not possible, a ray query (in naga IR)
2784                // is not allowed to be passed into a function, and
2785                // all languages disallow it in a struct (you get fun results if
2786                // you try it :) ).
2787                //
2788                // Therefore, the ray query expression must be a variable.
2789                let crate::Expression::LocalVariable(query_var) = func_ctx.expressions[query]
2790                else {
2791                    unreachable!()
2792                };
2793
2794                let tracker_expr_name = format!(
2795                    "{RAY_QUERY_TRACKER_VARIABLE_PREFIX}{}",
2796                    self.names[&func_ctx.name_key(query_var)]
2797                );
2798
2799                match *fun {
2800                    RayQueryFunction::Initialize {
2801                        acceleration_structure,
2802                        descriptor,
2803                    } => {
2804                        self.write_initialize_function(
2805                            module,
2806                            level,
2807                            query,
2808                            acceleration_structure,
2809                            descriptor,
2810                            &tracker_expr_name,
2811                            func_ctx,
2812                        )?;
2813                    }
2814                    RayQueryFunction::Proceed { result } => {
2815                        self.write_proceed(
2816                            module,
2817                            level,
2818                            query,
2819                            result,
2820                            &tracker_expr_name,
2821                            func_ctx,
2822                        )?;
2823                    }
2824                    RayQueryFunction::GenerateIntersection { hit_t } => {
2825                        self.write_generate_intersection(
2826                            module,
2827                            level,
2828                            query,
2829                            hit_t,
2830                            &tracker_expr_name,
2831                            func_ctx,
2832                        )?;
2833                    }
2834                    RayQueryFunction::ConfirmIntersection => {
2835                        self.write_confirm_intersection(
2836                            module,
2837                            level,
2838                            query,
2839                            &tracker_expr_name,
2840                            func_ctx,
2841                        )?;
2842                    }
2843                    RayQueryFunction::Terminate => {
2844                        self.write_terminate(module, level, query, &tracker_expr_name, func_ctx)?;
2845                    }
2846                }
2847            }
2848            Statement::SubgroupBallot { result, predicate } => {
2849                write!(self.out, "{level}")?;
2850                let name = Baked(result).to_string();
2851                write!(self.out, "const uint4 {name} = ")?;
2852                self.named_expressions.insert(result, name);
2853
2854                write!(self.out, "WaveActiveBallot(")?;
2855                match predicate {
2856                    Some(predicate) => self.write_expr(module, predicate, func_ctx)?,
2857                    None => write!(self.out, "true")?,
2858                }
2859                writeln!(self.out, ");")?;
2860            }
2861            Statement::SubgroupCollectiveOperation {
2862                op,
2863                collective_op,
2864                argument,
2865                result,
2866            } => {
2867                write!(self.out, "{level}")?;
2868                write!(self.out, "const ")?;
2869                let name = Baked(result).to_string();
2870                match func_ctx.info[result].ty {
2871                    proc::TypeResolution::Handle(handle) => self.write_type(module, handle)?,
2872                    proc::TypeResolution::Value(ref value) => {
2873                        self.write_value_type(module, value)?
2874                    }
2875                };
2876                write!(self.out, " {name} = ")?;
2877                self.named_expressions.insert(result, name);
2878
2879                match (collective_op, op) {
2880                    (crate::CollectiveOperation::Reduce, crate::SubgroupOperation::All) => {
2881                        write!(self.out, "WaveActiveAllTrue(")?
2882                    }
2883                    (crate::CollectiveOperation::Reduce, crate::SubgroupOperation::Any) => {
2884                        write!(self.out, "WaveActiveAnyTrue(")?
2885                    }
2886                    (crate::CollectiveOperation::Reduce, crate::SubgroupOperation::Add) => {
2887                        write!(self.out, "WaveActiveSum(")?
2888                    }
2889                    (crate::CollectiveOperation::Reduce, crate::SubgroupOperation::Mul) => {
2890                        write!(self.out, "WaveActiveProduct(")?
2891                    }
2892                    (crate::CollectiveOperation::Reduce, crate::SubgroupOperation::Max) => {
2893                        write!(self.out, "WaveActiveMax(")?
2894                    }
2895                    (crate::CollectiveOperation::Reduce, crate::SubgroupOperation::Min) => {
2896                        write!(self.out, "WaveActiveMin(")?
2897                    }
2898                    (crate::CollectiveOperation::Reduce, crate::SubgroupOperation::And) => {
2899                        write!(self.out, "WaveActiveBitAnd(")?
2900                    }
2901                    (crate::CollectiveOperation::Reduce, crate::SubgroupOperation::Or) => {
2902                        write!(self.out, "WaveActiveBitOr(")?
2903                    }
2904                    (crate::CollectiveOperation::Reduce, crate::SubgroupOperation::Xor) => {
2905                        write!(self.out, "WaveActiveBitXor(")?
2906                    }
2907                    (crate::CollectiveOperation::ExclusiveScan, crate::SubgroupOperation::Add) => {
2908                        write!(self.out, "WavePrefixSum(")?
2909                    }
2910                    (crate::CollectiveOperation::ExclusiveScan, crate::SubgroupOperation::Mul) => {
2911                        write!(self.out, "WavePrefixProduct(")?
2912                    }
2913                    (crate::CollectiveOperation::InclusiveScan, crate::SubgroupOperation::Add) => {
2914                        self.write_expr(module, argument, func_ctx)?;
2915                        write!(self.out, " + WavePrefixSum(")?;
2916                    }
2917                    (crate::CollectiveOperation::InclusiveScan, crate::SubgroupOperation::Mul) => {
2918                        self.write_expr(module, argument, func_ctx)?;
2919                        write!(self.out, " * WavePrefixProduct(")?;
2920                    }
2921                    _ => unimplemented!(),
2922                }
2923                self.write_expr(module, argument, func_ctx)?;
2924                writeln!(self.out, ");")?;
2925            }
2926            Statement::SubgroupGather {
2927                mode,
2928                argument,
2929                result,
2930            } => {
2931                write!(self.out, "{level}")?;
2932                write!(self.out, "const ")?;
2933                let name = Baked(result).to_string();
2934                match func_ctx.info[result].ty {
2935                    proc::TypeResolution::Handle(handle) => self.write_type(module, handle)?,
2936                    proc::TypeResolution::Value(ref value) => {
2937                        self.write_value_type(module, value)?
2938                    }
2939                };
2940                write!(self.out, " {name} = ")?;
2941                self.named_expressions.insert(result, name);
2942                match mode {
2943                    crate::GatherMode::BroadcastFirst => {
2944                        write!(self.out, "WaveReadLaneFirst(")?;
2945                        self.write_expr(module, argument, func_ctx)?;
2946                    }
2947                    crate::GatherMode::QuadBroadcast(index) => {
2948                        write!(self.out, "QuadReadLaneAt(")?;
2949                        self.write_expr(module, argument, func_ctx)?;
2950                        write!(self.out, ", ")?;
2951                        self.write_expr(module, index, func_ctx)?;
2952                    }
2953                    crate::GatherMode::QuadSwap(direction) => {
2954                        match direction {
2955                            crate::Direction::X => {
2956                                write!(self.out, "QuadReadAcrossX(")?;
2957                            }
2958                            crate::Direction::Y => {
2959                                write!(self.out, "QuadReadAcrossY(")?;
2960                            }
2961                            crate::Direction::Diagonal => {
2962                                write!(self.out, "QuadReadAcrossDiagonal(")?;
2963                            }
2964                        }
2965                        self.write_expr(module, argument, func_ctx)?;
2966                    }
2967                    _ => {
2968                        write!(self.out, "WaveReadLaneAt(")?;
2969                        self.write_expr(module, argument, func_ctx)?;
2970                        write!(self.out, ", ")?;
2971                        match mode {
2972                            crate::GatherMode::BroadcastFirst => unreachable!(),
2973                            crate::GatherMode::Broadcast(index)
2974                            | crate::GatherMode::Shuffle(index) => {
2975                                self.write_expr(module, index, func_ctx)?;
2976                            }
2977                            crate::GatherMode::ShuffleDown(index) => {
2978                                write!(self.out, "WaveGetLaneIndex() + ")?;
2979                                self.write_expr(module, index, func_ctx)?;
2980                            }
2981                            crate::GatherMode::ShuffleUp(index) => {
2982                                write!(self.out, "WaveGetLaneIndex() - ")?;
2983                                self.write_expr(module, index, func_ctx)?;
2984                            }
2985                            crate::GatherMode::ShuffleXor(index) => {
2986                                write!(self.out, "WaveGetLaneIndex() ^ ")?;
2987                                self.write_expr(module, index, func_ctx)?;
2988                            }
2989                            crate::GatherMode::QuadBroadcast(_) => unreachable!(),
2990                            crate::GatherMode::QuadSwap(_) => unreachable!(),
2991                        }
2992                    }
2993                }
2994                writeln!(self.out, ");")?;
2995            }
2996            Statement::CooperativeStore { .. } => unimplemented!(),
2997            Statement::RayPipelineFunction(_) => unreachable!(),
2998        }
2999
3000        Ok(())
3001    }
3002
3003    fn write_const_expression(
3004        &mut self,
3005        module: &Module,
3006        expr: Handle<crate::Expression>,
3007        arena: &crate::Arena<crate::Expression>,
3008    ) -> BackendResult {
3009        self.write_possibly_const_expression(module, expr, arena, |writer, expr| {
3010            writer.write_const_expression(module, expr, arena)
3011        })
3012    }
3013
3014    pub(super) fn write_literal(&mut self, literal: crate::Literal) -> BackendResult {
3015        match literal {
3016            crate::Literal::F64(value) => write!(self.out, "{value:?}L")?,
3017            crate::Literal::F32(value) => write!(self.out, "{value:?}")?,
3018            crate::Literal::F16(value) => write!(self.out, "{value:?}h")?,
3019            crate::Literal::U32(value) => write!(self.out, "{value}u")?,
3020            // `-2147483648` is parsed by some compilers as unary negation of
3021            // positive 2147483648, which is too large for an int, causing
3022            // issues for some compilers. Neither DXC nor FXC appear to have
3023            // this problem, but this is not specified and could change. We
3024            // therefore use `-2147483647 - 1` as a precaution.
3025            crate::Literal::I32(value) if value == i32::MIN => {
3026                write!(self.out, "int({} - 1)", value + 1)?
3027            }
3028            // HLSL has no suffix for explicit i32 literals, but not using any suffix
3029            // makes the type ambiguous which prevents overload resolution from
3030            // working. So we explicitly use the int() constructor syntax.
3031            crate::Literal::I32(value) => write!(self.out, "int({value})")?,
3032            crate::Literal::U64(value) => write!(self.out, "{value}uL")?,
3033            // I64 version of the minimum I32 value issue described above.
3034            crate::Literal::I64(value) if value == i64::MIN => {
3035                write!(self.out, "({}L - 1L)", value + 1)?;
3036            }
3037            crate::Literal::I64(value) => write!(self.out, "{value}L")?,
3038            crate::Literal::Bool(value) => write!(self.out, "{value}")?,
3039            crate::Literal::AbstractInt(_) | crate::Literal::AbstractFloat(_) => {
3040                return Err(Error::Custom(
3041                    "Abstract types should not appear in IR presented to backends".into(),
3042                ));
3043            }
3044        }
3045        Ok(())
3046    }
3047
3048    fn write_possibly_const_expression<E>(
3049        &mut self,
3050        module: &Module,
3051        expr: Handle<crate::Expression>,
3052        expressions: &crate::Arena<crate::Expression>,
3053        write_expression: E,
3054    ) -> BackendResult
3055    where
3056        E: Fn(&mut Self, Handle<crate::Expression>) -> BackendResult,
3057    {
3058        use crate::Expression;
3059
3060        match expressions[expr] {
3061            Expression::Literal(literal) => {
3062                self.write_literal(literal)?;
3063            }
3064            Expression::Constant(handle) => {
3065                let constant = &module.constants[handle];
3066                if constant.name.is_some() {
3067                    write!(self.out, "{}", self.names[&NameKey::Constant(handle)])?;
3068                } else {
3069                    self.write_const_expression(module, constant.init, &module.global_expressions)?;
3070                }
3071            }
3072            Expression::ZeroValue(ty) => {
3073                self.write_wrapped_zero_value_function_name(module, WrappedZeroValue { ty })?;
3074                write!(self.out, "()")?;
3075            }
3076            Expression::Compose { ty, ref components } => {
3077                match module.types[ty].inner {
3078                    TypeInner::Struct { .. } | TypeInner::Array { .. } => {
3079                        self.write_wrapped_constructor_function_name(
3080                            module,
3081                            WrappedConstructor { ty },
3082                        )?;
3083                    }
3084                    _ => {
3085                        self.write_type(module, ty)?;
3086                    }
3087                };
3088                write!(self.out, "(")?;
3089                for (index, component) in components.iter().enumerate() {
3090                    if index != 0 {
3091                        write!(self.out, ", ")?;
3092                    }
3093                    write_expression(self, *component)?;
3094                }
3095                write!(self.out, ")")?;
3096            }
3097            Expression::Splat { size, value } => {
3098                // hlsl is not supported one value constructor
3099                // if we write, for example, int4(0), dxc returns error:
3100                // error: too few elements in vector initialization (expected 4 elements, have 1)
3101                let number_of_components = match size {
3102                    crate::VectorSize::Bi => "xx",
3103                    crate::VectorSize::Tri => "xxx",
3104                    crate::VectorSize::Quad => "xxxx",
3105                };
3106                write!(self.out, "(")?;
3107                write_expression(self, value)?;
3108                write!(self.out, ").{number_of_components}")?
3109            }
3110            _ => {
3111                return Err(Error::Override);
3112            }
3113        }
3114
3115        Ok(())
3116    }
3117
3118    /// Helper method to write expressions
3119    ///
3120    /// # Notes
3121    /// Doesn't add any newlines or leading/trailing spaces
3122    pub(super) fn write_expr(
3123        &mut self,
3124        module: &Module,
3125        expr: Handle<crate::Expression>,
3126        func_ctx: &back::FunctionCtx<'_>,
3127    ) -> BackendResult {
3128        use crate::Expression;
3129
3130        // Handle the special semantics of vertex_index/instance_index
3131        let ff_input = if self.options.special_constants_binding.is_some() {
3132            func_ctx.is_fixed_function_input(expr, module)
3133        } else {
3134            None
3135        };
3136        let closing_bracket = match ff_input {
3137            Some(crate::BuiltIn::VertexIndex) => {
3138                write!(self.out, "({SPECIAL_CBUF_VAR}.{SPECIAL_FIRST_VERTEX} + ")?;
3139                ")"
3140            }
3141            Some(crate::BuiltIn::InstanceIndex) => {
3142                write!(self.out, "({SPECIAL_CBUF_VAR}.{SPECIAL_FIRST_INSTANCE} + ",)?;
3143                ")"
3144            }
3145            Some(crate::BuiltIn::NumWorkGroups) => {
3146                // Note: despite their names (`FIRST_VERTEX` and `FIRST_INSTANCE`),
3147                // in compute shaders the special constants contain the number
3148                // of workgroups, which we are using here.
3149                write!(
3150                    self.out,
3151                    "uint3({SPECIAL_CBUF_VAR}.{SPECIAL_FIRST_VERTEX}, {SPECIAL_CBUF_VAR}.{SPECIAL_FIRST_INSTANCE}, {SPECIAL_CBUF_VAR}.{SPECIAL_OTHER})",
3152                )?;
3153                return Ok(());
3154            }
3155            _ => "",
3156        };
3157
3158        if let Some(name) = self.named_expressions.get(&expr) {
3159            write!(self.out, "{name}{closing_bracket}")?;
3160            return Ok(());
3161        }
3162
3163        let expression = &func_ctx.expressions[expr];
3164
3165        match *expression {
3166            Expression::Literal(_)
3167            | Expression::Constant(_)
3168            | Expression::ZeroValue(_)
3169            | Expression::Compose { .. }
3170            | Expression::Splat { .. } => {
3171                self.write_possibly_const_expression(
3172                    module,
3173                    expr,
3174                    func_ctx.expressions,
3175                    |writer, expr| writer.write_expr(module, expr, func_ctx),
3176                )?;
3177            }
3178            Expression::Override(_) => return Err(Error::Override),
3179            // Avoid undefined behaviour for addition, subtraction, and
3180            // multiplication of signed integers by casting operands to
3181            // unsigned, performing the operation, then casting the result back
3182            // to signed.
3183            // TODO(#7109): This relies on the asint()/asuint() functions which only work
3184            // for 32-bit types, so we must find another solution for different bit widths.
3185            Expression::Binary {
3186                op:
3187                    op @ crate::BinaryOperator::Add
3188                    | op @ crate::BinaryOperator::Subtract
3189                    | op @ crate::BinaryOperator::Multiply,
3190                left,
3191                right,
3192            } if matches!(
3193                func_ctx.resolve_type(expr, &module.types).scalar(),
3194                Some(Scalar::I32)
3195            ) =>
3196            {
3197                write!(self.out, "asint(asuint(",)?;
3198                self.write_expr(module, left, func_ctx)?;
3199                write!(self.out, ") {} asuint(", back::binary_operation_str(op))?;
3200                self.write_expr(module, right, func_ctx)?;
3201                write!(self.out, "))")?;
3202            }
3203            // All of the multiplication can be expressed as `mul`,
3204            // except vector * vector, which needs to use the "*" operator.
3205            Expression::Binary {
3206                op: crate::BinaryOperator::Multiply,
3207                left,
3208                right,
3209            } if func_ctx.resolve_type(left, &module.types).is_matrix()
3210                || func_ctx.resolve_type(right, &module.types).is_matrix() =>
3211            {
3212                // We intentionally flip the order of multiplication as our matrices are implicitly transposed.
3213                write!(self.out, "mul(")?;
3214                self.write_expr(module, right, func_ctx)?;
3215                write!(self.out, ", ")?;
3216                self.write_expr(module, left, func_ctx)?;
3217                write!(self.out, ")")?;
3218            }
3219
3220            // WGSL says that floating-point division by zero should return
3221            // infinity. Microsoft's Direct3D 11 functional specification
3222            // (https://microsoft.github.io/DirectX-Specs/d3d/archive/D3D11_3_FunctionalSpec.htm)
3223            // says:
3224            //
3225            //     Divide by 0 produces +/- INF, except 0/0 which results in NaN.
3226            //
3227            // which is what we want. The DXIL specification for the FDiv
3228            // instruction corroborates this:
3229            //
3230            // https://github.com/microsoft/DirectXShaderCompiler/blob/main/docs/DXIL.rst#fdiv
3231            Expression::Binary {
3232                op: crate::BinaryOperator::Divide,
3233                left,
3234                right,
3235            } if matches!(
3236                func_ctx.resolve_type(expr, &module.types).scalar_kind(),
3237                Some(ScalarKind::Sint | ScalarKind::Uint)
3238            ) =>
3239            {
3240                write!(self.out, "{DIV_FUNCTION}(")?;
3241                self.write_expr(module, left, func_ctx)?;
3242                write!(self.out, ", ")?;
3243                self.write_expr(module, right, func_ctx)?;
3244                write!(self.out, ")")?;
3245            }
3246
3247            Expression::Binary {
3248                op: crate::BinaryOperator::Modulo,
3249                left,
3250                right,
3251            } if matches!(
3252                func_ctx.resolve_type(expr, &module.types).scalar_kind(),
3253                Some(ScalarKind::Sint | ScalarKind::Uint | ScalarKind::Float)
3254            ) =>
3255            {
3256                write!(self.out, "{MOD_FUNCTION}(")?;
3257                self.write_expr(module, left, func_ctx)?;
3258                write!(self.out, ", ")?;
3259                self.write_expr(module, right, func_ctx)?;
3260                write!(self.out, ")")?;
3261            }
3262
3263            Expression::Binary { op, left, right } => {
3264                write!(self.out, "(")?;
3265                self.write_expr(module, left, func_ctx)?;
3266                write!(self.out, " {} ", back::binary_operation_str(op))?;
3267                self.write_expr(module, right, func_ctx)?;
3268                write!(self.out, ")")?;
3269            }
3270            Expression::Access { base, index } => {
3271                if let Some(crate::AddressSpace::Storage { .. }) =
3272                    func_ctx.resolve_type(expr, &module.types).pointer_space()
3273                {
3274                    // do nothing, the chain is written on `Load`/`Store`
3275                } else {
3276                    // We use the function __get_col_of_matCx2 here in cases
3277                    // where `base`s type resolves to a matCx2 and is part of a
3278                    // struct member with type of (possibly nested) array of matCx2's.
3279                    //
3280                    // Note that this only works for `Load`s and we handle
3281                    // `Store`s differently in `Statement::Store`.
3282                    if let Some(MatrixType {
3283                        columns,
3284                        rows: crate::VectorSize::Bi,
3285                        width: 4,
3286                    }) = get_inner_matrix_of_struct_array_member(module, base, func_ctx, true)
3287                        .or_else(|| get_global_uniform_matrix(module, base, func_ctx))
3288                    {
3289                        write!(self.out, "__get_col_of_mat{}x2(", columns as u8)?;
3290                        self.write_expr(module, base, func_ctx)?;
3291                        write!(self.out, ", ")?;
3292                        self.write_expr(module, index, func_ctx)?;
3293                        write!(self.out, ")")?;
3294                        return Ok(());
3295                    }
3296
3297                    let resolved = func_ctx.resolve_type(base, &module.types);
3298
3299                    let (indexing_binding_array, non_uniform_qualifier) = match *resolved {
3300                        TypeInner::BindingArray { .. } => {
3301                            let uniformity = &func_ctx.info[index].uniformity;
3302
3303                            (true, uniformity.non_uniform_result.is_some())
3304                        }
3305                        _ => (false, false),
3306                    };
3307
3308                    self.write_expr(module, base, func_ctx)?;
3309
3310                    let array_sampler_info = self.sampler_binding_array_info_from_expression(
3311                        module, func_ctx, base, resolved,
3312                    );
3313
3314                    if let Some(ref info) = array_sampler_info {
3315                        write!(self.out, "{}[", info.sampler_heap_name)?;
3316                    } else {
3317                        write!(self.out, "[")?;
3318                    }
3319
3320                    let needs_bound_check = self.options.restrict_indexing
3321                        && !indexing_binding_array
3322                        && match resolved.pointer_space() {
3323                            Some(
3324                                crate::AddressSpace::Function
3325                                | crate::AddressSpace::Private
3326                                | crate::AddressSpace::WorkGroup
3327                                | crate::AddressSpace::Immediate
3328                                | crate::AddressSpace::TaskPayload
3329                                | crate::AddressSpace::RayPayload
3330                                | crate::AddressSpace::IncomingRayPayload,
3331                            )
3332                            | None => true,
3333                            Some(crate::AddressSpace::Uniform) => {
3334                                // check if BindTarget.restrict_indexing is set, this is used for dynamic buffers
3335                                let var_handle = self.fill_access_chain(module, base, func_ctx)?;
3336                                let bind_target = self
3337                                    .options
3338                                    .resolve_resource_binding(
3339                                        module.global_variables[var_handle]
3340                                            .binding
3341                                            .as_ref()
3342                                            .unwrap(),
3343                                    )
3344                                    .unwrap();
3345                                bind_target.restrict_indexing
3346                            }
3347                            Some(
3348                                crate::AddressSpace::Handle | crate::AddressSpace::Storage { .. },
3349                            ) => unreachable!(),
3350                        };
3351                    // Decide whether this index needs to be clamped to fall within range.
3352                    let restriction_needed = if needs_bound_check {
3353                        index::access_needs_check(
3354                            base,
3355                            index::GuardedIndex::Expression(index),
3356                            module,
3357                            func_ctx.expressions,
3358                            func_ctx.info,
3359                        )
3360                    } else {
3361                        None
3362                    };
3363                    if let Some(limit) = restriction_needed {
3364                        write!(self.out, "min(uint(")?;
3365                        self.write_expr(module, index, func_ctx)?;
3366                        write!(self.out, "), ")?;
3367                        match limit {
3368                            index::IndexableLength::Known(limit) => {
3369                                write!(self.out, "{}u", limit - 1)?;
3370                            }
3371                            index::IndexableLength::Dynamic => unreachable!(),
3372                        }
3373                        write!(self.out, ")")?;
3374                    } else {
3375                        if non_uniform_qualifier {
3376                            write!(self.out, "NonUniformResourceIndex(")?;
3377                        }
3378                        if let Some(ref info) = array_sampler_info {
3379                            write!(
3380                                self.out,
3381                                "{}[{} + ",
3382                                info.sampler_index_buffer_name, info.binding_array_base_index_name,
3383                            )?;
3384                        }
3385                        self.write_expr(module, index, func_ctx)?;
3386                        if array_sampler_info.is_some() {
3387                            write!(self.out, "]")?;
3388                        }
3389                        if non_uniform_qualifier {
3390                            write!(self.out, ")")?;
3391                        }
3392                    }
3393
3394                    write!(self.out, "]")?;
3395                }
3396            }
3397            Expression::AccessIndex { base, index } => {
3398                if let Some(crate::AddressSpace::Storage { .. }) =
3399                    func_ctx.resolve_type(expr, &module.types).pointer_space()
3400                {
3401                    // do nothing, the chain is written on `Load`/`Store`
3402                } else {
3403                    // See if we need to write the matrix column access in a
3404                    // special way since the type of `base` is our special
3405                    // __matCx2 struct.
3406                    if let Some(MatrixType {
3407                        rows: crate::VectorSize::Bi,
3408                        width: 4,
3409                        ..
3410                    }) = get_inner_matrix_of_struct_array_member(module, base, func_ctx, true)
3411                        .or_else(|| get_global_uniform_matrix(module, base, func_ctx))
3412                    {
3413                        self.write_expr(module, base, func_ctx)?;
3414                        write!(self.out, "._{index}")?;
3415                        return Ok(());
3416                    }
3417
3418                    let base_ty_res = &func_ctx.info[base].ty;
3419                    let mut resolved = base_ty_res.inner_with(&module.types);
3420                    let base_ty_handle = match *resolved {
3421                        TypeInner::Pointer { base, .. } => {
3422                            resolved = &module.types[base].inner;
3423                            Some(base)
3424                        }
3425                        _ => base_ty_res.handle(),
3426                    };
3427
3428                    // We treat matrices of the form `matCx2` as a sequence of C `vec2`s.
3429                    // See the module-level block comment in mod.rs for details.
3430                    //
3431                    // We handle matrix reconstruction here for Loads.
3432                    // Stores are handled directly by `Statement::Store`.
3433                    if let TypeInner::Struct { ref members, .. } = *resolved {
3434                        let member = &members[index as usize];
3435
3436                        match module.types[member.ty].inner {
3437                            TypeInner::Matrix {
3438                                rows: crate::VectorSize::Bi,
3439                                ..
3440                            } if member.binding.is_none() => {
3441                                let ty = base_ty_handle.unwrap();
3442                                self.write_wrapped_struct_matrix_get_function_name(
3443                                    WrappedStructMatrixAccess { ty, index },
3444                                )?;
3445                                write!(self.out, "(")?;
3446                                self.write_expr(module, base, func_ctx)?;
3447                                write!(self.out, ")")?;
3448                                return Ok(());
3449                            }
3450                            _ => {}
3451                        }
3452                    }
3453
3454                    let array_sampler_info = self.sampler_binding_array_info_from_expression(
3455                        module, func_ctx, base, resolved,
3456                    );
3457
3458                    if let Some(ref info) = array_sampler_info {
3459                        write!(
3460                            self.out,
3461                            "{}[{}",
3462                            info.sampler_heap_name, info.sampler_index_buffer_name
3463                        )?;
3464                    }
3465
3466                    self.write_expr(module, base, func_ctx)?;
3467
3468                    match *resolved {
3469                        // We specifically lift the ValuePointer to this case. While `[0]` is valid
3470                        // HLSL for any vector behind a value pointer, FXC completely miscompiles
3471                        // it and generates completely nonsensical DXBC.
3472                        //
3473                        // See https://github.com/gfx-rs/naga/issues/2095 for more details.
3474                        TypeInner::Vector { .. } | TypeInner::ValuePointer { .. } => {
3475                            // Write vector access as a swizzle
3476                            write!(self.out, ".{}", back::COMPONENTS[index as usize])?
3477                        }
3478                        TypeInner::Matrix { .. }
3479                        | TypeInner::Array { .. }
3480                        | TypeInner::BindingArray { .. } => {
3481                            if let Some(ref info) = array_sampler_info {
3482                                write!(
3483                                    self.out,
3484                                    "[{} + {index}]",
3485                                    info.binding_array_base_index_name
3486                                )?;
3487                            } else {
3488                                write!(self.out, "[{index}]")?;
3489                            }
3490                        }
3491                        TypeInner::Struct { .. } => {
3492                            // This will never panic in case the type is a `Struct`, this is not true
3493                            // for other types so we can only check while inside this match arm
3494                            let ty = base_ty_handle.unwrap();
3495
3496                            write!(
3497                                self.out,
3498                                ".{}",
3499                                &self.names[&NameKey::StructMember(ty, index)]
3500                            )?
3501                        }
3502                        ref other => return Err(Error::Custom(format!("Cannot index {other:?}"))),
3503                    }
3504
3505                    if array_sampler_info.is_some() {
3506                        write!(self.out, "]")?;
3507                    }
3508                }
3509            }
3510            Expression::FunctionArgument(pos) => {
3511                let ty = func_ctx.resolve_type(expr, &module.types);
3512
3513                // We know that any external texture function argument has been expanded into
3514                // separate consecutive arguments for each plane and the parameters buffer. And we
3515                // also know that external textures can only ever be used as an argument to another
3516                // function. Therefore we can simply emit each of the expanded arguments in a
3517                // consecutive comma-separated list.
3518                if let TypeInner::Image {
3519                    class: crate::ImageClass::External,
3520                    ..
3521                } = *ty
3522                {
3523                    let plane_names = [0, 1, 2].map(|i| {
3524                        &self.names[&func_ctx
3525                            .external_texture_argument_key(pos, ExternalTextureNameKey::Plane(i))]
3526                    });
3527                    let params_name = &self.names[&func_ctx
3528                        .external_texture_argument_key(pos, ExternalTextureNameKey::Params)];
3529                    write!(
3530                        self.out,
3531                        "{}, {}, {}, {}",
3532                        plane_names[0], plane_names[1], plane_names[2], params_name
3533                    )?;
3534                } else {
3535                    let key = func_ctx.argument_key(pos);
3536                    let name = &self.names[&key];
3537                    write!(self.out, "{name}")?;
3538                }
3539            }
3540            Expression::ImageSample {
3541                coordinate,
3542                image,
3543                sampler,
3544                clamp_to_edge: true,
3545                gather: None,
3546                array_index: None,
3547                offset: None,
3548                level: crate::SampleLevel::Zero,
3549                depth_ref: None,
3550            } => {
3551                write!(self.out, "{IMAGE_SAMPLE_BASE_CLAMP_TO_EDGE_FUNCTION}(")?;
3552                self.write_expr(module, image, func_ctx)?;
3553                write!(self.out, ", ")?;
3554                self.write_expr(module, sampler, func_ctx)?;
3555                write!(self.out, ", ")?;
3556                self.write_expr(module, coordinate, func_ctx)?;
3557                write!(self.out, ")")?;
3558            }
3559            Expression::ImageSample {
3560                image,
3561                sampler,
3562                gather,
3563                coordinate,
3564                array_index,
3565                offset,
3566                level,
3567                depth_ref,
3568                clamp_to_edge,
3569            } => {
3570                if clamp_to_edge {
3571                    return Err(Error::Custom(
3572                        "ImageSample::clamp_to_edge should have been validated out".to_string(),
3573                    ));
3574                }
3575
3576                use crate::SampleLevel as Sl;
3577                const COMPONENTS: [&str; 4] = ["", "Green", "Blue", "Alpha"];
3578
3579                let (base_str, component_str) = match gather {
3580                    Some(component) => ("Gather", COMPONENTS[component as usize]),
3581                    None => ("Sample", ""),
3582                };
3583                let cmp_str = match depth_ref {
3584                    Some(_) => "Cmp",
3585                    None => "",
3586                };
3587                let level_str = match level {
3588                    Sl::Zero if gather.is_none() => "LevelZero",
3589                    Sl::Auto | Sl::Zero => "",
3590                    Sl::Exact(_) => "Level",
3591                    Sl::Bias(_) => "Bias",
3592                    Sl::Gradient { .. } => "Grad",
3593                };
3594
3595                self.write_expr(module, image, func_ctx)?;
3596                write!(self.out, ".{base_str}{cmp_str}{component_str}{level_str}(")?;
3597                self.write_expr(module, sampler, func_ctx)?;
3598                write!(self.out, ", ")?;
3599                self.write_texture_coordinates(
3600                    "float",
3601                    coordinate,
3602                    array_index,
3603                    None,
3604                    module,
3605                    func_ctx,
3606                )?;
3607
3608                if let Some(depth_ref) = depth_ref {
3609                    write!(self.out, ", ")?;
3610                    self.write_expr(module, depth_ref, func_ctx)?;
3611                }
3612
3613                match level {
3614                    Sl::Auto | Sl::Zero => {}
3615                    Sl::Exact(expr) => {
3616                        write!(self.out, ", ")?;
3617                        self.write_expr(module, expr, func_ctx)?;
3618                    }
3619                    Sl::Bias(expr) => {
3620                        write!(self.out, ", ")?;
3621                        self.write_expr(module, expr, func_ctx)?;
3622                    }
3623                    Sl::Gradient { x, y } => {
3624                        write!(self.out, ", ")?;
3625                        self.write_expr(module, x, func_ctx)?;
3626                        write!(self.out, ", ")?;
3627                        self.write_expr(module, y, func_ctx)?;
3628                    }
3629                }
3630
3631                if let Some(offset) = offset {
3632                    write!(self.out, ", ")?;
3633                    write!(self.out, "int2(")?; // work around https://github.com/microsoft/DirectXShaderCompiler/issues/5082#issuecomment-1540147807
3634                    self.write_const_expression(module, offset, func_ctx.expressions)?;
3635                    write!(self.out, ")")?;
3636                }
3637
3638                write!(self.out, ")")?;
3639            }
3640            Expression::ImageQuery { image, query } => {
3641                // use wrapped image query function
3642                if let TypeInner::Image {
3643                    dim,
3644                    arrayed,
3645                    class,
3646                } = *func_ctx.resolve_type(image, &module.types)
3647                {
3648                    let wrapped_image_query = WrappedImageQuery {
3649                        dim,
3650                        arrayed,
3651                        class,
3652                        query: query.into(),
3653                    };
3654
3655                    self.write_wrapped_image_query_function_name(wrapped_image_query)?;
3656                    write!(self.out, "(")?;
3657                    // Image always first param
3658                    self.write_expr(module, image, func_ctx)?;
3659                    if let crate::ImageQuery::Size { level: Some(level) } = query {
3660                        write!(self.out, ", ")?;
3661                        self.write_expr(module, level, func_ctx)?;
3662                    }
3663                    write!(self.out, ")")?;
3664                }
3665            }
3666            Expression::ImageLoad {
3667                image,
3668                coordinate,
3669                array_index,
3670                sample,
3671                level,
3672            } => self.write_image_load(
3673                &module,
3674                expr,
3675                func_ctx,
3676                image,
3677                coordinate,
3678                array_index,
3679                sample,
3680                level,
3681            )?,
3682            Expression::GlobalVariable(handle) => {
3683                let global_variable = &module.global_variables[handle];
3684                let ty = &module.types[global_variable.ty].inner;
3685
3686                // In the case of binding arrays of samplers, we need to not write anything
3687                // as the we are in the wrong position to fully write the expression.
3688                //
3689                // The entire writing is done by AccessIndex.
3690                let is_binding_array_of_samplers = match *ty {
3691                    TypeInner::BindingArray { base, .. } => {
3692                        let base_ty = &module.types[base].inner;
3693                        matches!(*base_ty, TypeInner::Sampler { .. })
3694                    }
3695                    _ => false,
3696                };
3697
3698                let is_storage_space =
3699                    matches!(global_variable.space, crate::AddressSpace::Storage { .. });
3700
3701                // Our external texture global variable has been expanded into multiple
3702                // global variables, one for each plane and the parameters buffer.
3703                // External textures can only ever be used as arguments to a function
3704                // call, and we know that an external texture argument to any function
3705                // will have been expanded to separate consecutive arguments for each
3706                // plane and the parameters buffer. Therefore we can simply emit each of
3707                // the expanded global variables in a consecutive comma-separated list.
3708                if let TypeInner::Image {
3709                    class: crate::ImageClass::External,
3710                    ..
3711                } = *ty
3712                {
3713                    let plane_names = [0, 1, 2].map(|i| {
3714                        &self.names[&NameKey::ExternalTextureGlobalVariable(
3715                            handle,
3716                            ExternalTextureNameKey::Plane(i),
3717                        )]
3718                    });
3719                    let params_name = &self.names[&NameKey::ExternalTextureGlobalVariable(
3720                        handle,
3721                        ExternalTextureNameKey::Params,
3722                    )];
3723                    write!(
3724                        self.out,
3725                        "{}, {}, {}, {}",
3726                        plane_names[0], plane_names[1], plane_names[2], params_name
3727                    )?;
3728                } else if !is_binding_array_of_samplers && !is_storage_space {
3729                    let name = &self.names[&NameKey::GlobalVariable(handle)];
3730                    write!(self.out, "{name}")?;
3731                }
3732            }
3733            Expression::LocalVariable(handle) => {
3734                write!(self.out, "{}", self.names[&func_ctx.name_key(handle)])?
3735            }
3736            Expression::Load { pointer } => {
3737                match func_ctx
3738                    .resolve_type(pointer, &module.types)
3739                    .pointer_space()
3740                {
3741                    Some(crate::AddressSpace::Storage { .. }) => {
3742                        let var_handle = self.fill_access_chain(module, pointer, func_ctx)?;
3743                        let result_ty = func_ctx.info[expr].ty.clone();
3744                        self.write_storage_load(module, var_handle, result_ty, func_ctx)?;
3745                    }
3746                    _ => {
3747                        let mut close_paren = false;
3748
3749                        // We cast the value loaded to a native HLSL floatCx2
3750                        // in cases where it is of type:
3751                        //  - __matCx2 or
3752                        //  - a (possibly nested) array of __matCx2's
3753                        if let Some(MatrixType {
3754                            rows: crate::VectorSize::Bi,
3755                            width: 4,
3756                            ..
3757                        }) = get_inner_matrix_of_struct_array_member(
3758                            module, pointer, func_ctx, false,
3759                        )
3760                        .or_else(|| get_inner_matrix_of_global_uniform(module, pointer, func_ctx))
3761                        {
3762                            let mut resolved = func_ctx.resolve_type(pointer, &module.types);
3763                            let ptr_tr = resolved.pointer_base_type();
3764                            if let Some(ptr_ty) =
3765                                ptr_tr.as_ref().map(|tr| tr.inner_with(&module.types))
3766                            {
3767                                resolved = ptr_ty;
3768                            }
3769
3770                            write!(self.out, "((")?;
3771                            if let TypeInner::Array { base, size, .. } = *resolved {
3772                                self.write_type(module, base)?;
3773                                self.write_array_size(module, base, size)?;
3774                            } else {
3775                                self.write_value_type(module, resolved)?;
3776                            }
3777                            write!(self.out, ")")?;
3778                            close_paren = true;
3779                        }
3780
3781                        self.write_expr(module, pointer, func_ctx)?;
3782
3783                        if close_paren {
3784                            write!(self.out, ")")?;
3785                        }
3786                    }
3787                }
3788            }
3789            Expression::Unary { op, expr } => {
3790                // https://docs.microsoft.com/en-us/windows/win32/direct3dhlsl/dx-graphics-hlsl-operators#unary-operators
3791                let op_str = match op {
3792                    crate::UnaryOperator::Negate => {
3793                        match func_ctx.resolve_type(expr, &module.types).scalar() {
3794                            Some(Scalar::I32) => NEG_FUNCTION,
3795                            _ => "-",
3796                        }
3797                    }
3798                    crate::UnaryOperator::LogicalNot => "!",
3799                    crate::UnaryOperator::BitwiseNot => "~",
3800                };
3801                write!(self.out, "{op_str}(")?;
3802                self.write_expr(module, expr, func_ctx)?;
3803                write!(self.out, ")")?;
3804            }
3805            Expression::As {
3806                expr,
3807                kind,
3808                convert,
3809            } => {
3810                let inner = func_ctx.resolve_type(expr, &module.types);
3811                if inner.scalar_kind() == Some(ScalarKind::Float)
3812                    && (kind == ScalarKind::Sint || kind == ScalarKind::Uint)
3813                    && convert.is_some()
3814                {
3815                    // Use helper functions for float to int casts in order to
3816                    // avoid undefined behaviour when value is out of range for
3817                    // the target type.
3818                    let fun_name = match (kind, convert) {
3819                        (ScalarKind::Sint, Some(4)) => F2I32_FUNCTION,
3820                        (ScalarKind::Uint, Some(4)) => F2U32_FUNCTION,
3821                        (ScalarKind::Sint, Some(8)) => F2I64_FUNCTION,
3822                        (ScalarKind::Uint, Some(8)) => F2U64_FUNCTION,
3823                        _ => unreachable!(),
3824                    };
3825                    write!(self.out, "{fun_name}(")?;
3826                    self.write_expr(module, expr, func_ctx)?;
3827                    write!(self.out, ")")?;
3828                } else {
3829                    let close_paren = match convert {
3830                        Some(dst_width) => {
3831                            let scalar = Scalar {
3832                                kind,
3833                                width: dst_width,
3834                            };
3835                            match *inner {
3836                                TypeInner::Vector { size, .. } => {
3837                                    write!(
3838                                        self.out,
3839                                        "{}{}(",
3840                                        scalar.to_hlsl_str()?,
3841                                        common::vector_size_str(size)
3842                                    )?;
3843                                }
3844                                TypeInner::Scalar(_) => {
3845                                    write!(self.out, "{}(", scalar.to_hlsl_str()?,)?;
3846                                }
3847                                TypeInner::Matrix { columns, rows, .. } => {
3848                                    write!(
3849                                        self.out,
3850                                        "{}{}x{}(",
3851                                        scalar.to_hlsl_str()?,
3852                                        common::vector_size_str(columns),
3853                                        common::vector_size_str(rows)
3854                                    )?;
3855                                }
3856                                _ => {
3857                                    return Err(Error::Unimplemented(format!(
3858                                        "write_expr expression::as {inner:?}"
3859                                    )));
3860                                }
3861                            };
3862                            true
3863                        }
3864                        None => {
3865                            if inner.scalar_width() == Some(8) {
3866                                false
3867                            } else {
3868                                write!(self.out, "{}(", kind.to_hlsl_cast(),)?;
3869                                true
3870                            }
3871                        }
3872                    };
3873                    self.write_expr(module, expr, func_ctx)?;
3874                    if close_paren {
3875                        write!(self.out, ")")?;
3876                    }
3877                }
3878            }
3879            Expression::Math {
3880                fun,
3881                arg,
3882                arg1,
3883                arg2,
3884                arg3,
3885            } => {
3886                use crate::MathFunction as Mf;
3887
3888                enum Function {
3889                    Asincosh { is_sin: bool },
3890                    Atanh,
3891                    Pack2x16float,
3892                    Pack2x16snorm,
3893                    Pack2x16unorm,
3894                    Pack4x8snorm,
3895                    Pack4x8unorm,
3896                    Pack4xI8,
3897                    Pack4xU8,
3898                    Pack4xI8Clamp,
3899                    Pack4xU8Clamp,
3900                    Unpack2x16float,
3901                    Unpack2x16snorm,
3902                    Unpack2x16unorm,
3903                    Unpack4x8snorm,
3904                    Unpack4x8unorm,
3905                    Unpack4xI8,
3906                    Unpack4xU8,
3907                    Dot4I8Packed,
3908                    Dot4U8Packed,
3909                    QuantizeToF16,
3910                    Regular(&'static str),
3911                    MissingIntOverload(&'static str),
3912                    MissingIntReturnType(&'static str),
3913                    CountTrailingZeros,
3914                    CountLeadingZeros,
3915                }
3916
3917                let fun = match fun {
3918                    // comparison
3919                    Mf::Abs => match func_ctx.resolve_type(arg, &module.types).scalar() {
3920                        Some(Scalar::I32) => Function::Regular(ABS_FUNCTION),
3921                        _ => Function::Regular("abs"),
3922                    },
3923                    Mf::Min => Function::Regular("min"),
3924                    Mf::Max => Function::Regular("max"),
3925                    Mf::Clamp => Function::Regular("clamp"),
3926                    Mf::Saturate => Function::Regular("saturate"),
3927                    // trigonometry
3928                    Mf::Cos => Function::Regular("cos"),
3929                    Mf::Cosh => Function::Regular("cosh"),
3930                    Mf::Sin => Function::Regular("sin"),
3931                    Mf::Sinh => Function::Regular("sinh"),
3932                    Mf::Tan => Function::Regular("tan"),
3933                    Mf::Tanh => Function::Regular("tanh"),
3934                    Mf::Acos => Function::Regular("acos"),
3935                    Mf::Asin => Function::Regular("asin"),
3936                    Mf::Atan => Function::Regular("atan"),
3937                    Mf::Atan2 => Function::Regular("atan2"),
3938                    Mf::Asinh => Function::Asincosh { is_sin: true },
3939                    Mf::Acosh => Function::Asincosh { is_sin: false },
3940                    Mf::Atanh => Function::Atanh,
3941                    Mf::Radians => Function::Regular("radians"),
3942                    Mf::Degrees => Function::Regular("degrees"),
3943                    // decomposition
3944                    Mf::Ceil => Function::Regular("ceil"),
3945                    Mf::Floor => Function::Regular("floor"),
3946                    Mf::Round => Function::Regular("round"),
3947                    Mf::Fract => Function::Regular("frac"),
3948                    Mf::Trunc => Function::Regular("trunc"),
3949                    Mf::Modf => Function::Regular(MODF_FUNCTION),
3950                    Mf::Frexp => Function::Regular(FREXP_FUNCTION),
3951                    Mf::Ldexp => Function::Regular("ldexp"),
3952                    // exponent
3953                    Mf::Exp => Function::Regular("exp"),
3954                    Mf::Exp2 => Function::Regular("exp2"),
3955                    Mf::Log => Function::Regular("log"),
3956                    Mf::Log2 => Function::Regular("log2"),
3957                    Mf::Pow => Function::Regular("pow"),
3958                    // geometry
3959                    Mf::Dot => Function::Regular("dot"),
3960                    Mf::Dot4I8Packed => Function::Dot4I8Packed,
3961                    Mf::Dot4U8Packed => Function::Dot4U8Packed,
3962                    //Mf::Outer => ,
3963                    Mf::Cross => Function::Regular("cross"),
3964                    Mf::Distance => Function::Regular("distance"),
3965                    Mf::Length => Function::Regular("length"),
3966                    Mf::Normalize => Function::Regular("normalize"),
3967                    Mf::FaceForward => Function::Regular("faceforward"),
3968                    Mf::Reflect => Function::Regular("reflect"),
3969                    Mf::Refract => Function::Regular("refract"),
3970                    // computational
3971                    Mf::Sign => Function::Regular("sign"),
3972                    Mf::Fma => Function::Regular("mad"),
3973                    Mf::Mix => Function::Regular("lerp"),
3974                    Mf::Step => Function::Regular("step"),
3975                    Mf::SmoothStep => Function::Regular("smoothstep"),
3976                    Mf::Sqrt => Function::Regular("sqrt"),
3977                    Mf::InverseSqrt => Function::Regular("rsqrt"),
3978                    //Mf::Inverse =>,
3979                    Mf::Transpose => Function::Regular("transpose"),
3980                    Mf::Determinant => Function::Regular("determinant"),
3981                    Mf::QuantizeToF16 => Function::QuantizeToF16,
3982                    // bits
3983                    Mf::CountTrailingZeros => Function::CountTrailingZeros,
3984                    Mf::CountLeadingZeros => Function::CountLeadingZeros,
3985                    Mf::CountOneBits => Function::MissingIntOverload("countbits"),
3986                    Mf::ReverseBits => Function::MissingIntOverload("reversebits"),
3987                    Mf::FirstTrailingBit => Function::MissingIntReturnType("firstbitlow"),
3988                    Mf::FirstLeadingBit => Function::MissingIntReturnType("firstbithigh"),
3989                    Mf::ExtractBits => Function::Regular(EXTRACT_BITS_FUNCTION),
3990                    Mf::InsertBits => Function::Regular(INSERT_BITS_FUNCTION),
3991                    // Data Packing
3992                    Mf::Pack2x16float => Function::Pack2x16float,
3993                    Mf::Pack2x16snorm => Function::Pack2x16snorm,
3994                    Mf::Pack2x16unorm => Function::Pack2x16unorm,
3995                    Mf::Pack4x8snorm => Function::Pack4x8snorm,
3996                    Mf::Pack4x8unorm => Function::Pack4x8unorm,
3997                    Mf::Pack4xI8 => Function::Pack4xI8,
3998                    Mf::Pack4xU8 => Function::Pack4xU8,
3999                    Mf::Pack4xI8Clamp => Function::Pack4xI8Clamp,
4000                    Mf::Pack4xU8Clamp => Function::Pack4xU8Clamp,
4001                    // Data Unpacking
4002                    Mf::Unpack2x16float => Function::Unpack2x16float,
4003                    Mf::Unpack2x16snorm => Function::Unpack2x16snorm,
4004                    Mf::Unpack2x16unorm => Function::Unpack2x16unorm,
4005                    Mf::Unpack4x8snorm => Function::Unpack4x8snorm,
4006                    Mf::Unpack4x8unorm => Function::Unpack4x8unorm,
4007                    Mf::Unpack4xI8 => Function::Unpack4xI8,
4008                    Mf::Unpack4xU8 => Function::Unpack4xU8,
4009                    _ => return Err(Error::Unimplemented(format!("write_expr_math {fun:?}"))),
4010                };
4011
4012                match fun {
4013                    Function::Asincosh { is_sin } => {
4014                        write!(self.out, "log(")?;
4015                        self.write_expr(module, arg, func_ctx)?;
4016                        write!(self.out, " + sqrt(")?;
4017                        self.write_expr(module, arg, func_ctx)?;
4018                        write!(self.out, " * ")?;
4019                        self.write_expr(module, arg, func_ctx)?;
4020                        match is_sin {
4021                            true => write!(self.out, " + 1.0))")?,
4022                            false => write!(self.out, " - 1.0))")?,
4023                        }
4024                    }
4025                    Function::Atanh => {
4026                        write!(self.out, "0.5 * log((1.0 + ")?;
4027                        self.write_expr(module, arg, func_ctx)?;
4028                        write!(self.out, ") / (1.0 - ")?;
4029                        self.write_expr(module, arg, func_ctx)?;
4030                        write!(self.out, "))")?;
4031                    }
4032                    Function::Pack2x16float => {
4033                        write!(self.out, "(f32tof16(")?;
4034                        self.write_expr(module, arg, func_ctx)?;
4035                        write!(self.out, "[0]) | f32tof16(")?;
4036                        self.write_expr(module, arg, func_ctx)?;
4037                        write!(self.out, "[1]) << 16)")?;
4038                    }
4039                    Function::Pack2x16snorm => {
4040                        let scale = 32767;
4041
4042                        write!(self.out, "uint((int(round(clamp(")?;
4043                        self.write_expr(module, arg, func_ctx)?;
4044                        write!(
4045                            self.out,
4046                            "[0], -1.0, 1.0) * {scale}.0)) & 0xFFFF) | ((int(round(clamp("
4047                        )?;
4048                        self.write_expr(module, arg, func_ctx)?;
4049                        write!(self.out, "[1], -1.0, 1.0) * {scale}.0)) & 0xFFFF) << 16))",)?;
4050                    }
4051                    Function::Pack2x16unorm => {
4052                        let scale = 65535;
4053
4054                        write!(self.out, "(uint(round(clamp(")?;
4055                        self.write_expr(module, arg, func_ctx)?;
4056                        write!(self.out, "[0], 0.0, 1.0) * {scale}.0)) | uint(round(clamp(")?;
4057                        self.write_expr(module, arg, func_ctx)?;
4058                        write!(self.out, "[1], 0.0, 1.0) * {scale}.0)) << 16)")?;
4059                    }
4060                    Function::Pack4x8snorm => {
4061                        let scale = 127;
4062
4063                        write!(self.out, "uint((int(round(clamp(")?;
4064                        self.write_expr(module, arg, func_ctx)?;
4065                        write!(
4066                            self.out,
4067                            "[0], -1.0, 1.0) * {scale}.0)) & 0xFF) | ((int(round(clamp("
4068                        )?;
4069                        self.write_expr(module, arg, func_ctx)?;
4070                        write!(
4071                            self.out,
4072                            "[1], -1.0, 1.0) * {scale}.0)) & 0xFF) << 8) | ((int(round(clamp("
4073                        )?;
4074                        self.write_expr(module, arg, func_ctx)?;
4075                        write!(
4076                            self.out,
4077                            "[2], -1.0, 1.0) * {scale}.0)) & 0xFF) << 16) | ((int(round(clamp("
4078                        )?;
4079                        self.write_expr(module, arg, func_ctx)?;
4080                        write!(self.out, "[3], -1.0, 1.0) * {scale}.0)) & 0xFF) << 24))",)?;
4081                    }
4082                    Function::Pack4x8unorm => {
4083                        let scale = 255;
4084
4085                        write!(self.out, "(uint(round(clamp(")?;
4086                        self.write_expr(module, arg, func_ctx)?;
4087                        write!(self.out, "[0], 0.0, 1.0) * {scale}.0)) | uint(round(clamp(")?;
4088                        self.write_expr(module, arg, func_ctx)?;
4089                        write!(
4090                            self.out,
4091                            "[1], 0.0, 1.0) * {scale}.0)) << 8 | uint(round(clamp("
4092                        )?;
4093                        self.write_expr(module, arg, func_ctx)?;
4094                        write!(
4095                            self.out,
4096                            "[2], 0.0, 1.0) * {scale}.0)) << 16 | uint(round(clamp("
4097                        )?;
4098                        self.write_expr(module, arg, func_ctx)?;
4099                        write!(self.out, "[3], 0.0, 1.0) * {scale}.0)) << 24)")?;
4100                    }
4101                    fun @ (Function::Pack4xI8
4102                    | Function::Pack4xU8
4103                    | Function::Pack4xI8Clamp
4104                    | Function::Pack4xU8Clamp) => {
4105                        let was_signed =
4106                            matches!(fun, Function::Pack4xI8 | Function::Pack4xI8Clamp);
4107                        let clamp_bounds = match fun {
4108                            Function::Pack4xI8Clamp => Some(("-128", "127")),
4109                            Function::Pack4xU8Clamp => Some(("0", "255")),
4110                            _ => None,
4111                        };
4112                        if was_signed {
4113                            write!(self.out, "uint(")?;
4114                        }
4115                        let write_arg = |this: &mut Self| -> BackendResult {
4116                            if let Some((min, max)) = clamp_bounds {
4117                                write!(this.out, "clamp(")?;
4118                                this.write_expr(module, arg, func_ctx)?;
4119                                write!(this.out, ", {min}, {max})")?;
4120                            } else {
4121                                this.write_expr(module, arg, func_ctx)?;
4122                            }
4123                            Ok(())
4124                        };
4125                        write!(self.out, "(")?;
4126                        write_arg(self)?;
4127                        write!(self.out, "[0] & 0xFF) | ((")?;
4128                        write_arg(self)?;
4129                        write!(self.out, "[1] & 0xFF) << 8) | ((")?;
4130                        write_arg(self)?;
4131                        write!(self.out, "[2] & 0xFF) << 16) | ((")?;
4132                        write_arg(self)?;
4133                        write!(self.out, "[3] & 0xFF) << 24)")?;
4134                        if was_signed {
4135                            write!(self.out, ")")?;
4136                        }
4137                    }
4138
4139                    Function::Unpack2x16float => {
4140                        write!(self.out, "float2(f16tof32(")?;
4141                        self.write_expr(module, arg, func_ctx)?;
4142                        write!(self.out, "), f16tof32((")?;
4143                        self.write_expr(module, arg, func_ctx)?;
4144                        write!(self.out, ") >> 16))")?;
4145                    }
4146                    Function::Unpack2x16snorm => {
4147                        let scale = 32767;
4148
4149                        write!(self.out, "(float2(int2(")?;
4150                        self.write_expr(module, arg, func_ctx)?;
4151                        write!(self.out, " << 16, ")?;
4152                        self.write_expr(module, arg, func_ctx)?;
4153                        write!(self.out, ") >> 16) / {scale}.0)")?;
4154                    }
4155                    Function::Unpack2x16unorm => {
4156                        let scale = 65535;
4157
4158                        write!(self.out, "(float2(")?;
4159                        self.write_expr(module, arg, func_ctx)?;
4160                        write!(self.out, " & 0xFFFF, ")?;
4161                        self.write_expr(module, arg, func_ctx)?;
4162                        write!(self.out, " >> 16) / {scale}.0)")?;
4163                    }
4164                    Function::Unpack4x8snorm => {
4165                        let scale = 127;
4166
4167                        write!(self.out, "(float4(int4(")?;
4168                        self.write_expr(module, arg, func_ctx)?;
4169                        write!(self.out, " << 24, ")?;
4170                        self.write_expr(module, arg, func_ctx)?;
4171                        write!(self.out, " << 16, ")?;
4172                        self.write_expr(module, arg, func_ctx)?;
4173                        write!(self.out, " << 8, ")?;
4174                        self.write_expr(module, arg, func_ctx)?;
4175                        write!(self.out, ") >> 24) / {scale}.0)")?;
4176                    }
4177                    Function::Unpack4x8unorm => {
4178                        let scale = 255;
4179
4180                        write!(self.out, "(float4(")?;
4181                        self.write_expr(module, arg, func_ctx)?;
4182                        write!(self.out, " & 0xFF, ")?;
4183                        self.write_expr(module, arg, func_ctx)?;
4184                        write!(self.out, " >> 8 & 0xFF, ")?;
4185                        self.write_expr(module, arg, func_ctx)?;
4186                        write!(self.out, " >> 16 & 0xFF, ")?;
4187                        self.write_expr(module, arg, func_ctx)?;
4188                        write!(self.out, " >> 24) / {scale}.0)")?;
4189                    }
4190                    fun @ (Function::Unpack4xI8 | Function::Unpack4xU8) => {
4191                        write!(self.out, "(")?;
4192                        if matches!(fun, Function::Unpack4xU8) {
4193                            write!(self.out, "u")?;
4194                        }
4195                        write!(self.out, "int4(")?;
4196                        self.write_expr(module, arg, func_ctx)?;
4197                        write!(self.out, ", ")?;
4198                        self.write_expr(module, arg, func_ctx)?;
4199                        write!(self.out, " >> 8, ")?;
4200                        self.write_expr(module, arg, func_ctx)?;
4201                        write!(self.out, " >> 16, ")?;
4202                        self.write_expr(module, arg, func_ctx)?;
4203                        write!(self.out, " >> 24) << 24 >> 24)")?;
4204                    }
4205                    fun @ (Function::Dot4I8Packed | Function::Dot4U8Packed) => {
4206                        let arg1 = arg1.unwrap();
4207
4208                        if self.options.shader_model >= ShaderModel::V6_4 {
4209                            // Intrinsics `dot4add_{i, u}8packed` are available in SM 6.4 and later.
4210                            let function_name = match fun {
4211                                Function::Dot4I8Packed => "dot4add_i8packed",
4212                                Function::Dot4U8Packed => "dot4add_u8packed",
4213                                _ => unreachable!(),
4214                            };
4215                            write!(self.out, "{function_name}(")?;
4216                            self.write_expr(module, arg, func_ctx)?;
4217                            write!(self.out, ", ")?;
4218                            self.write_expr(module, arg1, func_ctx)?;
4219                            write!(self.out, ", 0)")?;
4220                        } else {
4221                            // Fall back to a polyfill as `dot4add_u8packed` is not available.
4222                            write!(self.out, "dot(")?;
4223
4224                            if matches!(fun, Function::Dot4U8Packed) {
4225                                write!(self.out, "u")?;
4226                            }
4227                            write!(self.out, "int4(")?;
4228                            self.write_expr(module, arg, func_ctx)?;
4229                            write!(self.out, ", ")?;
4230                            self.write_expr(module, arg, func_ctx)?;
4231                            write!(self.out, " >> 8, ")?;
4232                            self.write_expr(module, arg, func_ctx)?;
4233                            write!(self.out, " >> 16, ")?;
4234                            self.write_expr(module, arg, func_ctx)?;
4235                            write!(self.out, " >> 24) << 24 >> 24, ")?;
4236
4237                            if matches!(fun, Function::Dot4U8Packed) {
4238                                write!(self.out, "u")?;
4239                            }
4240                            write!(self.out, "int4(")?;
4241                            self.write_expr(module, arg1, func_ctx)?;
4242                            write!(self.out, ", ")?;
4243                            self.write_expr(module, arg1, func_ctx)?;
4244                            write!(self.out, " >> 8, ")?;
4245                            self.write_expr(module, arg1, func_ctx)?;
4246                            write!(self.out, " >> 16, ")?;
4247                            self.write_expr(module, arg1, func_ctx)?;
4248                            write!(self.out, " >> 24) << 24 >> 24)")?;
4249                        }
4250                    }
4251                    Function::QuantizeToF16 => {
4252                        write!(self.out, "f16tof32(f32tof16(")?;
4253                        self.write_expr(module, arg, func_ctx)?;
4254                        write!(self.out, "))")?;
4255                    }
4256                    Function::Regular(fun_name) => {
4257                        write!(self.out, "{fun_name}(")?;
4258                        self.write_expr(module, arg, func_ctx)?;
4259                        if let Some(arg) = arg1 {
4260                            write!(self.out, ", ")?;
4261                            self.write_expr(module, arg, func_ctx)?;
4262                        }
4263                        if let Some(arg) = arg2 {
4264                            write!(self.out, ", ")?;
4265                            self.write_expr(module, arg, func_ctx)?;
4266                        }
4267                        if let Some(arg) = arg3 {
4268                            write!(self.out, ", ")?;
4269                            self.write_expr(module, arg, func_ctx)?;
4270                        }
4271                        write!(self.out, ")")?
4272                    }
4273                    // These overloads are only missing on FXC, so this is only needed for 32bit types,
4274                    // as non-32bit types are DXC only.
4275                    Function::MissingIntOverload(fun_name) => {
4276                        let scalar_kind = func_ctx.resolve_type(arg, &module.types).scalar();
4277                        if let Some(Scalar::I32) = scalar_kind {
4278                            write!(self.out, "asint({fun_name}(asuint(")?;
4279                            self.write_expr(module, arg, func_ctx)?;
4280                            write!(self.out, ")))")?;
4281                        } else {
4282                            write!(self.out, "{fun_name}(")?;
4283                            self.write_expr(module, arg, func_ctx)?;
4284                            write!(self.out, ")")?;
4285                        }
4286                    }
4287                    // These overloads are only missing on FXC, so this is only needed for 32bit types,
4288                    // as non-32bit types are DXC only.
4289                    Function::MissingIntReturnType(fun_name) => {
4290                        let scalar_kind = func_ctx.resolve_type(arg, &module.types).scalar();
4291                        if let Some(Scalar::I32) = scalar_kind {
4292                            write!(self.out, "asint({fun_name}(")?;
4293                            self.write_expr(module, arg, func_ctx)?;
4294                            write!(self.out, "))")?;
4295                        } else {
4296                            write!(self.out, "{fun_name}(")?;
4297                            self.write_expr(module, arg, func_ctx)?;
4298                            write!(self.out, ")")?;
4299                        }
4300                    }
4301                    Function::CountTrailingZeros => {
4302                        match *func_ctx.resolve_type(arg, &module.types) {
4303                            TypeInner::Vector { size, scalar } => {
4304                                let s = match size {
4305                                    crate::VectorSize::Bi => ".xx",
4306                                    crate::VectorSize::Tri => ".xxx",
4307                                    crate::VectorSize::Quad => ".xxxx",
4308                                };
4309
4310                                let scalar_width_bits = scalar.width * 8;
4311
4312                                if scalar.kind == ScalarKind::Uint || scalar.width != 4 {
4313                                    write!(
4314                                        self.out,
4315                                        "min(({scalar_width_bits}u){s}, firstbitlow("
4316                                    )?;
4317                                    self.write_expr(module, arg, func_ctx)?;
4318                                    write!(self.out, "))")?;
4319                                } else {
4320                                    // This is only needed for the FXC path, on 32bit signed integers.
4321                                    write!(
4322                                        self.out,
4323                                        "asint(min(({scalar_width_bits}u){s}, firstbitlow("
4324                                    )?;
4325                                    self.write_expr(module, arg, func_ctx)?;
4326                                    write!(self.out, ")))")?;
4327                                }
4328                            }
4329                            TypeInner::Scalar(scalar) => {
4330                                let scalar_width_bits = scalar.width * 8;
4331
4332                                if scalar.kind == ScalarKind::Uint || scalar.width != 4 {
4333                                    write!(self.out, "min({scalar_width_bits}u, firstbitlow(")?;
4334                                    self.write_expr(module, arg, func_ctx)?;
4335                                    write!(self.out, "))")?;
4336                                } else {
4337                                    // This is only needed for the FXC path, on 32bit signed integers.
4338                                    write!(
4339                                        self.out,
4340                                        "asint(min({scalar_width_bits}u, firstbitlow("
4341                                    )?;
4342                                    self.write_expr(module, arg, func_ctx)?;
4343                                    write!(self.out, ")))")?;
4344                                }
4345                            }
4346                            _ => unreachable!(),
4347                        }
4348
4349                        return Ok(());
4350                    }
4351                    Function::CountLeadingZeros => {
4352                        match *func_ctx.resolve_type(arg, &module.types) {
4353                            TypeInner::Vector { size, scalar } => {
4354                                let s = match size {
4355                                    crate::VectorSize::Bi => ".xx",
4356                                    crate::VectorSize::Tri => ".xxx",
4357                                    crate::VectorSize::Quad => ".xxxx",
4358                                };
4359
4360                                // scalar width - 1
4361                                let constant = scalar.width * 8 - 1;
4362
4363                                if scalar.kind == ScalarKind::Uint {
4364                                    write!(self.out, "(({constant}u){s} - firstbithigh(")?;
4365                                    self.write_expr(module, arg, func_ctx)?;
4366                                    write!(self.out, "))")?;
4367                                } else {
4368                                    let conversion_func = match scalar.width {
4369                                        4 => "asint",
4370                                        _ => "",
4371                                    };
4372                                    write!(self.out, "(")?;
4373                                    self.write_expr(module, arg, func_ctx)?;
4374                                    write!(
4375                                        self.out,
4376                                        " < (0){s} ? (0){s} : ({constant}){s} - {conversion_func}(firstbithigh("
4377                                    )?;
4378                                    self.write_expr(module, arg, func_ctx)?;
4379                                    write!(self.out, ")))")?;
4380                                }
4381                            }
4382                            TypeInner::Scalar(scalar) => {
4383                                // scalar width - 1
4384                                let constant = scalar.width * 8 - 1;
4385
4386                                if let ScalarKind::Uint = scalar.kind {
4387                                    write!(self.out, "({constant}u - firstbithigh(")?;
4388                                    self.write_expr(module, arg, func_ctx)?;
4389                                    write!(self.out, "))")?;
4390                                } else {
4391                                    let conversion_func = match scalar.width {
4392                                        4 => "asint",
4393                                        _ => "",
4394                                    };
4395                                    write!(self.out, "(")?;
4396                                    self.write_expr(module, arg, func_ctx)?;
4397                                    write!(
4398                                        self.out,
4399                                        " < 0 ? 0 : {constant} - {conversion_func}(firstbithigh("
4400                                    )?;
4401                                    self.write_expr(module, arg, func_ctx)?;
4402                                    write!(self.out, ")))")?;
4403                                }
4404                            }
4405                            _ => unreachable!(),
4406                        }
4407
4408                        return Ok(());
4409                    }
4410                }
4411            }
4412            Expression::Swizzle {
4413                size,
4414                vector,
4415                pattern,
4416            } => {
4417                self.write_expr(module, vector, func_ctx)?;
4418                write!(self.out, ".")?;
4419                for &sc in pattern[..size as usize].iter() {
4420                    self.out.write_char(back::COMPONENTS[sc as usize])?;
4421                }
4422            }
4423            Expression::ArrayLength(expr) => {
4424                let var_handle = match func_ctx.expressions[expr] {
4425                    Expression::AccessIndex { base, index: _ } => {
4426                        match func_ctx.expressions[base] {
4427                            Expression::GlobalVariable(handle) => handle,
4428                            _ => unreachable!(),
4429                        }
4430                    }
4431                    Expression::GlobalVariable(handle) => handle,
4432                    _ => unreachable!(),
4433                };
4434
4435                let var = &module.global_variables[var_handle];
4436                let (offset, stride) = match module.types[var.ty].inner {
4437                    TypeInner::Array { stride, .. } => (0, stride),
4438                    TypeInner::Struct { ref members, .. } => {
4439                        let last = members.last().unwrap();
4440                        let stride = match module.types[last.ty].inner {
4441                            TypeInner::Array { stride, .. } => stride,
4442                            _ => unreachable!(),
4443                        };
4444                        (last.offset, stride)
4445                    }
4446                    _ => unreachable!(),
4447                };
4448
4449                let storage_access = match var.space {
4450                    crate::AddressSpace::Storage { access } => access,
4451                    _ => crate::StorageAccess::default(),
4452                };
4453                let wrapped_array_length = WrappedArrayLength {
4454                    writable: storage_access.contains(crate::StorageAccess::STORE),
4455                };
4456
4457                write!(self.out, "((")?;
4458                self.write_wrapped_array_length_function_name(wrapped_array_length)?;
4459                let var_name = &self.names[&NameKey::GlobalVariable(var_handle)];
4460                write!(self.out, "({var_name}) - {offset}) / {stride})")?
4461            }
4462            Expression::Derivative { axis, ctrl, expr } => {
4463                use crate::{DerivativeAxis as Axis, DerivativeControl as Ctrl};
4464                if axis == Axis::Width && (ctrl == Ctrl::Coarse || ctrl == Ctrl::Fine) {
4465                    let tail = match ctrl {
4466                        Ctrl::Coarse => "coarse",
4467                        Ctrl::Fine => "fine",
4468                        Ctrl::None => unreachable!(),
4469                    };
4470                    write!(self.out, "abs(ddx_{tail}(")?;
4471                    self.write_expr(module, expr, func_ctx)?;
4472                    write!(self.out, ")) + abs(ddy_{tail}(")?;
4473                    self.write_expr(module, expr, func_ctx)?;
4474                    write!(self.out, "))")?
4475                } else {
4476                    let fun_str = match (axis, ctrl) {
4477                        (Axis::X, Ctrl::Coarse) => "ddx_coarse",
4478                        (Axis::X, Ctrl::Fine) => "ddx_fine",
4479                        (Axis::X, Ctrl::None) => "ddx",
4480                        (Axis::Y, Ctrl::Coarse) => "ddy_coarse",
4481                        (Axis::Y, Ctrl::Fine) => "ddy_fine",
4482                        (Axis::Y, Ctrl::None) => "ddy",
4483                        (Axis::Width, Ctrl::Coarse | Ctrl::Fine) => unreachable!(),
4484                        (Axis::Width, Ctrl::None) => "fwidth",
4485                    };
4486                    write!(self.out, "{fun_str}(")?;
4487                    self.write_expr(module, expr, func_ctx)?;
4488                    write!(self.out, ")")?
4489                }
4490            }
4491            Expression::Relational { fun, argument } => {
4492                use crate::RelationalFunction as Rf;
4493
4494                let fun_str = match fun {
4495                    Rf::All => "all",
4496                    Rf::Any => "any",
4497                    Rf::IsNan => "isnan",
4498                    Rf::IsInf => "isinf",
4499                };
4500                write!(self.out, "{fun_str}(")?;
4501                self.write_expr(module, argument, func_ctx)?;
4502                write!(self.out, ")")?
4503            }
4504            Expression::Select {
4505                condition,
4506                accept,
4507                reject,
4508            } => {
4509                write!(self.out, "(")?;
4510                self.write_expr(module, condition, func_ctx)?;
4511                write!(self.out, " ? ")?;
4512                self.write_expr(module, accept, func_ctx)?;
4513                write!(self.out, " : ")?;
4514                self.write_expr(module, reject, func_ctx)?;
4515                write!(self.out, ")")?
4516            }
4517            Expression::RayQueryGetIntersection { query, committed } => {
4518                // For reasoning, see write_stmt
4519                let Expression::LocalVariable(query_var) = func_ctx.expressions[query] else {
4520                    unreachable!()
4521                };
4522
4523                let tracker_expr_name = format!(
4524                    "{RAY_QUERY_TRACKER_VARIABLE_PREFIX}{}",
4525                    self.names[&func_ctx.name_key(query_var)]
4526                );
4527
4528                if committed {
4529                    write!(self.out, "GetCommittedIntersection(")?;
4530                    self.write_expr(module, query, func_ctx)?;
4531                    write!(self.out, ", {tracker_expr_name})")?;
4532                } else {
4533                    write!(self.out, "GetCandidateIntersection(")?;
4534                    self.write_expr(module, query, func_ctx)?;
4535                    write!(self.out, ", {tracker_expr_name})")?;
4536                }
4537            }
4538            // Not supported yet
4539            Expression::RayQueryVertexPositions { .. }
4540            | Expression::CooperativeLoad { .. }
4541            | Expression::CooperativeMultiplyAdd { .. } => {
4542                unreachable!()
4543            }
4544            // Nothing to do here, since call expression already cached
4545            Expression::CallResult(_)
4546            | Expression::AtomicResult { .. }
4547            | Expression::WorkGroupUniformLoadResult { .. }
4548            | Expression::RayQueryProceedResult
4549            | Expression::SubgroupBallotResult
4550            | Expression::SubgroupOperationResult { .. } => {}
4551        }
4552
4553        if !closing_bracket.is_empty() {
4554            write!(self.out, "{closing_bracket}")?;
4555        }
4556        Ok(())
4557    }
4558
4559    #[allow(clippy::too_many_arguments)]
4560    fn write_image_load(
4561        &mut self,
4562        module: &&Module,
4563        expr: Handle<crate::Expression>,
4564        func_ctx: &back::FunctionCtx,
4565        image: Handle<crate::Expression>,
4566        coordinate: Handle<crate::Expression>,
4567        array_index: Option<Handle<crate::Expression>>,
4568        sample: Option<Handle<crate::Expression>>,
4569        level: Option<Handle<crate::Expression>>,
4570    ) -> Result<(), Error> {
4571        let mut wrapping_type = None;
4572        match *func_ctx.resolve_type(image, &module.types) {
4573            TypeInner::Image {
4574                class: crate::ImageClass::External,
4575                ..
4576            } => {
4577                write!(self.out, "{IMAGE_LOAD_EXTERNAL_FUNCTION}(")?;
4578                self.write_expr(module, image, func_ctx)?;
4579                write!(self.out, ", ")?;
4580                self.write_expr(module, coordinate, func_ctx)?;
4581                write!(self.out, ")")?;
4582                return Ok(());
4583            }
4584            TypeInner::Image {
4585                class: crate::ImageClass::Storage { format, .. },
4586                ..
4587            } => {
4588                if format.single_component() {
4589                    wrapping_type = Some(Scalar::from(format));
4590                }
4591            }
4592            _ => {}
4593        }
4594        if let Some(scalar) = wrapping_type {
4595            write!(
4596                self.out,
4597                "{}{}(",
4598                help::IMAGE_STORAGE_LOAD_SCALAR_WRAPPER,
4599                scalar.to_hlsl_str()?
4600            )?;
4601        }
4602        // https://docs.microsoft.com/en-us/windows/win32/direct3dhlsl/dx-graphics-hlsl-to-load
4603        self.write_expr(module, image, func_ctx)?;
4604        write!(self.out, ".Load(")?;
4605
4606        self.write_texture_coordinates("int", coordinate, array_index, level, module, func_ctx)?;
4607
4608        if let Some(sample) = sample {
4609            write!(self.out, ", ")?;
4610            self.write_expr(module, sample, func_ctx)?;
4611        }
4612
4613        // close bracket for Load function
4614        write!(self.out, ")")?;
4615
4616        if wrapping_type.is_some() {
4617            write!(self.out, ")")?;
4618        }
4619
4620        // return x component if return type is scalar
4621        if let TypeInner::Scalar(_) = *func_ctx.resolve_type(expr, &module.types) {
4622            write!(self.out, ".x")?;
4623        }
4624        Ok(())
4625    }
4626
4627    /// Find the [`BindingArraySamplerInfo`] from an expression so that such an access
4628    /// can be generated later.
4629    fn sampler_binding_array_info_from_expression(
4630        &mut self,
4631        module: &Module,
4632        func_ctx: &back::FunctionCtx<'_>,
4633        base: Handle<crate::Expression>,
4634        resolved: &TypeInner,
4635    ) -> Option<BindingArraySamplerInfo> {
4636        if let TypeInner::BindingArray {
4637            base: base_ty_handle,
4638            ..
4639        } = *resolved
4640        {
4641            let base_ty = &module.types[base_ty_handle].inner;
4642            if let TypeInner::Sampler { comparison, .. } = *base_ty {
4643                let base = &func_ctx.expressions[base];
4644
4645                if let crate::Expression::GlobalVariable(handle) = *base {
4646                    let variable = &module.global_variables[handle];
4647
4648                    let sampler_heap_name = match comparison {
4649                        true => COMPARISON_SAMPLER_HEAP_VAR,
4650                        false => SAMPLER_HEAP_VAR,
4651                    };
4652
4653                    return Some(BindingArraySamplerInfo {
4654                        sampler_heap_name,
4655                        sampler_index_buffer_name: self
4656                            .wrapped
4657                            .sampler_index_buffers
4658                            .get(&super::SamplerIndexBufferKey {
4659                                group: variable.binding.unwrap().group,
4660                            })
4661                            .unwrap()
4662                            .clone(),
4663                        binding_array_base_index_name: self.names[&NameKey::GlobalVariable(handle)]
4664                            .clone(),
4665                    });
4666                }
4667            }
4668        }
4669
4670        None
4671    }
4672
4673    fn write_named_expr(
4674        &mut self,
4675        module: &Module,
4676        handle: Handle<crate::Expression>,
4677        name: String,
4678        // The expression which is being named.
4679        // Generally, this is the same as handle, except in WorkGroupUniformLoad
4680        named: Handle<crate::Expression>,
4681        ctx: &back::FunctionCtx,
4682    ) -> BackendResult {
4683        match ctx.info[named].ty {
4684            proc::TypeResolution::Handle(ty_handle) => match module.types[ty_handle].inner {
4685                TypeInner::Struct { .. } => {
4686                    let ty_name = &self.names[&NameKey::Type(ty_handle)];
4687                    write!(self.out, "{ty_name}")?;
4688                }
4689                _ => {
4690                    self.write_type(module, ty_handle)?;
4691                }
4692            },
4693            proc::TypeResolution::Value(ref inner) => {
4694                self.write_value_type(module, inner)?;
4695            }
4696        }
4697
4698        let resolved = ctx.resolve_type(named, &module.types);
4699
4700        write!(self.out, " {name}")?;
4701        // If rhs is a array type, we should write array size
4702        if let TypeInner::Array { base, size, .. } = *resolved {
4703            self.write_array_size(module, base, size)?;
4704        }
4705        write!(self.out, " = ")?;
4706        self.write_expr(module, handle, ctx)?;
4707        writeln!(self.out, ";")?;
4708        self.named_expressions.insert(named, name);
4709
4710        Ok(())
4711    }
4712
4713    /// Helper function that write default zero initialization
4714    pub(super) fn write_default_init(
4715        &mut self,
4716        module: &Module,
4717        ty: Handle<crate::Type>,
4718    ) -> BackendResult {
4719        write!(self.out, "(")?;
4720        self.write_type(module, ty)?;
4721        if let TypeInner::Array { base, size, .. } = module.types[ty].inner {
4722            self.write_array_size(module, base, size)?;
4723        }
4724        write!(self.out, ")0")?;
4725        Ok(())
4726    }
4727
4728    pub(super) fn write_control_barrier(
4729        &mut self,
4730        barrier: crate::Barrier,
4731        level: back::Level,
4732    ) -> BackendResult {
4733        if barrier.contains(crate::Barrier::STORAGE) {
4734            writeln!(self.out, "{level}DeviceMemoryBarrierWithGroupSync();")?;
4735        }
4736        if barrier.contains(crate::Barrier::WORK_GROUP) {
4737            writeln!(self.out, "{level}GroupMemoryBarrierWithGroupSync();")?;
4738        }
4739        if barrier.contains(crate::Barrier::SUB_GROUP) {
4740            // Does not exist in DirectX
4741        }
4742        if barrier.contains(crate::Barrier::TEXTURE) {
4743            writeln!(self.out, "{level}DeviceMemoryBarrierWithGroupSync();")?;
4744        }
4745        Ok(())
4746    }
4747
4748    fn write_memory_barrier(
4749        &mut self,
4750        barrier: crate::Barrier,
4751        level: back::Level,
4752    ) -> BackendResult {
4753        if barrier.contains(crate::Barrier::STORAGE) {
4754            writeln!(self.out, "{level}DeviceMemoryBarrier();")?;
4755        }
4756        if barrier.contains(crate::Barrier::WORK_GROUP) {
4757            writeln!(self.out, "{level}GroupMemoryBarrier();")?;
4758        }
4759        if barrier.contains(crate::Barrier::SUB_GROUP) {
4760            // Does not exist in DirectX
4761        }
4762        if barrier.contains(crate::Barrier::TEXTURE) {
4763            writeln!(self.out, "{level}DeviceMemoryBarrier();")?;
4764        }
4765        Ok(())
4766    }
4767
4768    /// Helper to emit the shared tail of an HLSL atomic call (arguments, value, result)
4769    fn emit_hlsl_atomic_tail(
4770        &mut self,
4771        module: &Module,
4772        func_ctx: &back::FunctionCtx<'_>,
4773        fun: &crate::AtomicFunction,
4774        compare_expr: Option<Handle<crate::Expression>>,
4775        value: Handle<crate::Expression>,
4776        res_var_info: &Option<(Handle<crate::Expression>, String)>,
4777    ) -> BackendResult {
4778        if let Some(cmp) = compare_expr {
4779            write!(self.out, ", ")?;
4780            self.write_expr(module, cmp, func_ctx)?;
4781        }
4782        write!(self.out, ", ")?;
4783        if let crate::AtomicFunction::Subtract = *fun {
4784            // we just wrote `InterlockedAdd`, so negate the argument
4785            write!(self.out, "-")?;
4786        }
4787        self.write_expr(module, value, func_ctx)?;
4788        if let Some(&(_res_handle, ref res_name)) = res_var_info.as_ref() {
4789            write!(self.out, ", ")?;
4790            if compare_expr.is_some() {
4791                write!(self.out, "{res_name}.old_value")?;
4792            } else {
4793                write!(self.out, "{res_name}")?;
4794            }
4795        }
4796        writeln!(self.out, ");")?;
4797        Ok(())
4798    }
4799}
4800
4801pub(super) struct MatrixType {
4802    pub(super) columns: crate::VectorSize,
4803    pub(super) rows: crate::VectorSize,
4804    pub(super) width: crate::Bytes,
4805}
4806
4807pub(super) fn get_inner_matrix_data(
4808    module: &Module,
4809    handle: Handle<crate::Type>,
4810) -> Option<MatrixType> {
4811    match module.types[handle].inner {
4812        TypeInner::Matrix {
4813            columns,
4814            rows,
4815            scalar,
4816        } => Some(MatrixType {
4817            columns,
4818            rows,
4819            width: scalar.width,
4820        }),
4821        TypeInner::Array { base, .. } => get_inner_matrix_data(module, base),
4822        _ => None,
4823    }
4824}
4825
4826/// If `base` is an access chain of the form `mat`, `mat[col]`, or `mat[col][row]`,
4827/// returns a tuple of the matrix, the column (vector) index (if present), and
4828/// the row (scalar) index (if present).
4829fn find_matrix_in_access_chain(
4830    module: &Module,
4831    base: Handle<crate::Expression>,
4832    func_ctx: &back::FunctionCtx<'_>,
4833) -> Option<(Handle<crate::Expression>, Option<Index>, Option<Index>)> {
4834    let mut current_base = base;
4835    let mut vector = None;
4836    let mut scalar = None;
4837    loop {
4838        let resolved_tr = func_ctx
4839            .resolve_type(current_base, &module.types)
4840            .pointer_base_type();
4841        let resolved = resolved_tr.as_ref()?.inner_with(&module.types);
4842
4843        match *resolved {
4844            TypeInner::Matrix { .. } => return Some((current_base, vector, scalar)),
4845            TypeInner::Scalar(_) | TypeInner::Vector { .. } => {}
4846            _ => return None,
4847        }
4848
4849        let index;
4850        (current_base, index) = match func_ctx.expressions[current_base] {
4851            crate::Expression::Access { base, index } => (base, Index::Expression(index)),
4852            crate::Expression::AccessIndex { base, index } => (base, Index::Static(index)),
4853            _ => return None,
4854        };
4855
4856        match *resolved {
4857            TypeInner::Scalar(_) => scalar = Some(index),
4858            TypeInner::Vector { .. } => vector = Some(index),
4859            _ => unreachable!(),
4860        }
4861    }
4862}
4863
4864/// Returns the matrix data if the access chain starting at `base`:
4865/// - starts with an expression with resolved type of [`TypeInner::Matrix`] if `direct = true`
4866/// - contains one or more expressions with resolved type of [`TypeInner::Array`] of [`TypeInner::Matrix`]
4867/// - ends at an expression with resolved type of [`TypeInner::Struct`]
4868pub(super) fn get_inner_matrix_of_struct_array_member(
4869    module: &Module,
4870    base: Handle<crate::Expression>,
4871    func_ctx: &back::FunctionCtx<'_>,
4872    direct: bool,
4873) -> Option<MatrixType> {
4874    let mut mat_data = None;
4875    let mut array_base = None;
4876
4877    let mut current_base = base;
4878    loop {
4879        let mut resolved = func_ctx.resolve_type(current_base, &module.types);
4880        if let TypeInner::Pointer { base, .. } = *resolved {
4881            resolved = &module.types[base].inner;
4882        };
4883
4884        match *resolved {
4885            TypeInner::Matrix {
4886                columns,
4887                rows,
4888                scalar,
4889            } => {
4890                mat_data = Some(MatrixType {
4891                    columns,
4892                    rows,
4893                    width: scalar.width,
4894                })
4895            }
4896            TypeInner::Array { base, .. } => {
4897                array_base = Some(base);
4898            }
4899            TypeInner::Struct { .. } => {
4900                if let Some(array_base) = array_base {
4901                    if direct {
4902                        return mat_data;
4903                    } else {
4904                        return get_inner_matrix_data(module, array_base);
4905                    }
4906                }
4907
4908                break;
4909            }
4910            _ => break,
4911        }
4912
4913        current_base = match func_ctx.expressions[current_base] {
4914            crate::Expression::Access { base, .. } => base,
4915            crate::Expression::AccessIndex { base, .. } => base,
4916            _ => break,
4917        };
4918    }
4919    None
4920}
4921
4922/// Simpler version of get_inner_matrix_of_global_uniform that only looks at the
4923/// immediate expression, rather than traversing an access chain.
4924fn get_global_uniform_matrix(
4925    module: &Module,
4926    base: Handle<crate::Expression>,
4927    func_ctx: &back::FunctionCtx<'_>,
4928) -> Option<MatrixType> {
4929    let base_tr = func_ctx
4930        .resolve_type(base, &module.types)
4931        .pointer_base_type();
4932    let base_ty = base_tr.as_ref().map(|tr| tr.inner_with(&module.types));
4933    match (&func_ctx.expressions[base], base_ty) {
4934        (
4935            &crate::Expression::GlobalVariable(handle),
4936            Some(&TypeInner::Matrix {
4937                columns,
4938                rows,
4939                scalar,
4940            }),
4941        ) if module.global_variables[handle].space == crate::AddressSpace::Uniform => {
4942            Some(MatrixType {
4943                columns,
4944                rows,
4945                width: scalar.width,
4946            })
4947        }
4948        _ => None,
4949    }
4950}
4951
4952/// Returns the matrix data if the access chain starting at `base`:
4953/// - starts with an expression with resolved type of [`TypeInner::Matrix`]
4954/// - contains zero or more expressions with resolved type of [`TypeInner::Array`] of [`TypeInner::Matrix`]
4955/// - ends with an [`Expression::GlobalVariable`](crate::Expression::GlobalVariable) in [`AddressSpace::Uniform`](crate::AddressSpace::Uniform)
4956fn get_inner_matrix_of_global_uniform(
4957    module: &Module,
4958    base: Handle<crate::Expression>,
4959    func_ctx: &back::FunctionCtx<'_>,
4960) -> Option<MatrixType> {
4961    let mut mat_data = None;
4962    let mut array_base = None;
4963
4964    let mut current_base = base;
4965    loop {
4966        let mut resolved = func_ctx.resolve_type(current_base, &module.types);
4967        if let TypeInner::Pointer { base, .. } = *resolved {
4968            resolved = &module.types[base].inner;
4969        };
4970
4971        match *resolved {
4972            TypeInner::Matrix {
4973                columns,
4974                rows,
4975                scalar,
4976            } => {
4977                mat_data = Some(MatrixType {
4978                    columns,
4979                    rows,
4980                    width: scalar.width,
4981                })
4982            }
4983            TypeInner::Array { base, .. } => {
4984                array_base = Some(base);
4985            }
4986            _ => break,
4987        }
4988
4989        current_base = match func_ctx.expressions[current_base] {
4990            crate::Expression::Access { base, .. } => base,
4991            crate::Expression::AccessIndex { base, .. } => base,
4992            crate::Expression::GlobalVariable(handle)
4993                if module.global_variables[handle].space == crate::AddressSpace::Uniform =>
4994            {
4995                return mat_data.or_else(|| {
4996                    array_base.and_then(|array_base| get_inner_matrix_data(module, array_base))
4997                })
4998            }
4999            _ => break,
5000        };
5001    }
5002    None
5003}