naga/back/hlsl/
writer.rs

1use alloc::{
2    format,
3    string::{String, ToString},
4    vec::Vec,
5};
6use core::{fmt, mem};
7
8use super::{
9    help,
10    help::{
11        WrappedArrayLength, WrappedConstructor, WrappedImageQuery, WrappedStructMatrixAccess,
12        WrappedZeroValue,
13    },
14    storage::StoreValue,
15    BackendResult, Error, FragmentEntryPoint, Options, PipelineOptions, ShaderModel,
16};
17use crate::{
18    back::{self, get_entry_points, Baked},
19    common,
20    proc::{self, index, ExternalTextureNameKey, NameKey},
21    valid, Handle, Module, RayQueryFunction, Scalar, ScalarKind, ShaderStage, TypeInner,
22};
23
24const LOCATION_SEMANTIC: &str = "LOC";
25const SPECIAL_CBUF_TYPE: &str = "NagaConstants";
26const SPECIAL_CBUF_VAR: &str = "_NagaConstants";
27const SPECIAL_FIRST_VERTEX: &str = "first_vertex";
28const SPECIAL_FIRST_INSTANCE: &str = "first_instance";
29const SPECIAL_OTHER: &str = "other";
30
31pub(crate) const MODF_FUNCTION: &str = "naga_modf";
32pub(crate) const FREXP_FUNCTION: &str = "naga_frexp";
33pub(crate) const EXTRACT_BITS_FUNCTION: &str = "naga_extractBits";
34pub(crate) const INSERT_BITS_FUNCTION: &str = "naga_insertBits";
35pub(crate) const SAMPLER_HEAP_VAR: &str = "nagaSamplerHeap";
36pub(crate) const COMPARISON_SAMPLER_HEAP_VAR: &str = "nagaComparisonSamplerHeap";
37pub(crate) const SAMPLE_EXTERNAL_TEXTURE_FUNCTION: &str = "nagaSampleExternalTexture";
38pub(crate) const ABS_FUNCTION: &str = "naga_abs";
39pub(crate) const DIV_FUNCTION: &str = "naga_div";
40pub(crate) const MOD_FUNCTION: &str = "naga_mod";
41pub(crate) const NEG_FUNCTION: &str = "naga_neg";
42pub(crate) const F2I32_FUNCTION: &str = "naga_f2i32";
43pub(crate) const F2U32_FUNCTION: &str = "naga_f2u32";
44pub(crate) const F2I64_FUNCTION: &str = "naga_f2i64";
45pub(crate) const F2U64_FUNCTION: &str = "naga_f2u64";
46pub(crate) const IMAGE_SAMPLE_BASE_CLAMP_TO_EDGE_FUNCTION: &str =
47    "nagaTextureSampleBaseClampToEdge";
48pub(crate) const IMAGE_LOAD_EXTERNAL_FUNCTION: &str = "nagaTextureLoadExternal";
49pub(crate) const RAY_QUERY_TRACKER_VARIABLE_PREFIX: &str = "naga_query_init_tracker_for_";
50/// Prefix for variables in a naga statement
51pub(crate) const INTERNAL_PREFIX: &str = "naga_";
52
53enum Index {
54    Expression(Handle<crate::Expression>),
55    Static(u32),
56}
57
58struct EpStructMember {
59    name: String,
60    ty: Handle<crate::Type>,
61    // technically, this should always be `Some`
62    // (we `debug_assert!` this in `write_interface_struct`)
63    binding: Option<crate::Binding>,
64    index: u32,
65}
66
67/// Structure contains information required for generating
68/// wrapped structure of all entry points arguments
69struct EntryPointBinding {
70    /// Name of the fake EP argument that contains the struct
71    /// with all the flattened input data.
72    arg_name: String,
73    /// Generated structure name
74    ty_name: String,
75    /// Members of generated structure
76    members: Vec<EpStructMember>,
77}
78
79pub(super) struct EntryPointInterface {
80    /// If `Some`, the input of an entry point is gathered in a special
81    /// struct with members sorted by binding.
82    /// The `EntryPointBinding::members` array is sorted by index,
83    /// so that we can walk it in `write_ep_arguments_initialization`.
84    input: Option<EntryPointBinding>,
85    /// If `Some`, the output of an entry point is flattened.
86    /// The `EntryPointBinding::members` array is sorted by binding,
87    /// So that we can walk it in `Statement::Return` handler.
88    output: Option<EntryPointBinding>,
89}
90
91#[derive(Clone, Eq, PartialEq, PartialOrd, Ord)]
92enum InterfaceKey {
93    Location(u32),
94    BuiltIn(crate::BuiltIn),
95    Other,
96}
97
98impl InterfaceKey {
99    const fn new(binding: Option<&crate::Binding>) -> Self {
100        match binding {
101            Some(&crate::Binding::Location { location, .. }) => Self::Location(location),
102            Some(&crate::Binding::BuiltIn(built_in)) => Self::BuiltIn(built_in),
103            None => Self::Other,
104        }
105    }
106}
107
108#[derive(Copy, Clone, PartialEq)]
109enum Io {
110    Input,
111    Output,
112}
113
114const fn is_subgroup_builtin_binding(binding: &Option<crate::Binding>) -> bool {
115    let &Some(crate::Binding::BuiltIn(builtin)) = binding else {
116        return false;
117    };
118    matches!(
119        builtin,
120        crate::BuiltIn::SubgroupSize
121            | crate::BuiltIn::SubgroupInvocationId
122            | crate::BuiltIn::NumSubgroups
123            | crate::BuiltIn::SubgroupId
124    )
125}
126
127/// Information for how to generate a `binding_array<sampler>` access.
128struct BindingArraySamplerInfo {
129    /// Variable name of the sampler heap
130    sampler_heap_name: &'static str,
131    /// Variable name of the sampler index buffer
132    sampler_index_buffer_name: String,
133    /// Variable name of the base index _into_ the sampler index buffer
134    binding_array_base_index_name: String,
135}
136
137impl<'a, W: fmt::Write> super::Writer<'a, W> {
138    pub fn new(out: W, options: &'a Options, pipeline_options: &'a PipelineOptions) -> Self {
139        Self {
140            out,
141            names: crate::FastHashMap::default(),
142            namer: proc::Namer::default(),
143            options,
144            pipeline_options,
145            entry_point_io: crate::FastHashMap::default(),
146            named_expressions: crate::NamedExpressions::default(),
147            wrapped: super::Wrapped::default(),
148            written_committed_intersection: false,
149            written_candidate_intersection: false,
150            continue_ctx: back::continue_forward::ContinueCtx::default(),
151            temp_access_chain: Vec::new(),
152            need_bake_expressions: Default::default(),
153        }
154    }
155
156    fn reset(&mut self, module: &Module) {
157        self.names.clear();
158        self.namer.reset(
159            module,
160            &super::keywords::RESERVED_SET,
161            proc::KeywordSet::empty(),
162            &super::keywords::RESERVED_CASE_INSENSITIVE_SET,
163            super::keywords::RESERVED_PREFIXES,
164            &mut self.names,
165        );
166        self.entry_point_io.clear();
167        self.named_expressions.clear();
168        self.wrapped.clear();
169        self.written_committed_intersection = false;
170        self.written_candidate_intersection = false;
171        self.continue_ctx.clear();
172        self.need_bake_expressions.clear();
173    }
174
175    /// Generates statements to be inserted immediately before and at the very
176    /// start of the body of each loop, to defeat infinite loop reasoning.
177    /// The 0th item of the returned tuple should be inserted immediately prior
178    /// to the loop and the 1st item should be inserted at the very start of
179    /// the loop body.
180    ///
181    /// See [`back::msl::Writer::gen_force_bounded_loop_statements`] for details.
182    fn gen_force_bounded_loop_statements(
183        &mut self,
184        level: back::Level,
185    ) -> Option<(String, String)> {
186        if !self.options.force_loop_bounding {
187            return None;
188        }
189
190        let loop_bound_name = self.namer.call("loop_bound");
191        let max = u32::MAX;
192        // Count down from u32::MAX rather than up from 0 to avoid hang on
193        // certain Intel drivers. See <https://github.com/gfx-rs/wgpu/issues/7319>.
194        let decl = format!("{level}uint2 {loop_bound_name} = uint2({max}u, {max}u);");
195        let level = level.next();
196        let break_and_inc = format!(
197            "{level}if (all({loop_bound_name} == uint2(0u, 0u))) {{ break; }}
198{level}{loop_bound_name} -= uint2({loop_bound_name}.y == 0u, 1u);"
199        );
200
201        Some((decl, break_and_inc))
202    }
203
204    /// Helper method used to find which expressions of a given function require baking
205    ///
206    /// # Notes
207    /// Clears `need_bake_expressions` set before adding to it
208    fn update_expressions_to_bake(
209        &mut self,
210        module: &Module,
211        func: &crate::Function,
212        info: &valid::FunctionInfo,
213    ) {
214        use crate::Expression;
215        self.need_bake_expressions.clear();
216        for (exp_handle, expr) in func.expressions.iter() {
217            let expr_info = &info[exp_handle];
218            let min_ref_count = func.expressions[exp_handle].bake_ref_count();
219            if min_ref_count <= expr_info.ref_count {
220                self.need_bake_expressions.insert(exp_handle);
221            }
222
223            if let Expression::Math { fun, arg, arg1, .. } = *expr {
224                match fun {
225                    crate::MathFunction::Asinh
226                    | crate::MathFunction::Acosh
227                    | crate::MathFunction::Atanh
228                    | crate::MathFunction::Unpack2x16float
229                    | crate::MathFunction::Unpack2x16snorm
230                    | crate::MathFunction::Unpack2x16unorm
231                    | crate::MathFunction::Unpack4x8snorm
232                    | crate::MathFunction::Unpack4x8unorm
233                    | crate::MathFunction::Unpack4xI8
234                    | crate::MathFunction::Unpack4xU8
235                    | crate::MathFunction::Pack2x16float
236                    | crate::MathFunction::Pack2x16snorm
237                    | crate::MathFunction::Pack2x16unorm
238                    | crate::MathFunction::Pack4x8snorm
239                    | crate::MathFunction::Pack4x8unorm
240                    | crate::MathFunction::Pack4xI8
241                    | crate::MathFunction::Pack4xU8
242                    | crate::MathFunction::Pack4xI8Clamp
243                    | crate::MathFunction::Pack4xU8Clamp => {
244                        self.need_bake_expressions.insert(arg);
245                    }
246                    crate::MathFunction::CountLeadingZeros => {
247                        let inner = info[exp_handle].ty.inner_with(&module.types);
248                        if let Some(ScalarKind::Sint) = inner.scalar_kind() {
249                            self.need_bake_expressions.insert(arg);
250                        }
251                    }
252                    crate::MathFunction::Dot4U8Packed | crate::MathFunction::Dot4I8Packed => {
253                        self.need_bake_expressions.insert(arg);
254                        self.need_bake_expressions.insert(arg1.unwrap());
255                    }
256                    _ => {}
257                }
258            }
259
260            if let Expression::Derivative { axis, ctrl, expr } = *expr {
261                use crate::{DerivativeAxis as Axis, DerivativeControl as Ctrl};
262                if axis == Axis::Width && (ctrl == Ctrl::Coarse || ctrl == Ctrl::Fine) {
263                    self.need_bake_expressions.insert(expr);
264                }
265            }
266
267            if let Expression::GlobalVariable(_) = *expr {
268                let inner = info[exp_handle].ty.inner_with(&module.types);
269
270                if let TypeInner::Sampler { .. } = *inner {
271                    self.need_bake_expressions.insert(exp_handle);
272                }
273            }
274        }
275        for statement in func.body.iter() {
276            match *statement {
277                crate::Statement::SubgroupCollectiveOperation {
278                    op: _,
279                    collective_op: crate::CollectiveOperation::InclusiveScan,
280                    argument,
281                    result: _,
282                } => {
283                    self.need_bake_expressions.insert(argument);
284                }
285                crate::Statement::Atomic {
286                    fun: crate::AtomicFunction::Exchange { compare: Some(cmp) },
287                    ..
288                } => {
289                    self.need_bake_expressions.insert(cmp);
290                }
291                _ => {}
292            }
293        }
294    }
295
296    pub fn write(
297        &mut self,
298        module: &Module,
299        module_info: &valid::ModuleInfo,
300        fragment_entry_point: Option<&FragmentEntryPoint<'_>>,
301    ) -> Result<super::ReflectionInfo, Error> {
302        self.reset(module);
303
304        // Write special constants, if needed
305        if let Some(ref bt) = self.options.special_constants_binding {
306            writeln!(self.out, "struct {SPECIAL_CBUF_TYPE} {{")?;
307            writeln!(self.out, "{}int {};", back::INDENT, SPECIAL_FIRST_VERTEX)?;
308            writeln!(self.out, "{}int {};", back::INDENT, SPECIAL_FIRST_INSTANCE)?;
309            writeln!(self.out, "{}uint {};", back::INDENT, SPECIAL_OTHER)?;
310            writeln!(self.out, "}};")?;
311            write!(
312                self.out,
313                "ConstantBuffer<{}> {}: register(b{}",
314                SPECIAL_CBUF_TYPE, SPECIAL_CBUF_VAR, bt.register
315            )?;
316            if bt.space != 0 {
317                write!(self.out, ", space{}", bt.space)?;
318            }
319            writeln!(self.out, ");")?;
320
321            // Extra newline for readability
322            writeln!(self.out)?;
323        }
324
325        for (group, bt) in self.options.dynamic_storage_buffer_offsets_targets.iter() {
326            writeln!(self.out, "struct __dynamic_buffer_offsetsTy{group} {{")?;
327            for i in 0..bt.size {
328                writeln!(self.out, "{}uint _{};", back::INDENT, i)?;
329            }
330            writeln!(self.out, "}};")?;
331            writeln!(
332                self.out,
333                "ConstantBuffer<__dynamic_buffer_offsetsTy{}> __dynamic_buffer_offsets{}: register(b{}, space{});",
334                group, group, bt.register, bt.space
335            )?;
336
337            // Extra newline for readability
338            writeln!(self.out)?;
339        }
340
341        // Save all entry point output types
342        let ep_results = module
343            .entry_points
344            .iter()
345            .map(|ep| (ep.stage, ep.function.result.clone()))
346            .collect::<Vec<(ShaderStage, Option<crate::FunctionResult>)>>();
347
348        self.write_all_mat_cx2_typedefs_and_functions(module)?;
349
350        // Write all structs
351        for (handle, ty) in module.types.iter() {
352            if let TypeInner::Struct { ref members, span } = ty.inner {
353                if module.types[members.last().unwrap().ty]
354                    .inner
355                    .is_dynamically_sized(&module.types)
356                {
357                    // unsized arrays can only be in storage buffers,
358                    // for which we use `ByteAddressBuffer` anyway.
359                    continue;
360                }
361
362                let ep_result = ep_results.iter().find(|e| {
363                    if let Some(ref result) = e.1 {
364                        result.ty == handle
365                    } else {
366                        false
367                    }
368                });
369
370                self.write_struct(
371                    module,
372                    handle,
373                    members,
374                    span,
375                    ep_result.map(|r| (r.0, Io::Output)),
376                )?;
377                writeln!(self.out)?;
378            }
379        }
380
381        self.write_special_functions(module)?;
382
383        self.write_wrapped_expression_functions(module, &module.global_expressions, None)?;
384        self.write_wrapped_zero_value_functions(module, &module.global_expressions)?;
385
386        // Write all named constants
387        let mut constants = module
388            .constants
389            .iter()
390            .filter(|&(_, c)| c.name.is_some())
391            .peekable();
392        while let Some((handle, _)) = constants.next() {
393            self.write_global_constant(module, handle)?;
394            // Add extra newline for readability on last iteration
395            if constants.peek().is_none() {
396                writeln!(self.out)?;
397            }
398        }
399
400        // Write all globals
401        for (global, _) in module.global_variables.iter() {
402            self.write_global(module, global)?;
403        }
404
405        if !module.global_variables.is_empty() {
406            // Add extra newline for readability
407            writeln!(self.out)?;
408        }
409
410        let ep_range = get_entry_points(module, self.pipeline_options.entry_point.as_ref())
411            .map_err(|(stage, name)| Error::EntryPointNotFound(stage, name))?;
412
413        // Write all entry points wrapped structs
414        for index in ep_range.clone() {
415            let ep = &module.entry_points[index];
416            let ep_name = self.names[&NameKey::EntryPoint(index as u16)].clone();
417            let ep_io = self.write_ep_interface(
418                module,
419                &ep.function,
420                ep.stage,
421                &ep_name,
422                fragment_entry_point,
423            )?;
424            self.entry_point_io.insert(index, ep_io);
425        }
426
427        // Write all regular functions
428        for (handle, function) in module.functions.iter() {
429            let info = &module_info[handle];
430
431            // Check if all of the globals are accessible
432            if !self.options.fake_missing_bindings {
433                if let Some((var_handle, _)) =
434                    module
435                        .global_variables
436                        .iter()
437                        .find(|&(var_handle, var)| match var.binding {
438                            Some(ref binding) if !info[var_handle].is_empty() => {
439                                self.options.resolve_resource_binding(binding).is_err()
440                                    && self
441                                        .options
442                                        .resolve_external_texture_resource_binding(binding)
443                                        .is_err()
444                            }
445                            _ => false,
446                        })
447                {
448                    log::debug!(
449                        "Skipping function {:?} (name {:?}) because global {:?} is inaccessible",
450                        handle,
451                        function.name,
452                        var_handle
453                    );
454                    continue;
455                }
456            }
457
458            let ctx = back::FunctionCtx {
459                ty: back::FunctionType::Function(handle),
460                info,
461                expressions: &function.expressions,
462                named_expressions: &function.named_expressions,
463            };
464            let name = self.names[&NameKey::Function(handle)].clone();
465
466            self.write_wrapped_functions(module, &ctx)?;
467
468            self.write_function(module, name.as_str(), function, &ctx, info)?;
469
470            writeln!(self.out)?;
471        }
472
473        let mut translated_ep_names = Vec::with_capacity(ep_range.len());
474
475        // Write all entry points
476        for index in ep_range {
477            let ep = &module.entry_points[index];
478            let info = module_info.get_entry_point(index);
479
480            if !self.options.fake_missing_bindings {
481                let mut ep_error = None;
482                for (var_handle, var) in module.global_variables.iter() {
483                    match var.binding {
484                        Some(ref binding) if !info[var_handle].is_empty() => {
485                            if let Err(err) = self.options.resolve_resource_binding(binding) {
486                                if self
487                                    .options
488                                    .resolve_external_texture_resource_binding(binding)
489                                    .is_err()
490                                {
491                                    ep_error = Some(err);
492                                    break;
493                                }
494                            }
495                        }
496                        _ => {}
497                    }
498                }
499                if let Some(err) = ep_error {
500                    translated_ep_names.push(Err(err));
501                    continue;
502                }
503            }
504
505            let ctx = back::FunctionCtx {
506                ty: back::FunctionType::EntryPoint(index as u16),
507                info,
508                expressions: &ep.function.expressions,
509                named_expressions: &ep.function.named_expressions,
510            };
511
512            self.write_wrapped_functions(module, &ctx)?;
513
514            if ep.stage.compute_like() {
515                // HLSL is calling workgroup size "num threads"
516                let num_threads = ep.workgroup_size;
517                writeln!(
518                    self.out,
519                    "[numthreads({}, {}, {})]",
520                    num_threads[0], num_threads[1], num_threads[2]
521                )?;
522            }
523
524            let name = self.names[&NameKey::EntryPoint(index as u16)].clone();
525            self.write_function(module, &name, &ep.function, &ctx, info)?;
526
527            if index < module.entry_points.len() - 1 {
528                writeln!(self.out)?;
529            }
530
531            translated_ep_names.push(Ok(name));
532        }
533
534        Ok(super::ReflectionInfo {
535            entry_point_names: translated_ep_names,
536        })
537    }
538
539    fn write_modifier(&mut self, binding: &crate::Binding) -> BackendResult {
540        match *binding {
541            crate::Binding::BuiltIn(crate::BuiltIn::Position { invariant: true }) => {
542                write!(self.out, "precise ")?;
543            }
544            crate::Binding::BuiltIn(crate::BuiltIn::Barycentric { perspective: false }) => {
545                write!(self.out, "noperspective ")?;
546            }
547            crate::Binding::Location {
548                interpolation,
549                sampling,
550                ..
551            } => {
552                if let Some(interpolation) = interpolation {
553                    if let Some(string) = interpolation.to_hlsl_str() {
554                        write!(self.out, "{string} ")?
555                    }
556                }
557
558                if let Some(sampling) = sampling {
559                    if let Some(string) = sampling.to_hlsl_str() {
560                        write!(self.out, "{string} ")?
561                    }
562                }
563            }
564            crate::Binding::BuiltIn(_) => {}
565        }
566
567        Ok(())
568    }
569
570    //TODO: we could force fragment outputs to always go through `entry_point_io.output` path
571    // if they are struct, so that the `stage` argument here could be omitted.
572    fn write_semantic(
573        &mut self,
574        binding: &Option<crate::Binding>,
575        stage: Option<(ShaderStage, Io)>,
576    ) -> BackendResult {
577        match *binding {
578            Some(crate::Binding::BuiltIn(builtin)) if !is_subgroup_builtin_binding(binding) => {
579                if builtin == crate::BuiltIn::ViewIndex
580                    && self.options.shader_model < ShaderModel::V6_1
581                {
582                    return Err(Error::ShaderModelTooLow(
583                        "used @builtin(view_index) or SV_ViewID".to_string(),
584                        ShaderModel::V6_1,
585                    ));
586                }
587                let builtin_str = builtin.to_hlsl_str()?;
588                write!(self.out, " : {builtin_str}")?;
589            }
590            Some(crate::Binding::Location {
591                blend_src: Some(1), ..
592            }) => {
593                write!(self.out, " : SV_Target1")?;
594            }
595            Some(crate::Binding::Location { location, .. }) => {
596                if stage == Some((ShaderStage::Fragment, Io::Output)) {
597                    write!(self.out, " : SV_Target{location}")?;
598                } else {
599                    write!(self.out, " : {LOCATION_SEMANTIC}{location}")?;
600                }
601            }
602            _ => {}
603        }
604
605        Ok(())
606    }
607
608    fn write_interface_struct(
609        &mut self,
610        module: &Module,
611        shader_stage: (ShaderStage, Io),
612        struct_name: String,
613        mut members: Vec<EpStructMember>,
614    ) -> Result<EntryPointBinding, Error> {
615        // Sort the members so that first come the user-defined varyings
616        // in ascending locations, and then built-ins. This allows VS and FS
617        // interfaces to match with regards to order.
618        members.sort_by_key(|m| InterfaceKey::new(m.binding.as_ref()));
619
620        write!(self.out, "struct {struct_name}")?;
621        writeln!(self.out, " {{")?;
622        for m in members.iter() {
623            // Sanity check that each IO member is a built-in or is assigned a
624            // location. Also see note about nesting in `write_ep_input_struct`.
625            debug_assert!(m.binding.is_some());
626
627            if is_subgroup_builtin_binding(&m.binding) {
628                continue;
629            }
630            write!(self.out, "{}", back::INDENT)?;
631            if let Some(ref binding) = m.binding {
632                self.write_modifier(binding)?;
633            }
634            self.write_type(module, m.ty)?;
635            write!(self.out, " {}", &m.name)?;
636            self.write_semantic(&m.binding, Some(shader_stage))?;
637            writeln!(self.out, ";")?;
638        }
639        if members.iter().any(|arg| {
640            matches!(
641                arg.binding,
642                Some(crate::Binding::BuiltIn(crate::BuiltIn::SubgroupId))
643            )
644        }) {
645            writeln!(
646                self.out,
647                "{}uint __local_invocation_index : SV_GroupIndex;",
648                back::INDENT
649            )?;
650        }
651        writeln!(self.out, "}};")?;
652        writeln!(self.out)?;
653
654        // See ordering notes on EntryPointInterface fields
655        match shader_stage.1 {
656            Io::Input => {
657                // bring back the original order
658                members.sort_by_key(|m| m.index);
659            }
660            Io::Output => {
661                // keep it sorted by binding
662            }
663        }
664
665        Ok(EntryPointBinding {
666            arg_name: self.namer.call(struct_name.to_lowercase().as_str()),
667            ty_name: struct_name,
668            members,
669        })
670    }
671
672    /// Flatten all entry point arguments into a single struct.
673    /// This is needed since we need to re-order them: first placing user locations,
674    /// then built-ins.
675    fn write_ep_input_struct(
676        &mut self,
677        module: &Module,
678        func: &crate::Function,
679        stage: ShaderStage,
680        entry_point_name: &str,
681    ) -> Result<EntryPointBinding, Error> {
682        let struct_name = format!("{stage:?}Input_{entry_point_name}");
683
684        let mut fake_members = Vec::new();
685        for arg in func.arguments.iter() {
686            // NOTE: We don't need to handle nesting structs. All members must
687            // be either built-ins or assigned a location. I.E. `binding` is
688            // `Some`. This is checked in `VaryingContext::validate`. See:
689            // https://gpuweb.github.io/gpuweb/wgsl/#input-output-locations
690            match module.types[arg.ty].inner {
691                TypeInner::Struct { ref members, .. } => {
692                    for member in members.iter() {
693                        let name = self.namer.call_or(&member.name, "member");
694                        let index = fake_members.len() as u32;
695                        fake_members.push(EpStructMember {
696                            name,
697                            ty: member.ty,
698                            binding: member.binding.clone(),
699                            index,
700                        });
701                    }
702                }
703                _ => {
704                    let member_name = self.namer.call_or(&arg.name, "member");
705                    let index = fake_members.len() as u32;
706                    fake_members.push(EpStructMember {
707                        name: member_name,
708                        ty: arg.ty,
709                        binding: arg.binding.clone(),
710                        index,
711                    });
712                }
713            }
714        }
715
716        self.write_interface_struct(module, (stage, Io::Input), struct_name, fake_members)
717    }
718
719    /// Flatten all entry point results into a single struct.
720    /// This is needed since we need to re-order them: first placing user locations,
721    /// then built-ins.
722    fn write_ep_output_struct(
723        &mut self,
724        module: &Module,
725        result: &crate::FunctionResult,
726        stage: ShaderStage,
727        entry_point_name: &str,
728        frag_ep: Option<&FragmentEntryPoint<'_>>,
729    ) -> Result<EntryPointBinding, Error> {
730        let struct_name = format!("{stage:?}Output_{entry_point_name}");
731
732        let empty = [];
733        let members = match module.types[result.ty].inner {
734            TypeInner::Struct { ref members, .. } => members,
735            ref other => {
736                log::error!("Unexpected {other:?} output type without a binding");
737                &empty[..]
738            }
739        };
740
741        // Gather list of fragment input locations. We use this below to remove user-defined
742        // varyings from VS outputs that aren't in the FS inputs. This makes the VS interface match
743        // as long as the FS inputs are a subset of the VS outputs. This is only applied if the
744        // writer is supplied with information about the fragment entry point.
745        let fs_input_locs = if let (Some(frag_ep), ShaderStage::Vertex) = (frag_ep, stage) {
746            let mut fs_input_locs = Vec::new();
747            for arg in frag_ep.func.arguments.iter() {
748                let mut push_if_location = |binding: &Option<crate::Binding>| match *binding {
749                    Some(crate::Binding::Location { location, .. }) => fs_input_locs.push(location),
750                    Some(crate::Binding::BuiltIn(_)) | None => {}
751                };
752
753                // NOTE: We don't need to handle struct nesting. See note in
754                // `write_ep_input_struct`.
755                match frag_ep.module.types[arg.ty].inner {
756                    TypeInner::Struct { ref members, .. } => {
757                        for member in members.iter() {
758                            push_if_location(&member.binding);
759                        }
760                    }
761                    _ => push_if_location(&arg.binding),
762                }
763            }
764            fs_input_locs.sort();
765            Some(fs_input_locs)
766        } else {
767            None
768        };
769
770        let mut fake_members = Vec::new();
771        for (index, member) in members.iter().enumerate() {
772            if let Some(ref fs_input_locs) = fs_input_locs {
773                match member.binding {
774                    Some(crate::Binding::Location { location, .. }) => {
775                        if fs_input_locs.binary_search(&location).is_err() {
776                            continue;
777                        }
778                    }
779                    Some(crate::Binding::BuiltIn(_)) | None => {}
780                }
781            }
782
783            let member_name = self.namer.call_or(&member.name, "member");
784            fake_members.push(EpStructMember {
785                name: member_name,
786                ty: member.ty,
787                binding: member.binding.clone(),
788                index: index as u32,
789            });
790        }
791
792        self.write_interface_struct(module, (stage, Io::Output), struct_name, fake_members)
793    }
794
795    /// Writes special interface structures for an entry point. The special structures have
796    /// all the fields flattened into them and sorted by binding. They are needed to emulate
797    /// subgroup built-ins and to make the interfaces between VS outputs and FS inputs match.
798    fn write_ep_interface(
799        &mut self,
800        module: &Module,
801        func: &crate::Function,
802        stage: ShaderStage,
803        ep_name: &str,
804        frag_ep: Option<&FragmentEntryPoint<'_>>,
805    ) -> Result<EntryPointInterface, Error> {
806        Ok(EntryPointInterface {
807            input: if !func.arguments.is_empty()
808                && (stage == ShaderStage::Fragment
809                    || func
810                        .arguments
811                        .iter()
812                        .any(|arg| is_subgroup_builtin_binding(&arg.binding)))
813            {
814                Some(self.write_ep_input_struct(module, func, stage, ep_name)?)
815            } else {
816                None
817            },
818            output: match func.result {
819                Some(ref fr) if fr.binding.is_none() && stage == ShaderStage::Vertex => {
820                    Some(self.write_ep_output_struct(module, fr, stage, ep_name, frag_ep)?)
821                }
822                _ => None,
823            },
824        })
825    }
826
827    fn write_ep_argument_initialization(
828        &mut self,
829        ep: &crate::EntryPoint,
830        ep_input: &EntryPointBinding,
831        fake_member: &EpStructMember,
832    ) -> BackendResult {
833        match fake_member.binding {
834            Some(crate::Binding::BuiltIn(crate::BuiltIn::SubgroupSize)) => {
835                write!(self.out, "WaveGetLaneCount()")?
836            }
837            Some(crate::Binding::BuiltIn(crate::BuiltIn::SubgroupInvocationId)) => {
838                write!(self.out, "WaveGetLaneIndex()")?
839            }
840            Some(crate::Binding::BuiltIn(crate::BuiltIn::NumSubgroups)) => write!(
841                self.out,
842                "({}u + WaveGetLaneCount() - 1u) / WaveGetLaneCount()",
843                ep.workgroup_size[0] * ep.workgroup_size[1] * ep.workgroup_size[2]
844            )?,
845            Some(crate::Binding::BuiltIn(crate::BuiltIn::SubgroupId)) => {
846                write!(
847                    self.out,
848                    "{}.__local_invocation_index / WaveGetLaneCount()",
849                    ep_input.arg_name
850                )?;
851            }
852            _ => {
853                write!(self.out, "{}.{}", ep_input.arg_name, fake_member.name)?;
854            }
855        }
856        Ok(())
857    }
858
859    /// Write an entry point preface that initializes the arguments as specified in IR.
860    fn write_ep_arguments_initialization(
861        &mut self,
862        module: &Module,
863        func: &crate::Function,
864        ep_index: u16,
865    ) -> BackendResult {
866        let ep = &module.entry_points[ep_index as usize];
867        let ep_input = match self
868            .entry_point_io
869            .get_mut(&(ep_index as usize))
870            .unwrap()
871            .input
872            .take()
873        {
874            Some(ep_input) => ep_input,
875            None => return Ok(()),
876        };
877        let mut fake_iter = ep_input.members.iter();
878        for (arg_index, arg) in func.arguments.iter().enumerate() {
879            write!(self.out, "{}", back::INDENT)?;
880            self.write_type(module, arg.ty)?;
881            let arg_name = &self.names[&NameKey::EntryPointArgument(ep_index, arg_index as u32)];
882            write!(self.out, " {arg_name}")?;
883            match module.types[arg.ty].inner {
884                TypeInner::Array { base, size, .. } => {
885                    self.write_array_size(module, base, size)?;
886                    write!(self.out, " = ")?;
887                    self.write_ep_argument_initialization(
888                        ep,
889                        &ep_input,
890                        fake_iter.next().unwrap(),
891                    )?;
892                    writeln!(self.out, ";")?;
893                }
894                TypeInner::Struct { ref members, .. } => {
895                    write!(self.out, " = {{ ")?;
896                    for index in 0..members.len() {
897                        if index != 0 {
898                            write!(self.out, ", ")?;
899                        }
900                        self.write_ep_argument_initialization(
901                            ep,
902                            &ep_input,
903                            fake_iter.next().unwrap(),
904                        )?;
905                    }
906                    writeln!(self.out, " }};")?;
907                }
908                _ => {
909                    write!(self.out, " = ")?;
910                    self.write_ep_argument_initialization(
911                        ep,
912                        &ep_input,
913                        fake_iter.next().unwrap(),
914                    )?;
915                    writeln!(self.out, ";")?;
916                }
917            }
918        }
919        assert!(fake_iter.next().is_none());
920        Ok(())
921    }
922
923    /// Helper method used to write global variables
924    /// # Notes
925    /// Always adds a newline
926    fn write_global(
927        &mut self,
928        module: &Module,
929        handle: Handle<crate::GlobalVariable>,
930    ) -> BackendResult {
931        let global = &module.global_variables[handle];
932        let inner = &module.types[global.ty].inner;
933
934        let handle_ty = match *inner {
935            TypeInner::BindingArray { ref base, .. } => &module.types[*base].inner,
936            _ => inner,
937        };
938
939        // External textures are handled entirely differently, so defer entirely to that method.
940        // We do so prior to calling resolve_resource_binding() below, as we even need to resolve
941        // their bindings separately.
942        let is_external_texture = matches!(
943            *handle_ty,
944            TypeInner::Image {
945                class: crate::ImageClass::External,
946                ..
947            }
948        );
949        if is_external_texture {
950            return self.write_global_external_texture(module, handle, global);
951        }
952
953        if let Some(ref binding) = global.binding {
954            if let Err(err) = self.options.resolve_resource_binding(binding) {
955                log::debug!(
956                    "Skipping global {:?} (name {:?}) for being inaccessible: {}",
957                    handle,
958                    global.name,
959                    err,
960                );
961                return Ok(());
962            }
963        }
964
965        // Samplers are handled entirely differently, so defer entirely to that method.
966        let is_sampler = matches!(*handle_ty, TypeInner::Sampler { .. });
967
968        if is_sampler {
969            return self.write_global_sampler(module, handle, global);
970        }
971
972        // https://docs.microsoft.com/en-us/windows/win32/direct3dhlsl/dx-graphics-hlsl-variable-register
973        let register_ty = match global.space {
974            crate::AddressSpace::Function => unreachable!("Function address space"),
975            crate::AddressSpace::Private => {
976                write!(self.out, "static ")?;
977                self.write_type(module, global.ty)?;
978                ""
979            }
980            crate::AddressSpace::WorkGroup => {
981                write!(self.out, "groupshared ")?;
982                self.write_type(module, global.ty)?;
983                ""
984            }
985            crate::AddressSpace::TaskPayload => unimplemented!(),
986            crate::AddressSpace::Uniform => {
987                // constant buffer declarations are expected to be inlined, e.g.
988                // `cbuffer foo: register(b0) { field1: type1; }`
989                write!(self.out, "cbuffer")?;
990                "b"
991            }
992            crate::AddressSpace::Storage { access } => {
993                let (prefix, register) = if access.contains(crate::StorageAccess::STORE) {
994                    ("RW", "u")
995                } else {
996                    ("", "t")
997                };
998                write!(self.out, "{prefix}ByteAddressBuffer")?;
999                register
1000            }
1001            crate::AddressSpace::Handle => {
1002                let register = match *handle_ty {
1003                    // all storage textures are UAV, unconditionally
1004                    TypeInner::Image {
1005                        class: crate::ImageClass::Storage { .. },
1006                        ..
1007                    } => "u",
1008                    _ => "t",
1009                };
1010                self.write_type(module, global.ty)?;
1011                register
1012            }
1013            crate::AddressSpace::Immediate => {
1014                // The type of the immediates will be wrapped in `ConstantBuffer`
1015                write!(self.out, "ConstantBuffer<")?;
1016                "b"
1017            }
1018            crate::AddressSpace::RayPayload | crate::AddressSpace::IncomingRayPayload => {
1019                unimplemented!()
1020            }
1021        };
1022
1023        // If the global is a immediate data write the type now because it will be a
1024        // generic argument to `ConstantBuffer`
1025        if global.space == crate::AddressSpace::Immediate {
1026            self.write_global_type(module, global.ty)?;
1027
1028            // need to write the array size if the type was emitted with `write_type`
1029            if let TypeInner::Array { base, size, .. } = module.types[global.ty].inner {
1030                self.write_array_size(module, base, size)?;
1031            }
1032
1033            // Close the angled brackets for the generic argument
1034            write!(self.out, ">")?;
1035        }
1036
1037        let name = &self.names[&NameKey::GlobalVariable(handle)];
1038        write!(self.out, " {name}")?;
1039
1040        // Immediates need to be assigned a binding explicitly by the consumer
1041        // since naga has no way to know the binding from the shader alone
1042        if global.space == crate::AddressSpace::Immediate {
1043            match module.types[global.ty].inner {
1044                TypeInner::Struct { .. } => {}
1045                _ => {
1046                    return Err(Error::Unimplemented(format!(
1047                        "push-constant '{name}' has non-struct type; tracked by: https://github.com/gfx-rs/wgpu/issues/5683"
1048                    )));
1049                }
1050            }
1051
1052            let target = self
1053                .options
1054                .immediates_target
1055                .as_ref()
1056                .expect("No bind target was defined for the immediates block");
1057            write!(self.out, ": register(b{}", target.register)?;
1058            if target.space != 0 {
1059                write!(self.out, ", space{}", target.space)?;
1060            }
1061            write!(self.out, ")")?;
1062        }
1063
1064        if let Some(ref binding) = global.binding {
1065            // this was already resolved earlier when we started evaluating an entry point.
1066            let bt = self.options.resolve_resource_binding(binding).unwrap();
1067
1068            // need to write the binding array size if the type was emitted with `write_type`
1069            if let TypeInner::BindingArray { base, size, .. } = module.types[global.ty].inner {
1070                if let Some(overridden_size) = bt.binding_array_size {
1071                    write!(self.out, "[{overridden_size}]")?;
1072                } else {
1073                    self.write_array_size(module, base, size)?;
1074                }
1075            }
1076
1077            write!(self.out, " : register({}{}", register_ty, bt.register)?;
1078            if bt.space != 0 {
1079                write!(self.out, ", space{}", bt.space)?;
1080            }
1081            write!(self.out, ")")?;
1082        } else {
1083            // need to write the array size if the type was emitted with `write_type`
1084            if let TypeInner::Array { base, size, .. } = module.types[global.ty].inner {
1085                self.write_array_size(module, base, size)?;
1086            }
1087            if global.space == crate::AddressSpace::Private {
1088                write!(self.out, " = ")?;
1089                if let Some(init) = global.init {
1090                    self.write_const_expression(module, init, &module.global_expressions)?;
1091                } else {
1092                    self.write_default_init(module, global.ty)?;
1093                }
1094            }
1095        }
1096
1097        if global.space == crate::AddressSpace::Uniform {
1098            write!(self.out, " {{ ")?;
1099
1100            self.write_global_type(module, global.ty)?;
1101
1102            write!(
1103                self.out,
1104                " {}",
1105                &self.names[&NameKey::GlobalVariable(handle)]
1106            )?;
1107
1108            // need to write the array size if the type was emitted with `write_type`
1109            if let TypeInner::Array { base, size, .. } = module.types[global.ty].inner {
1110                self.write_array_size(module, base, size)?;
1111            }
1112
1113            writeln!(self.out, "; }}")?;
1114        } else {
1115            writeln!(self.out, ";")?;
1116        }
1117
1118        Ok(())
1119    }
1120
1121    fn write_global_sampler(
1122        &mut self,
1123        module: &Module,
1124        handle: Handle<crate::GlobalVariable>,
1125        global: &crate::GlobalVariable,
1126    ) -> BackendResult {
1127        let binding = *global.binding.as_ref().unwrap();
1128
1129        let key = super::SamplerIndexBufferKey {
1130            group: binding.group,
1131        };
1132        self.write_wrapped_sampler_buffer(key)?;
1133
1134        // This was already validated, so we can confidently unwrap it.
1135        let bt = self.options.resolve_resource_binding(&binding).unwrap();
1136
1137        match module.types[global.ty].inner {
1138            TypeInner::Sampler { comparison } => {
1139                // If we are generating a static access, we create a variable for the sampler.
1140                //
1141                // This prevents the DXIL from containing multiple lookups for the sampler, which
1142                // the backend compiler will then have to eliminate. AMD does seem to be able to
1143                // eliminate these, but better safe than sorry.
1144
1145                write!(self.out, "static const ")?;
1146                self.write_type(module, global.ty)?;
1147
1148                let heap_var = if comparison {
1149                    COMPARISON_SAMPLER_HEAP_VAR
1150                } else {
1151                    SAMPLER_HEAP_VAR
1152                };
1153
1154                let index_buffer_name = &self.wrapped.sampler_index_buffers[&key];
1155                let name = &self.names[&NameKey::GlobalVariable(handle)];
1156                writeln!(
1157                    self.out,
1158                    " {name} = {heap_var}[{index_buffer_name}[{register}]];",
1159                    register = bt.register
1160                )?;
1161            }
1162            TypeInner::BindingArray { .. } => {
1163                // If we are generating a binding array, we cannot directly access the sampler as the index
1164                // into the sampler index buffer is unknown at compile time. Instead we generate a constant
1165                // that represents the "base" index into the sampler index buffer. This constant is added
1166                // to the user provided index to get the final index into the sampler index buffer.
1167
1168                let name = &self.names[&NameKey::GlobalVariable(handle)];
1169                writeln!(
1170                    self.out,
1171                    "static const uint {name} = {register};",
1172                    register = bt.register
1173                )?;
1174            }
1175            _ => unreachable!(),
1176        };
1177
1178        Ok(())
1179    }
1180
1181    /// Write the declarations for an external texture global variable.
1182    /// These are emitted as multiple global variables: Three `Texture2D`s
1183    /// (one for each plane) and a parameters cbuffer.
1184    fn write_global_external_texture(
1185        &mut self,
1186        module: &Module,
1187        handle: Handle<crate::GlobalVariable>,
1188        global: &crate::GlobalVariable,
1189    ) -> BackendResult {
1190        let res_binding = global
1191            .binding
1192            .as_ref()
1193            .expect("External texture global variables must have a resource binding");
1194        let ext_tex_bindings = match self
1195            .options
1196            .resolve_external_texture_resource_binding(res_binding)
1197        {
1198            Ok(bindings) => bindings,
1199            Err(err) => {
1200                log::debug!(
1201                    "Skipping global {:?} (name {:?}) for being inaccessible: {}",
1202                    handle,
1203                    global.name,
1204                    err,
1205                );
1206                return Ok(());
1207            }
1208        };
1209
1210        let mut write_plane = |bt: &super::BindTarget, name| -> BackendResult {
1211            write!(
1212                self.out,
1213                "Texture2D<float4> {}: register(t{}",
1214                name, bt.register
1215            )?;
1216            if bt.space != 0 {
1217                write!(self.out, ", space{}", bt.space)?;
1218            }
1219            writeln!(self.out, ");")?;
1220            Ok(())
1221        };
1222        for (i, bt) in ext_tex_bindings.planes.iter().enumerate() {
1223            let plane_name = &self.names
1224                [&NameKey::ExternalTextureGlobalVariable(handle, ExternalTextureNameKey::Plane(i))];
1225            write_plane(bt, plane_name)?;
1226        }
1227
1228        let params_name = &self.names
1229            [&NameKey::ExternalTextureGlobalVariable(handle, ExternalTextureNameKey::Params)];
1230        let params_ty_name =
1231            &self.names[&NameKey::Type(module.special_types.external_texture_params.unwrap())];
1232        write!(
1233            self.out,
1234            "cbuffer {}: register(b{}",
1235            params_name, ext_tex_bindings.params.register
1236        )?;
1237        if ext_tex_bindings.params.space != 0 {
1238            write!(self.out, ", space{}", ext_tex_bindings.params.space)?;
1239        }
1240        writeln!(self.out, ") {{ {params_ty_name} {params_name}; }};")?;
1241
1242        Ok(())
1243    }
1244
1245    /// Helper method used to write global constants
1246    ///
1247    /// # Notes
1248    /// Ends in a newline
1249    fn write_global_constant(
1250        &mut self,
1251        module: &Module,
1252        handle: Handle<crate::Constant>,
1253    ) -> BackendResult {
1254        write!(self.out, "static const ")?;
1255        let constant = &module.constants[handle];
1256        self.write_type(module, constant.ty)?;
1257        let name = &self.names[&NameKey::Constant(handle)];
1258        write!(self.out, " {name}")?;
1259        // Write size for array type
1260        if let TypeInner::Array { base, size, .. } = module.types[constant.ty].inner {
1261            self.write_array_size(module, base, size)?;
1262        }
1263        write!(self.out, " = ")?;
1264        self.write_const_expression(module, constant.init, &module.global_expressions)?;
1265        writeln!(self.out, ";")?;
1266        Ok(())
1267    }
1268
1269    pub(super) fn write_array_size(
1270        &mut self,
1271        module: &Module,
1272        base: Handle<crate::Type>,
1273        size: crate::ArraySize,
1274    ) -> BackendResult {
1275        write!(self.out, "[")?;
1276
1277        match size.resolve(module.to_ctx())? {
1278            proc::IndexableLength::Known(size) => {
1279                write!(self.out, "{size}")?;
1280            }
1281            proc::IndexableLength::Dynamic => unreachable!(),
1282        }
1283
1284        write!(self.out, "]")?;
1285
1286        if let TypeInner::Array {
1287            base: next_base,
1288            size: next_size,
1289            ..
1290        } = module.types[base].inner
1291        {
1292            self.write_array_size(module, next_base, next_size)?;
1293        }
1294
1295        Ok(())
1296    }
1297
1298    /// Helper method used to write structs
1299    ///
1300    /// # Notes
1301    /// Ends in a newline
1302    fn write_struct(
1303        &mut self,
1304        module: &Module,
1305        handle: Handle<crate::Type>,
1306        members: &[crate::StructMember],
1307        span: u32,
1308        shader_stage: Option<(ShaderStage, Io)>,
1309    ) -> BackendResult {
1310        // Write struct name
1311        let struct_name = &self.names[&NameKey::Type(handle)];
1312        writeln!(self.out, "struct {struct_name} {{")?;
1313
1314        let mut last_offset = 0;
1315        for (index, member) in members.iter().enumerate() {
1316            if member.binding.is_none() && member.offset > last_offset {
1317                // using int as padding should work as long as the backend
1318                // doesn't support a type that's less than 4 bytes in size
1319                // (Error::UnsupportedScalar catches this)
1320                let padding = (member.offset - last_offset) / 4;
1321                for i in 0..padding {
1322                    writeln!(self.out, "{}int _pad{}_{};", back::INDENT, index, i)?;
1323                }
1324            }
1325            let ty_inner = &module.types[member.ty].inner;
1326            last_offset = member.offset + ty_inner.size_hlsl(module.to_ctx())?;
1327
1328            // The indentation is only for readability
1329            write!(self.out, "{}", back::INDENT)?;
1330
1331            match module.types[member.ty].inner {
1332                TypeInner::Array { base, size, .. } => {
1333                    // HLSL arrays are written as `type name[size]`
1334
1335                    self.write_global_type(module, member.ty)?;
1336
1337                    // Write `name`
1338                    write!(
1339                        self.out,
1340                        " {}",
1341                        &self.names[&NameKey::StructMember(handle, index as u32)]
1342                    )?;
1343                    // Write [size]
1344                    self.write_array_size(module, base, size)?;
1345                }
1346                // We treat matrices of the form `matCx2` as a sequence of C `vec2`s.
1347                // See the module-level block comment in mod.rs for details.
1348                TypeInner::Matrix {
1349                    rows,
1350                    columns,
1351                    scalar,
1352                } if member.binding.is_none() && rows == crate::VectorSize::Bi => {
1353                    let vec_ty = TypeInner::Vector { size: rows, scalar };
1354                    let field_name_key = NameKey::StructMember(handle, index as u32);
1355
1356                    for i in 0..columns as u8 {
1357                        if i != 0 {
1358                            write!(self.out, "; ")?;
1359                        }
1360                        self.write_value_type(module, &vec_ty)?;
1361                        write!(self.out, " {}_{}", &self.names[&field_name_key], i)?;
1362                    }
1363                }
1364                _ => {
1365                    // Write modifier before type
1366                    if let Some(ref binding) = member.binding {
1367                        self.write_modifier(binding)?;
1368                    }
1369
1370                    // Even though Naga IR matrices are column-major, we must describe
1371                    // matrices passed from the CPU as being in row-major order.
1372                    // See the module-level block comment in mod.rs for details.
1373                    if let TypeInner::Matrix { .. } = module.types[member.ty].inner {
1374                        write!(self.out, "row_major ")?;
1375                    }
1376
1377                    // Write the member type and name
1378                    self.write_type(module, member.ty)?;
1379                    write!(
1380                        self.out,
1381                        " {}",
1382                        &self.names[&NameKey::StructMember(handle, index as u32)]
1383                    )?;
1384                }
1385            }
1386
1387            self.write_semantic(&member.binding, shader_stage)?;
1388            writeln!(self.out, ";")?;
1389        }
1390
1391        // add padding at the end since sizes of types don't get rounded up to their alignment in HLSL
1392        if members.last().unwrap().binding.is_none() && span > last_offset {
1393            let padding = (span - last_offset) / 4;
1394            for i in 0..padding {
1395                writeln!(self.out, "{}int _end_pad_{};", back::INDENT, i)?;
1396            }
1397        }
1398
1399        writeln!(self.out, "}};")?;
1400        Ok(())
1401    }
1402
1403    /// Helper method used to write global/structs non image/sampler types
1404    ///
1405    /// # Notes
1406    /// Adds no trailing or leading whitespace
1407    pub(super) fn write_global_type(
1408        &mut self,
1409        module: &Module,
1410        ty: Handle<crate::Type>,
1411    ) -> BackendResult {
1412        let matrix_data = get_inner_matrix_data(module, ty);
1413
1414        // We treat matrices of the form `matCx2` as a sequence of C `vec2`s.
1415        // See the module-level block comment in mod.rs for details.
1416        if let Some(MatrixType {
1417            columns,
1418            rows: crate::VectorSize::Bi,
1419            width: 4,
1420        }) = matrix_data
1421        {
1422            write!(self.out, "__mat{}x2", columns as u8)?;
1423        } else {
1424            // Even though Naga IR matrices are column-major, we must describe
1425            // matrices passed from the CPU as being in row-major order.
1426            // See the module-level block comment in mod.rs for details.
1427            if matrix_data.is_some() {
1428                write!(self.out, "row_major ")?;
1429            }
1430
1431            self.write_type(module, ty)?;
1432        }
1433
1434        Ok(())
1435    }
1436
1437    /// Helper method used to write non image/sampler types
1438    ///
1439    /// # Notes
1440    /// Adds no trailing or leading whitespace
1441    pub(super) fn write_type(&mut self, module: &Module, ty: Handle<crate::Type>) -> BackendResult {
1442        let inner = &module.types[ty].inner;
1443        match *inner {
1444            TypeInner::Struct { .. } => write!(self.out, "{}", self.names[&NameKey::Type(ty)])?,
1445            // hlsl array has the size separated from the base type
1446            TypeInner::Array { base, .. } | TypeInner::BindingArray { base, .. } => {
1447                self.write_type(module, base)?
1448            }
1449            ref other => self.write_value_type(module, other)?,
1450        }
1451
1452        Ok(())
1453    }
1454
1455    /// Helper method used to write value types
1456    ///
1457    /// # Notes
1458    /// Adds no trailing or leading whitespace
1459    pub(super) fn write_value_type(&mut self, module: &Module, inner: &TypeInner) -> BackendResult {
1460        match *inner {
1461            TypeInner::Scalar(scalar) | TypeInner::Atomic(scalar) => {
1462                write!(self.out, "{}", scalar.to_hlsl_str()?)?;
1463            }
1464            TypeInner::Vector { size, scalar } => {
1465                write!(
1466                    self.out,
1467                    "{}{}",
1468                    scalar.to_hlsl_str()?,
1469                    common::vector_size_str(size)
1470                )?;
1471            }
1472            TypeInner::Matrix {
1473                columns,
1474                rows,
1475                scalar,
1476            } => {
1477                // The IR supports only float matrix
1478                // https://docs.microsoft.com/en-us/windows/win32/direct3dhlsl/dx-graphics-hlsl-matrix
1479
1480                // Because of the implicit transpose all matrices have in HLSL, we need to transpose the size as well.
1481                write!(
1482                    self.out,
1483                    "{}{}x{}",
1484                    scalar.to_hlsl_str()?,
1485                    common::vector_size_str(columns),
1486                    common::vector_size_str(rows),
1487                )?;
1488            }
1489            TypeInner::Image {
1490                dim,
1491                arrayed,
1492                class,
1493            } => {
1494                self.write_image_type(dim, arrayed, class)?;
1495            }
1496            TypeInner::Sampler { comparison } => {
1497                let sampler = if comparison {
1498                    "SamplerComparisonState"
1499                } else {
1500                    "SamplerState"
1501                };
1502                write!(self.out, "{sampler}")?;
1503            }
1504            // HLSL arrays are written as `type name[size]`
1505            // Current code is written arrays only as `[size]`
1506            // Base `type` and `name` should be written outside
1507            TypeInner::Array { base, size, .. } | TypeInner::BindingArray { base, size } => {
1508                self.write_array_size(module, base, size)?;
1509            }
1510            TypeInner::AccelerationStructure { .. } => {
1511                write!(self.out, "RaytracingAccelerationStructure")?;
1512            }
1513            TypeInner::RayQuery { .. } => {
1514                // these are constant flags, there are dynamic flags also but constant flags are not supported by naga
1515                write!(self.out, "RayQuery<RAY_FLAG_NONE>")?;
1516            }
1517            _ => return Err(Error::Unimplemented(format!("write_value_type {inner:?}"))),
1518        }
1519
1520        Ok(())
1521    }
1522
1523    /// Helper method used to write functions
1524    /// # Notes
1525    /// Ends in a newline
1526    fn write_function(
1527        &mut self,
1528        module: &Module,
1529        name: &str,
1530        func: &crate::Function,
1531        func_ctx: &back::FunctionCtx<'_>,
1532        info: &valid::FunctionInfo,
1533    ) -> BackendResult {
1534        // Function Declaration Syntax - https://docs.microsoft.com/en-us/windows/win32/direct3dhlsl/dx-graphics-hlsl-function-syntax
1535
1536        self.update_expressions_to_bake(module, func, info);
1537
1538        if let Some(ref result) = func.result {
1539            // Write typedef if return type is an array
1540            let array_return_type = match module.types[result.ty].inner {
1541                TypeInner::Array { base, size, .. } => {
1542                    let array_return_type = self.namer.call(&format!("ret_{name}"));
1543                    write!(self.out, "typedef ")?;
1544                    self.write_type(module, result.ty)?;
1545                    write!(self.out, " {array_return_type}")?;
1546                    self.write_array_size(module, base, size)?;
1547                    writeln!(self.out, ";")?;
1548                    Some(array_return_type)
1549                }
1550                _ => None,
1551            };
1552
1553            // Write modifier
1554            if let Some(
1555                ref binding @ crate::Binding::BuiltIn(crate::BuiltIn::Position { invariant: true }),
1556            ) = result.binding
1557            {
1558                self.write_modifier(binding)?;
1559            }
1560
1561            // Write return type
1562            match func_ctx.ty {
1563                back::FunctionType::Function(_) => {
1564                    if let Some(array_return_type) = array_return_type {
1565                        write!(self.out, "{array_return_type}")?;
1566                    } else {
1567                        self.write_type(module, result.ty)?;
1568                    }
1569                }
1570                back::FunctionType::EntryPoint(index) => {
1571                    if let Some(ref ep_output) =
1572                        self.entry_point_io.get(&(index as usize)).unwrap().output
1573                    {
1574                        write!(self.out, "{}", ep_output.ty_name)?;
1575                    } else {
1576                        self.write_type(module, result.ty)?;
1577                    }
1578                }
1579            }
1580        } else {
1581            write!(self.out, "void")?;
1582        }
1583
1584        // Write function name
1585        write!(self.out, " {name}(")?;
1586
1587        let need_workgroup_variables_initialization =
1588            self.need_workgroup_variables_initialization(func_ctx, module);
1589
1590        // Write function arguments for non entry point functions
1591        match func_ctx.ty {
1592            back::FunctionType::Function(handle) => {
1593                for (index, arg) in func.arguments.iter().enumerate() {
1594                    if index != 0 {
1595                        write!(self.out, ", ")?;
1596                    }
1597
1598                    self.write_function_argument(module, handle, arg, index)?;
1599                }
1600            }
1601            back::FunctionType::EntryPoint(ep_index) => {
1602                if let Some(ref ep_input) =
1603                    self.entry_point_io.get(&(ep_index as usize)).unwrap().input
1604                {
1605                    write!(self.out, "{} {}", ep_input.ty_name, ep_input.arg_name)?;
1606                } else {
1607                    let stage = module.entry_points[ep_index as usize].stage;
1608                    for (index, arg) in func.arguments.iter().enumerate() {
1609                        if index != 0 {
1610                            write!(self.out, ", ")?;
1611                        }
1612                        self.write_type(module, arg.ty)?;
1613
1614                        let argument_name =
1615                            &self.names[&NameKey::EntryPointArgument(ep_index, index as u32)];
1616
1617                        write!(self.out, " {argument_name}")?;
1618                        if let TypeInner::Array { base, size, .. } = module.types[arg.ty].inner {
1619                            self.write_array_size(module, base, size)?;
1620                        }
1621
1622                        self.write_semantic(&arg.binding, Some((stage, Io::Input)))?;
1623                    }
1624                }
1625                if need_workgroup_variables_initialization {
1626                    if self
1627                        .entry_point_io
1628                        .get(&(ep_index as usize))
1629                        .unwrap()
1630                        .input
1631                        .is_some()
1632                        || !func.arguments.is_empty()
1633                    {
1634                        write!(self.out, ", ")?;
1635                    }
1636                    write!(self.out, "uint3 __local_invocation_id : SV_GroupThreadID")?;
1637                }
1638            }
1639        }
1640        // Ends of arguments
1641        write!(self.out, ")")?;
1642
1643        // Write semantic if it present
1644        if let back::FunctionType::EntryPoint(index) = func_ctx.ty {
1645            let stage = module.entry_points[index as usize].stage;
1646            if let Some(crate::FunctionResult { ref binding, .. }) = func.result {
1647                self.write_semantic(binding, Some((stage, Io::Output)))?;
1648            }
1649        }
1650
1651        // Function body start
1652        writeln!(self.out)?;
1653        writeln!(self.out, "{{")?;
1654
1655        if need_workgroup_variables_initialization {
1656            self.write_workgroup_variables_initialization(func_ctx, module)?;
1657        }
1658
1659        if let back::FunctionType::EntryPoint(index) = func_ctx.ty {
1660            self.write_ep_arguments_initialization(module, func, index)?;
1661        }
1662
1663        // Write function local variables
1664        for (handle, local) in func.local_variables.iter() {
1665            // Write indentation (only for readability)
1666            write!(self.out, "{}", back::INDENT)?;
1667
1668            // Write the local name
1669            // The leading space is important
1670            self.write_type(module, local.ty)?;
1671            write!(self.out, " {}", self.names[&func_ctx.name_key(handle)])?;
1672            // Write size for array type
1673            if let TypeInner::Array { base, size, .. } = module.types[local.ty].inner {
1674                self.write_array_size(module, base, size)?;
1675            }
1676
1677            let is_ray_query = match module.types[local.ty].inner {
1678                // from https://microsoft.github.io/DirectX-Specs/d3d/Raytracing.html#tracerayinline-example-1 it seems that ray queries shouldn't be zeroed
1679                TypeInner::RayQuery { .. } => true,
1680                _ => {
1681                    write!(self.out, " = ")?;
1682                    // Write the local initializer if needed
1683                    if let Some(init) = local.init {
1684                        self.write_expr(module, init, func_ctx)?;
1685                    } else {
1686                        // Zero initialize local variables
1687                        self.write_default_init(module, local.ty)?;
1688                    }
1689                    false
1690                }
1691            };
1692            // Finish the local with `;` and add a newline (only for readability)
1693            writeln!(self.out, ";")?;
1694            // If it's a ray query, we also want a tracker variable
1695            if is_ray_query {
1696                write!(self.out, "{}", back::INDENT)?;
1697                self.write_value_type(module, &TypeInner::Scalar(Scalar::U32))?;
1698                writeln!(
1699                    self.out,
1700                    " {RAY_QUERY_TRACKER_VARIABLE_PREFIX}{} = 0;",
1701                    self.names[&func_ctx.name_key(handle)]
1702                )?;
1703            }
1704        }
1705
1706        if !func.local_variables.is_empty() {
1707            writeln!(self.out)?;
1708        }
1709
1710        // Write the function body (statement list)
1711        for sta in func.body.iter() {
1712            // The indentation should always be 1 when writing the function body
1713            self.write_stmt(module, sta, func_ctx, back::Level(1))?;
1714        }
1715
1716        writeln!(self.out, "}}")?;
1717
1718        self.named_expressions.clear();
1719
1720        Ok(())
1721    }
1722
1723    fn write_function_argument(
1724        &mut self,
1725        module: &Module,
1726        handle: Handle<crate::Function>,
1727        arg: &crate::FunctionArgument,
1728        index: usize,
1729    ) -> BackendResult {
1730        // External texture arguments must be expanded into separate
1731        // arguments for each plane and the params buffer.
1732        if let TypeInner::Image {
1733            class: crate::ImageClass::External,
1734            ..
1735        } = module.types[arg.ty].inner
1736        {
1737            return self.write_function_external_texture_argument(module, handle, index);
1738        }
1739
1740        // Write argument type
1741        let arg_ty = match module.types[arg.ty].inner {
1742            // pointers in function arguments are expected and resolve to `inout`
1743            TypeInner::Pointer { base, .. } => {
1744                //TODO: can we narrow this down to just `in` when possible?
1745                write!(self.out, "inout ")?;
1746                base
1747            }
1748            _ => arg.ty,
1749        };
1750        self.write_type(module, arg_ty)?;
1751
1752        let argument_name = &self.names[&NameKey::FunctionArgument(handle, index as u32)];
1753
1754        // Write argument name. Space is important.
1755        write!(self.out, " {argument_name}")?;
1756        if let TypeInner::Array { base, size, .. } = module.types[arg_ty].inner {
1757            self.write_array_size(module, base, size)?;
1758        }
1759
1760        Ok(())
1761    }
1762
1763    fn write_function_external_texture_argument(
1764        &mut self,
1765        module: &Module,
1766        handle: Handle<crate::Function>,
1767        index: usize,
1768    ) -> BackendResult {
1769        let plane_names = [0, 1, 2].map(|i| {
1770            &self.names[&NameKey::ExternalTextureFunctionArgument(
1771                handle,
1772                index as u32,
1773                ExternalTextureNameKey::Plane(i),
1774            )]
1775        });
1776        let params_name = &self.names[&NameKey::ExternalTextureFunctionArgument(
1777            handle,
1778            index as u32,
1779            ExternalTextureNameKey::Params,
1780        )];
1781        let params_ty_name =
1782            &self.names[&NameKey::Type(module.special_types.external_texture_params.unwrap())];
1783        write!(
1784            self.out,
1785            "Texture2D<float4> {}, Texture2D<float4> {}, Texture2D<float4> {}, {params_ty_name} {params_name}",
1786            plane_names[0], plane_names[1], plane_names[2],
1787        )?;
1788        Ok(())
1789    }
1790
1791    fn need_workgroup_variables_initialization(
1792        &mut self,
1793        func_ctx: &back::FunctionCtx,
1794        module: &Module,
1795    ) -> bool {
1796        self.options.zero_initialize_workgroup_memory
1797            && func_ctx.ty.is_compute_like_entry_point(module)
1798            && module.global_variables.iter().any(|(handle, var)| {
1799                !func_ctx.info[handle].is_empty() && var.space == crate::AddressSpace::WorkGroup
1800            })
1801    }
1802
1803    fn write_workgroup_variables_initialization(
1804        &mut self,
1805        func_ctx: &back::FunctionCtx,
1806        module: &Module,
1807    ) -> BackendResult {
1808        let level = back::Level(1);
1809
1810        writeln!(
1811            self.out,
1812            "{level}if (all(__local_invocation_id == uint3(0u, 0u, 0u))) {{"
1813        )?;
1814
1815        let vars = module.global_variables.iter().filter(|&(handle, var)| {
1816            !func_ctx.info[handle].is_empty() && var.space == crate::AddressSpace::WorkGroup
1817        });
1818
1819        for (handle, var) in vars {
1820            let name = &self.names[&NameKey::GlobalVariable(handle)];
1821            write!(self.out, "{}{} = ", level.next(), name)?;
1822            self.write_default_init(module, var.ty)?;
1823            writeln!(self.out, ";")?;
1824        }
1825
1826        writeln!(self.out, "{level}}}")?;
1827        self.write_control_barrier(crate::Barrier::WORK_GROUP, level)
1828    }
1829
1830    /// Helper method used to write switches
1831    fn write_switch(
1832        &mut self,
1833        module: &Module,
1834        func_ctx: &back::FunctionCtx<'_>,
1835        level: back::Level,
1836        selector: Handle<crate::Expression>,
1837        cases: &[crate::SwitchCase],
1838    ) -> BackendResult {
1839        // Write all cases
1840        let indent_level_1 = level.next();
1841        let indent_level_2 = indent_level_1.next();
1842
1843        // See docs of `back::continue_forward` module.
1844        if let Some(variable) = self.continue_ctx.enter_switch(&mut self.namer) {
1845            writeln!(self.out, "{level}bool {variable} = false;",)?;
1846        };
1847
1848        // Check if there is only one body, by seeing if all except the last case are fall through
1849        // with empty bodies. FXC doesn't handle these switches correctly, so
1850        // we generate a `do {} while(false);` loop instead. There must be a default case, so there
1851        // is no need to check if one of the cases would have matched.
1852        let one_body = cases
1853            .iter()
1854            .rev()
1855            .skip(1)
1856            .all(|case| case.fall_through && case.body.is_empty());
1857        if one_body {
1858            // Start the do-while
1859            writeln!(self.out, "{level}do {{")?;
1860            // Note: Expressions have no side-effects so we don't need to emit selector expression.
1861
1862            // Body
1863            if let Some(case) = cases.last() {
1864                for sta in case.body.iter() {
1865                    self.write_stmt(module, sta, func_ctx, indent_level_1)?;
1866                }
1867            }
1868            // End do-while
1869            writeln!(self.out, "{level}}} while(false);")?;
1870        } else {
1871            // Start the switch
1872            write!(self.out, "{level}")?;
1873            write!(self.out, "switch(")?;
1874            self.write_expr(module, selector, func_ctx)?;
1875            writeln!(self.out, ") {{")?;
1876
1877            for (i, case) in cases.iter().enumerate() {
1878                match case.value {
1879                    crate::SwitchValue::I32(value) => {
1880                        write!(self.out, "{indent_level_1}case {value}:")?
1881                    }
1882                    crate::SwitchValue::U32(value) => {
1883                        write!(self.out, "{indent_level_1}case {value}u:")?
1884                    }
1885                    crate::SwitchValue::Default => write!(self.out, "{indent_level_1}default:")?,
1886                }
1887
1888                // The new block is not only stylistic, it plays a role here:
1889                // We might end up having to write the same case body
1890                // multiple times due to FXC not supporting fallthrough.
1891                // Therefore, some `Expression`s written by `Statement::Emit`
1892                // will end up having the same name (`_expr<handle_index>`).
1893                // So we need to put each case in its own scope.
1894                let write_block_braces = !(case.fall_through && case.body.is_empty());
1895                if write_block_braces {
1896                    writeln!(self.out, " {{")?;
1897                } else {
1898                    writeln!(self.out)?;
1899                }
1900
1901                // Although FXC does support a series of case clauses before
1902                // a block[^yes], it does not support fallthrough from a
1903                // non-empty case block to the next[^no]. If this case has a
1904                // non-empty body with a fallthrough, emulate that by
1905                // duplicating the bodies of all the cases it would fall
1906                // into as extensions of this case's own body. This makes
1907                // the HLSL output potentially quadratic in the size of the
1908                // Naga IR.
1909                //
1910                // [^yes]: ```hlsl
1911                // case 1:
1912                // case 2: do_stuff()
1913                // ```
1914                // [^no]: ```hlsl
1915                // case 1: do_this();
1916                // case 2: do_that();
1917                // ```
1918                if case.fall_through && !case.body.is_empty() {
1919                    let curr_len = i + 1;
1920                    let end_case_idx = curr_len
1921                        + cases
1922                            .iter()
1923                            .skip(curr_len)
1924                            .position(|case| !case.fall_through)
1925                            .unwrap();
1926                    let indent_level_3 = indent_level_2.next();
1927                    for case in &cases[i..=end_case_idx] {
1928                        writeln!(self.out, "{indent_level_2}{{")?;
1929                        let prev_len = self.named_expressions.len();
1930                        for sta in case.body.iter() {
1931                            self.write_stmt(module, sta, func_ctx, indent_level_3)?;
1932                        }
1933                        // Clear all named expressions that were previously inserted by the statements in the block
1934                        self.named_expressions.truncate(prev_len);
1935                        writeln!(self.out, "{indent_level_2}}}")?;
1936                    }
1937
1938                    let last_case = &cases[end_case_idx];
1939                    if last_case.body.last().is_none_or(|s| !s.is_terminator()) {
1940                        writeln!(self.out, "{indent_level_2}break;")?;
1941                    }
1942                } else {
1943                    for sta in case.body.iter() {
1944                        self.write_stmt(module, sta, func_ctx, indent_level_2)?;
1945                    }
1946                    if !case.fall_through && case.body.last().is_none_or(|s| !s.is_terminator()) {
1947                        writeln!(self.out, "{indent_level_2}break;")?;
1948                    }
1949                }
1950
1951                if write_block_braces {
1952                    writeln!(self.out, "{indent_level_1}}}")?;
1953                }
1954            }
1955
1956            writeln!(self.out, "{level}}}")?;
1957        }
1958
1959        // Handle any forwarded continue statements.
1960        use back::continue_forward::ExitControlFlow;
1961        let op = match self.continue_ctx.exit_switch() {
1962            ExitControlFlow::None => None,
1963            ExitControlFlow::Continue { variable } => Some(("continue", variable)),
1964            ExitControlFlow::Break { variable } => Some(("break", variable)),
1965        };
1966        if let Some((control_flow, variable)) = op {
1967            writeln!(self.out, "{level}if ({variable}) {{")?;
1968            writeln!(self.out, "{indent_level_1}{control_flow};")?;
1969            writeln!(self.out, "{level}}}")?;
1970        }
1971
1972        Ok(())
1973    }
1974
1975    fn write_index(
1976        &mut self,
1977        module: &Module,
1978        index: Index,
1979        func_ctx: &back::FunctionCtx<'_>,
1980    ) -> BackendResult {
1981        match index {
1982            Index::Static(index) => {
1983                write!(self.out, "{index}")?;
1984            }
1985            Index::Expression(index) => {
1986                self.write_expr(module, index, func_ctx)?;
1987            }
1988        }
1989        Ok(())
1990    }
1991
1992    /// Helper method used to write statements
1993    ///
1994    /// # Notes
1995    /// Always adds a newline
1996    fn write_stmt(
1997        &mut self,
1998        module: &Module,
1999        stmt: &crate::Statement,
2000        func_ctx: &back::FunctionCtx<'_>,
2001        level: back::Level,
2002    ) -> BackendResult {
2003        use crate::Statement;
2004
2005        match *stmt {
2006            Statement::Emit(ref range) => {
2007                for handle in range.clone() {
2008                    let ptr_class = func_ctx.resolve_type(handle, &module.types).pointer_space();
2009                    let expr_name = if ptr_class.is_some() {
2010                        // HLSL can't save a pointer-valued expression in a variable,
2011                        // but we shouldn't ever need to: they should never be named expressions,
2012                        // and none of the expression types flagged by bake_ref_count can be pointer-valued.
2013                        None
2014                    } else if let Some(name) = func_ctx.named_expressions.get(&handle) {
2015                        // Front end provides names for all variables at the start of writing.
2016                        // But we write them to step by step. We need to recache them
2017                        // Otherwise, we could accidentally write variable name instead of full expression.
2018                        // Also, we use sanitized names! It defense backend from generating variable with name from reserved keywords.
2019                        Some(self.namer.call(name))
2020                    } else if self.need_bake_expressions.contains(&handle) {
2021                        Some(Baked(handle).to_string())
2022                    } else {
2023                        None
2024                    };
2025
2026                    if let Some(name) = expr_name {
2027                        write!(self.out, "{level}")?;
2028                        self.write_named_expr(module, handle, name, handle, func_ctx)?;
2029                    }
2030                }
2031            }
2032            // TODO: copy-paste from glsl-out
2033            Statement::Block(ref block) => {
2034                write!(self.out, "{level}")?;
2035                writeln!(self.out, "{{")?;
2036                for sta in block.iter() {
2037                    // Increase the indentation to help with readability
2038                    self.write_stmt(module, sta, func_ctx, level.next())?
2039                }
2040                writeln!(self.out, "{level}}}")?
2041            }
2042            // TODO: copy-paste from glsl-out
2043            Statement::If {
2044                condition,
2045                ref accept,
2046                ref reject,
2047            } => {
2048                write!(self.out, "{level}")?;
2049                write!(self.out, "if (")?;
2050                self.write_expr(module, condition, func_ctx)?;
2051                writeln!(self.out, ") {{")?;
2052
2053                let l2 = level.next();
2054                for sta in accept {
2055                    // Increase indentation to help with readability
2056                    self.write_stmt(module, sta, func_ctx, l2)?;
2057                }
2058
2059                // If there are no statements in the reject block we skip writing it
2060                // This is only for readability
2061                if !reject.is_empty() {
2062                    writeln!(self.out, "{level}}} else {{")?;
2063
2064                    for sta in reject {
2065                        // Increase indentation to help with readability
2066                        self.write_stmt(module, sta, func_ctx, l2)?;
2067                    }
2068                }
2069
2070                writeln!(self.out, "{level}}}")?
2071            }
2072            // TODO: copy-paste from glsl-out
2073            Statement::Kill => writeln!(self.out, "{level}discard;")?,
2074            Statement::Return { value: None } => {
2075                writeln!(self.out, "{level}return;")?;
2076            }
2077            Statement::Return { value: Some(expr) } => {
2078                let base_ty_res = &func_ctx.info[expr].ty;
2079                let mut resolved = base_ty_res.inner_with(&module.types);
2080                if let TypeInner::Pointer { base, space: _ } = *resolved {
2081                    resolved = &module.types[base].inner;
2082                }
2083
2084                if let TypeInner::Struct { .. } = *resolved {
2085                    // We can safely unwrap here, since we now we working with struct
2086                    let ty = base_ty_res.handle().unwrap();
2087                    let struct_name = &self.names[&NameKey::Type(ty)];
2088                    let variable_name = self.namer.call(&struct_name.to_lowercase());
2089                    write!(self.out, "{level}const {struct_name} {variable_name} = ",)?;
2090                    self.write_expr(module, expr, func_ctx)?;
2091                    writeln!(self.out, ";")?;
2092
2093                    // for entry point returns, we may need to reshuffle the outputs into a different struct
2094                    let ep_output = match func_ctx.ty {
2095                        back::FunctionType::Function(_) => None,
2096                        back::FunctionType::EntryPoint(index) => self
2097                            .entry_point_io
2098                            .get(&(index as usize))
2099                            .unwrap()
2100                            .output
2101                            .as_ref(),
2102                    };
2103                    let final_name = match ep_output {
2104                        Some(ep_output) => {
2105                            let final_name = self.namer.call(&variable_name);
2106                            write!(
2107                                self.out,
2108                                "{}const {} {} = {{ ",
2109                                level, ep_output.ty_name, final_name,
2110                            )?;
2111                            for (index, m) in ep_output.members.iter().enumerate() {
2112                                if index != 0 {
2113                                    write!(self.out, ", ")?;
2114                                }
2115                                let member_name = &self.names[&NameKey::StructMember(ty, m.index)];
2116                                write!(self.out, "{variable_name}.{member_name}")?;
2117                            }
2118                            writeln!(self.out, " }};")?;
2119                            final_name
2120                        }
2121                        None => variable_name,
2122                    };
2123                    writeln!(self.out, "{level}return {final_name};")?;
2124                } else {
2125                    write!(self.out, "{level}return ")?;
2126                    self.write_expr(module, expr, func_ctx)?;
2127                    writeln!(self.out, ";")?
2128                }
2129            }
2130            Statement::Store { pointer, value } => {
2131                let ty_inner = func_ctx.resolve_type(pointer, &module.types);
2132                if let Some(crate::AddressSpace::Storage { .. }) = ty_inner.pointer_space() {
2133                    let var_handle = self.fill_access_chain(module, pointer, func_ctx)?;
2134                    self.write_storage_store(
2135                        module,
2136                        var_handle,
2137                        StoreValue::Expression(value),
2138                        func_ctx,
2139                        level,
2140                        None,
2141                    )?;
2142                } else {
2143                    // We treat matrices of the form `matCx2` as a sequence of C `vec2`s.
2144                    // See the module-level block comment in mod.rs for details.
2145                    //
2146                    // We handle matrix Stores here directly (including sub accesses for Vectors and Scalars).
2147                    // Loads are handled by `Expression::AccessIndex` (since sub accesses work fine for Loads).
2148                    enum MatrixAccess {
2149                        Direct {
2150                            base: Handle<crate::Expression>,
2151                            index: u32,
2152                        },
2153                        Struct {
2154                            columns: crate::VectorSize,
2155                            base: Handle<crate::Expression>,
2156                        },
2157                    }
2158
2159                    let get_members = |expr: Handle<crate::Expression>| {
2160                        let resolved = func_ctx.resolve_type(expr, &module.types);
2161                        match *resolved {
2162                            TypeInner::Pointer { base, .. } => match module.types[base].inner {
2163                                TypeInner::Struct { ref members, .. } => Some(members),
2164                                _ => None,
2165                            },
2166                            _ => None,
2167                        }
2168                    };
2169
2170                    write!(self.out, "{level}")?;
2171
2172                    let matrix_access_on_lhs =
2173                        find_matrix_in_access_chain(module, pointer, func_ctx).and_then(
2174                            |(matrix_expr, vector, scalar)| match (
2175                                func_ctx.resolve_type(matrix_expr, &module.types),
2176                                &func_ctx.expressions[matrix_expr],
2177                            ) {
2178                                (
2179                                    &TypeInner::Pointer { base: ty, .. },
2180                                    &crate::Expression::AccessIndex { base, index },
2181                                ) if matches!(
2182                                    module.types[ty].inner,
2183                                    TypeInner::Matrix {
2184                                        rows: crate::VectorSize::Bi,
2185                                        ..
2186                                    }
2187                                ) && get_members(base)
2188                                    .map(|members| members[index as usize].binding.is_none())
2189                                    == Some(true) =>
2190                                {
2191                                    Some((MatrixAccess::Direct { base, index }, vector, scalar))
2192                                }
2193                                _ => {
2194                                    if let Some(MatrixType {
2195                                        columns,
2196                                        rows: crate::VectorSize::Bi,
2197                                        width: 4,
2198                                    }) = get_inner_matrix_of_struct_array_member(
2199                                        module,
2200                                        matrix_expr,
2201                                        func_ctx,
2202                                        true,
2203                                    ) {
2204                                        Some((
2205                                            MatrixAccess::Struct {
2206                                                columns,
2207                                                base: matrix_expr,
2208                                            },
2209                                            vector,
2210                                            scalar,
2211                                        ))
2212                                    } else {
2213                                        None
2214                                    }
2215                                }
2216                            },
2217                        );
2218
2219                    match matrix_access_on_lhs {
2220                        Some((MatrixAccess::Direct { index, base }, vector, scalar)) => {
2221                            let base_ty_res = &func_ctx.info[base].ty;
2222                            let resolved = base_ty_res.inner_with(&module.types);
2223                            let ty = match *resolved {
2224                                TypeInner::Pointer { base, .. } => base,
2225                                _ => base_ty_res.handle().unwrap(),
2226                            };
2227
2228                            if let Some(Index::Static(vec_index)) = vector {
2229                                self.write_expr(module, base, func_ctx)?;
2230                                write!(
2231                                    self.out,
2232                                    ".{}_{}",
2233                                    &self.names[&NameKey::StructMember(ty, index)],
2234                                    vec_index
2235                                )?;
2236
2237                                if let Some(scalar_index) = scalar {
2238                                    write!(self.out, "[")?;
2239                                    self.write_index(module, scalar_index, func_ctx)?;
2240                                    write!(self.out, "]")?;
2241                                }
2242
2243                                write!(self.out, " = ")?;
2244                                self.write_expr(module, value, func_ctx)?;
2245                                writeln!(self.out, ";")?;
2246                            } else {
2247                                let access = WrappedStructMatrixAccess { ty, index };
2248                                match (&vector, &scalar) {
2249                                    (&Some(_), &Some(_)) => {
2250                                        self.write_wrapped_struct_matrix_set_scalar_function_name(
2251                                            access,
2252                                        )?;
2253                                    }
2254                                    (&Some(_), &None) => {
2255                                        self.write_wrapped_struct_matrix_set_vec_function_name(
2256                                            access,
2257                                        )?;
2258                                    }
2259                                    (&None, _) => {
2260                                        self.write_wrapped_struct_matrix_set_function_name(access)?;
2261                                    }
2262                                }
2263
2264                                write!(self.out, "(")?;
2265                                self.write_expr(module, base, func_ctx)?;
2266                                write!(self.out, ", ")?;
2267                                self.write_expr(module, value, func_ctx)?;
2268
2269                                if let Some(Index::Expression(vec_index)) = vector {
2270                                    write!(self.out, ", ")?;
2271                                    self.write_expr(module, vec_index, func_ctx)?;
2272
2273                                    if let Some(scalar_index) = scalar {
2274                                        write!(self.out, ", ")?;
2275                                        self.write_index(module, scalar_index, func_ctx)?;
2276                                    }
2277                                }
2278                                writeln!(self.out, ");")?;
2279                            }
2280                        }
2281                        Some((
2282                            MatrixAccess::Struct { columns, base },
2283                            Some(Index::Expression(vec_index)),
2284                            scalar,
2285                        )) => {
2286                            // We handle `Store`s to __matCx2 column vectors and scalar elements via
2287                            // the previously injected functions __set_col_of_matCx2 / __set_el_of_matCx2.
2288
2289                            if scalar.is_some() {
2290                                write!(self.out, "__set_el_of_mat{}x2", columns as u8)?;
2291                            } else {
2292                                write!(self.out, "__set_col_of_mat{}x2", columns as u8)?;
2293                            }
2294                            write!(self.out, "(")?;
2295                            self.write_expr(module, base, func_ctx)?;
2296                            write!(self.out, ", ")?;
2297                            self.write_expr(module, vec_index, func_ctx)?;
2298
2299                            if let Some(scalar_index) = scalar {
2300                                write!(self.out, ", ")?;
2301                                self.write_index(module, scalar_index, func_ctx)?;
2302                            }
2303
2304                            write!(self.out, ", ")?;
2305                            self.write_expr(module, value, func_ctx)?;
2306
2307                            writeln!(self.out, ");")?;
2308                        }
2309                        Some((MatrixAccess::Struct { .. }, Some(Index::Static(_)), _))
2310                        | Some((MatrixAccess::Struct { .. }, None, _))
2311                        | None => {
2312                            self.write_expr(module, pointer, func_ctx)?;
2313                            write!(self.out, " = ")?;
2314
2315                            // We cast the RHS of this store in cases where the LHS
2316                            // is a struct member with type:
2317                            //  - matCx2 or
2318                            //  - a (possibly nested) array of matCx2's
2319                            if let Some(MatrixType {
2320                                columns,
2321                                rows: crate::VectorSize::Bi,
2322                                width: 4,
2323                            }) = get_inner_matrix_of_struct_array_member(
2324                                module, pointer, func_ctx, false,
2325                            ) {
2326                                let mut resolved = func_ctx.resolve_type(pointer, &module.types);
2327                                if let TypeInner::Pointer { base, .. } = *resolved {
2328                                    resolved = &module.types[base].inner;
2329                                }
2330
2331                                write!(self.out, "(__mat{}x2", columns as u8)?;
2332                                if let TypeInner::Array { base, size, .. } = *resolved {
2333                                    self.write_array_size(module, base, size)?;
2334                                }
2335                                write!(self.out, ")")?;
2336                            }
2337
2338                            self.write_expr(module, value, func_ctx)?;
2339                            writeln!(self.out, ";")?
2340                        }
2341                    }
2342                }
2343            }
2344            Statement::Loop {
2345                ref body,
2346                ref continuing,
2347                break_if,
2348            } => {
2349                let force_loop_bound_statements = self.gen_force_bounded_loop_statements(level);
2350                let gate_name = (!continuing.is_empty() || break_if.is_some())
2351                    .then(|| self.namer.call("loop_init"));
2352
2353                if let Some((ref decl, _)) = force_loop_bound_statements {
2354                    writeln!(self.out, "{decl}")?;
2355                }
2356                if let Some(ref gate_name) = gate_name {
2357                    writeln!(self.out, "{level}bool {gate_name} = true;")?;
2358                }
2359
2360                self.continue_ctx.enter_loop();
2361                writeln!(self.out, "{level}while(true) {{")?;
2362                if let Some((_, ref break_and_inc)) = force_loop_bound_statements {
2363                    writeln!(self.out, "{break_and_inc}")?;
2364                }
2365                let l2 = level.next();
2366                if let Some(gate_name) = gate_name {
2367                    writeln!(self.out, "{l2}if (!{gate_name}) {{")?;
2368                    let l3 = l2.next();
2369                    for sta in continuing.iter() {
2370                        self.write_stmt(module, sta, func_ctx, l3)?;
2371                    }
2372                    if let Some(condition) = break_if {
2373                        write!(self.out, "{l3}if (")?;
2374                        self.write_expr(module, condition, func_ctx)?;
2375                        writeln!(self.out, ") {{")?;
2376                        writeln!(self.out, "{}break;", l3.next())?;
2377                        writeln!(self.out, "{l3}}}")?;
2378                    }
2379                    writeln!(self.out, "{l2}}}")?;
2380                    writeln!(self.out, "{l2}{gate_name} = false;")?;
2381                }
2382
2383                for sta in body.iter() {
2384                    self.write_stmt(module, sta, func_ctx, l2)?;
2385                }
2386
2387                writeln!(self.out, "{level}}}")?;
2388                self.continue_ctx.exit_loop();
2389            }
2390            Statement::Break => writeln!(self.out, "{level}break;")?,
2391            Statement::Continue => {
2392                if let Some(variable) = self.continue_ctx.continue_encountered() {
2393                    writeln!(self.out, "{level}{variable} = true;")?;
2394                    writeln!(self.out, "{level}break;")?
2395                } else {
2396                    writeln!(self.out, "{level}continue;")?
2397                }
2398            }
2399            Statement::ControlBarrier(barrier) => {
2400                self.write_control_barrier(barrier, level)?;
2401            }
2402            Statement::MemoryBarrier(barrier) => {
2403                self.write_memory_barrier(barrier, level)?;
2404            }
2405            Statement::ImageStore {
2406                image,
2407                coordinate,
2408                array_index,
2409                value,
2410            } => {
2411                write!(self.out, "{level}")?;
2412                self.write_expr(module, image, func_ctx)?;
2413
2414                write!(self.out, "[")?;
2415                if let Some(index) = array_index {
2416                    // Array index accepted only for texture_storage_2d_array, so we can safety use int3(coordinate, array_index) here
2417                    write!(self.out, "int3(")?;
2418                    self.write_expr(module, coordinate, func_ctx)?;
2419                    write!(self.out, ", ")?;
2420                    self.write_expr(module, index, func_ctx)?;
2421                    write!(self.out, ")")?;
2422                } else {
2423                    self.write_expr(module, coordinate, func_ctx)?;
2424                }
2425                write!(self.out, "]")?;
2426
2427                write!(self.out, " = ")?;
2428                self.write_expr(module, value, func_ctx)?;
2429                writeln!(self.out, ";")?;
2430            }
2431            Statement::Call {
2432                function,
2433                ref arguments,
2434                result,
2435            } => {
2436                write!(self.out, "{level}")?;
2437                if let Some(expr) = result {
2438                    write!(self.out, "const ")?;
2439                    let name = Baked(expr).to_string();
2440                    let expr_ty = &func_ctx.info[expr].ty;
2441                    let ty_inner = match *expr_ty {
2442                        proc::TypeResolution::Handle(handle) => {
2443                            self.write_type(module, handle)?;
2444                            &module.types[handle].inner
2445                        }
2446                        proc::TypeResolution::Value(ref value) => {
2447                            self.write_value_type(module, value)?;
2448                            value
2449                        }
2450                    };
2451                    write!(self.out, " {name}")?;
2452                    if let TypeInner::Array { base, size, .. } = *ty_inner {
2453                        self.write_array_size(module, base, size)?;
2454                    }
2455                    write!(self.out, " = ")?;
2456                    self.named_expressions.insert(expr, name);
2457                }
2458                let func_name = &self.names[&NameKey::Function(function)];
2459                write!(self.out, "{func_name}(")?;
2460                for (index, argument) in arguments.iter().enumerate() {
2461                    if index != 0 {
2462                        write!(self.out, ", ")?;
2463                    }
2464                    self.write_expr(module, *argument, func_ctx)?;
2465                }
2466                writeln!(self.out, ");")?
2467            }
2468            Statement::Atomic {
2469                pointer,
2470                ref fun,
2471                value,
2472                result,
2473            } => {
2474                write!(self.out, "{level}")?;
2475                let res_var_info = if let Some(res_handle) = result {
2476                    let name = Baked(res_handle).to_string();
2477                    match func_ctx.info[res_handle].ty {
2478                        proc::TypeResolution::Handle(handle) => self.write_type(module, handle)?,
2479                        proc::TypeResolution::Value(ref value) => {
2480                            self.write_value_type(module, value)?
2481                        }
2482                    };
2483                    write!(self.out, " {name}; ")?;
2484                    self.named_expressions.insert(res_handle, name.clone());
2485                    Some((res_handle, name))
2486                } else {
2487                    None
2488                };
2489                let pointer_space = func_ctx
2490                    .resolve_type(pointer, &module.types)
2491                    .pointer_space()
2492                    .unwrap();
2493                let fun_str = fun.to_hlsl_suffix();
2494                let compare_expr = match *fun {
2495                    crate::AtomicFunction::Exchange { compare: Some(cmp) } => Some(cmp),
2496                    _ => None,
2497                };
2498                match pointer_space {
2499                    crate::AddressSpace::WorkGroup => {
2500                        write!(self.out, "Interlocked{fun_str}(")?;
2501                        self.write_expr(module, pointer, func_ctx)?;
2502                        self.emit_hlsl_atomic_tail(
2503                            module,
2504                            func_ctx,
2505                            fun,
2506                            compare_expr,
2507                            value,
2508                            &res_var_info,
2509                        )?;
2510                    }
2511                    crate::AddressSpace::Storage { .. } => {
2512                        let var_handle = self.fill_access_chain(module, pointer, func_ctx)?;
2513                        let var_name = &self.names[&NameKey::GlobalVariable(var_handle)];
2514                        let width = match func_ctx.resolve_type(value, &module.types) {
2515                            &TypeInner::Scalar(Scalar { width: 8, .. }) => "64",
2516                            _ => "",
2517                        };
2518                        write!(self.out, "{var_name}.Interlocked{fun_str}{width}(")?;
2519                        let chain = mem::take(&mut self.temp_access_chain);
2520                        self.write_storage_address(module, &chain, func_ctx)?;
2521                        self.temp_access_chain = chain;
2522                        self.emit_hlsl_atomic_tail(
2523                            module,
2524                            func_ctx,
2525                            fun,
2526                            compare_expr,
2527                            value,
2528                            &res_var_info,
2529                        )?;
2530                    }
2531                    ref other => {
2532                        return Err(Error::Custom(format!(
2533                            "invalid address space {other:?} for atomic statement"
2534                        )))
2535                    }
2536                }
2537                if let Some(cmp) = compare_expr {
2538                    if let Some(&(_res_handle, ref res_name)) = res_var_info.as_ref() {
2539                        write!(
2540                            self.out,
2541                            "{level}{res_name}.exchanged = ({res_name}.old_value == "
2542                        )?;
2543                        self.write_expr(module, cmp, func_ctx)?;
2544                        writeln!(self.out, ");")?;
2545                    }
2546                }
2547            }
2548            Statement::ImageAtomic {
2549                image,
2550                coordinate,
2551                array_index,
2552                fun,
2553                value,
2554            } => {
2555                write!(self.out, "{level}")?;
2556
2557                let fun_str = fun.to_hlsl_suffix();
2558                write!(self.out, "Interlocked{fun_str}(")?;
2559                self.write_expr(module, image, func_ctx)?;
2560                write!(self.out, "[")?;
2561                self.write_texture_coordinates(
2562                    "int",
2563                    coordinate,
2564                    array_index,
2565                    None,
2566                    module,
2567                    func_ctx,
2568                )?;
2569                write!(self.out, "],")?;
2570
2571                self.write_expr(module, value, func_ctx)?;
2572                writeln!(self.out, ");")?;
2573            }
2574            Statement::WorkGroupUniformLoad { pointer, result } => {
2575                self.write_control_barrier(crate::Barrier::WORK_GROUP, level)?;
2576                write!(self.out, "{level}")?;
2577                let name = Baked(result).to_string();
2578                self.write_named_expr(module, pointer, name, result, func_ctx)?;
2579
2580                self.write_control_barrier(crate::Barrier::WORK_GROUP, level)?;
2581            }
2582            Statement::Switch {
2583                selector,
2584                ref cases,
2585            } => {
2586                self.write_switch(module, func_ctx, level, selector, cases)?;
2587            }
2588            Statement::RayQuery { query, ref fun } => {
2589                // There are three possibilities for a ptr to be:
2590                // 1. A variable
2591                // 2. A function argument
2592                // 3. part of a struct
2593                //
2594                // 2 and 3 are not possible, a ray query (in naga IR)
2595                // is not allowed to be passed into a function, and
2596                // all languages disallow it in a struct (you get fun results if
2597                // you try it :) ).
2598                //
2599                // Therefore, the ray query expression must be a variable.
2600                let crate::Expression::LocalVariable(query_var) = func_ctx.expressions[query]
2601                else {
2602                    unreachable!()
2603                };
2604
2605                let tracker_expr_name = format!(
2606                    "{RAY_QUERY_TRACKER_VARIABLE_PREFIX}{}",
2607                    self.names[&func_ctx.name_key(query_var)]
2608                );
2609
2610                match *fun {
2611                    RayQueryFunction::Initialize {
2612                        acceleration_structure,
2613                        descriptor,
2614                    } => {
2615                        self.write_initialize_function(
2616                            module,
2617                            level,
2618                            query,
2619                            acceleration_structure,
2620                            descriptor,
2621                            &tracker_expr_name,
2622                            func_ctx,
2623                        )?;
2624                    }
2625                    RayQueryFunction::Proceed { result } => {
2626                        self.write_proceed(
2627                            module,
2628                            level,
2629                            query,
2630                            result,
2631                            &tracker_expr_name,
2632                            func_ctx,
2633                        )?;
2634                    }
2635                    RayQueryFunction::GenerateIntersection { hit_t } => {
2636                        self.write_generate_intersection(
2637                            module,
2638                            level,
2639                            query,
2640                            hit_t,
2641                            &tracker_expr_name,
2642                            func_ctx,
2643                        )?;
2644                    }
2645                    RayQueryFunction::ConfirmIntersection => {
2646                        self.write_confirm_intersection(
2647                            module,
2648                            level,
2649                            query,
2650                            &tracker_expr_name,
2651                            func_ctx,
2652                        )?;
2653                    }
2654                    RayQueryFunction::Terminate => {
2655                        self.write_terminate(module, level, query, &tracker_expr_name, func_ctx)?;
2656                    }
2657                }
2658            }
2659            Statement::SubgroupBallot { result, predicate } => {
2660                write!(self.out, "{level}")?;
2661                let name = Baked(result).to_string();
2662                write!(self.out, "const uint4 {name} = ")?;
2663                self.named_expressions.insert(result, name);
2664
2665                write!(self.out, "WaveActiveBallot(")?;
2666                match predicate {
2667                    Some(predicate) => self.write_expr(module, predicate, func_ctx)?,
2668                    None => write!(self.out, "true")?,
2669                }
2670                writeln!(self.out, ");")?;
2671            }
2672            Statement::SubgroupCollectiveOperation {
2673                op,
2674                collective_op,
2675                argument,
2676                result,
2677            } => {
2678                write!(self.out, "{level}")?;
2679                write!(self.out, "const ")?;
2680                let name = Baked(result).to_string();
2681                match func_ctx.info[result].ty {
2682                    proc::TypeResolution::Handle(handle) => self.write_type(module, handle)?,
2683                    proc::TypeResolution::Value(ref value) => {
2684                        self.write_value_type(module, value)?
2685                    }
2686                };
2687                write!(self.out, " {name} = ")?;
2688                self.named_expressions.insert(result, name);
2689
2690                match (collective_op, op) {
2691                    (crate::CollectiveOperation::Reduce, crate::SubgroupOperation::All) => {
2692                        write!(self.out, "WaveActiveAllTrue(")?
2693                    }
2694                    (crate::CollectiveOperation::Reduce, crate::SubgroupOperation::Any) => {
2695                        write!(self.out, "WaveActiveAnyTrue(")?
2696                    }
2697                    (crate::CollectiveOperation::Reduce, crate::SubgroupOperation::Add) => {
2698                        write!(self.out, "WaveActiveSum(")?
2699                    }
2700                    (crate::CollectiveOperation::Reduce, crate::SubgroupOperation::Mul) => {
2701                        write!(self.out, "WaveActiveProduct(")?
2702                    }
2703                    (crate::CollectiveOperation::Reduce, crate::SubgroupOperation::Max) => {
2704                        write!(self.out, "WaveActiveMax(")?
2705                    }
2706                    (crate::CollectiveOperation::Reduce, crate::SubgroupOperation::Min) => {
2707                        write!(self.out, "WaveActiveMin(")?
2708                    }
2709                    (crate::CollectiveOperation::Reduce, crate::SubgroupOperation::And) => {
2710                        write!(self.out, "WaveActiveBitAnd(")?
2711                    }
2712                    (crate::CollectiveOperation::Reduce, crate::SubgroupOperation::Or) => {
2713                        write!(self.out, "WaveActiveBitOr(")?
2714                    }
2715                    (crate::CollectiveOperation::Reduce, crate::SubgroupOperation::Xor) => {
2716                        write!(self.out, "WaveActiveBitXor(")?
2717                    }
2718                    (crate::CollectiveOperation::ExclusiveScan, crate::SubgroupOperation::Add) => {
2719                        write!(self.out, "WavePrefixSum(")?
2720                    }
2721                    (crate::CollectiveOperation::ExclusiveScan, crate::SubgroupOperation::Mul) => {
2722                        write!(self.out, "WavePrefixProduct(")?
2723                    }
2724                    (crate::CollectiveOperation::InclusiveScan, crate::SubgroupOperation::Add) => {
2725                        self.write_expr(module, argument, func_ctx)?;
2726                        write!(self.out, " + WavePrefixSum(")?;
2727                    }
2728                    (crate::CollectiveOperation::InclusiveScan, crate::SubgroupOperation::Mul) => {
2729                        self.write_expr(module, argument, func_ctx)?;
2730                        write!(self.out, " * WavePrefixProduct(")?;
2731                    }
2732                    _ => unimplemented!(),
2733                }
2734                self.write_expr(module, argument, func_ctx)?;
2735                writeln!(self.out, ");")?;
2736            }
2737            Statement::SubgroupGather {
2738                mode,
2739                argument,
2740                result,
2741            } => {
2742                write!(self.out, "{level}")?;
2743                write!(self.out, "const ")?;
2744                let name = Baked(result).to_string();
2745                match func_ctx.info[result].ty {
2746                    proc::TypeResolution::Handle(handle) => self.write_type(module, handle)?,
2747                    proc::TypeResolution::Value(ref value) => {
2748                        self.write_value_type(module, value)?
2749                    }
2750                };
2751                write!(self.out, " {name} = ")?;
2752                self.named_expressions.insert(result, name);
2753                match mode {
2754                    crate::GatherMode::BroadcastFirst => {
2755                        write!(self.out, "WaveReadLaneFirst(")?;
2756                        self.write_expr(module, argument, func_ctx)?;
2757                    }
2758                    crate::GatherMode::QuadBroadcast(index) => {
2759                        write!(self.out, "QuadReadLaneAt(")?;
2760                        self.write_expr(module, argument, func_ctx)?;
2761                        write!(self.out, ", ")?;
2762                        self.write_expr(module, index, func_ctx)?;
2763                    }
2764                    crate::GatherMode::QuadSwap(direction) => {
2765                        match direction {
2766                            crate::Direction::X => {
2767                                write!(self.out, "QuadReadAcrossX(")?;
2768                            }
2769                            crate::Direction::Y => {
2770                                write!(self.out, "QuadReadAcrossY(")?;
2771                            }
2772                            crate::Direction::Diagonal => {
2773                                write!(self.out, "QuadReadAcrossDiagonal(")?;
2774                            }
2775                        }
2776                        self.write_expr(module, argument, func_ctx)?;
2777                    }
2778                    _ => {
2779                        write!(self.out, "WaveReadLaneAt(")?;
2780                        self.write_expr(module, argument, func_ctx)?;
2781                        write!(self.out, ", ")?;
2782                        match mode {
2783                            crate::GatherMode::BroadcastFirst => unreachable!(),
2784                            crate::GatherMode::Broadcast(index)
2785                            | crate::GatherMode::Shuffle(index) => {
2786                                self.write_expr(module, index, func_ctx)?;
2787                            }
2788                            crate::GatherMode::ShuffleDown(index) => {
2789                                write!(self.out, "WaveGetLaneIndex() + ")?;
2790                                self.write_expr(module, index, func_ctx)?;
2791                            }
2792                            crate::GatherMode::ShuffleUp(index) => {
2793                                write!(self.out, "WaveGetLaneIndex() - ")?;
2794                                self.write_expr(module, index, func_ctx)?;
2795                            }
2796                            crate::GatherMode::ShuffleXor(index) => {
2797                                write!(self.out, "WaveGetLaneIndex() ^ ")?;
2798                                self.write_expr(module, index, func_ctx)?;
2799                            }
2800                            crate::GatherMode::QuadBroadcast(_) => unreachable!(),
2801                            crate::GatherMode::QuadSwap(_) => unreachable!(),
2802                        }
2803                    }
2804                }
2805                writeln!(self.out, ");")?;
2806            }
2807            Statement::CooperativeStore { .. } => unimplemented!(),
2808            Statement::RayPipelineFunction(_) => unreachable!(),
2809        }
2810
2811        Ok(())
2812    }
2813
2814    fn write_const_expression(
2815        &mut self,
2816        module: &Module,
2817        expr: Handle<crate::Expression>,
2818        arena: &crate::Arena<crate::Expression>,
2819    ) -> BackendResult {
2820        self.write_possibly_const_expression(module, expr, arena, |writer, expr| {
2821            writer.write_const_expression(module, expr, arena)
2822        })
2823    }
2824
2825    pub(super) fn write_literal(&mut self, literal: crate::Literal) -> BackendResult {
2826        match literal {
2827            crate::Literal::F64(value) => write!(self.out, "{value:?}L")?,
2828            crate::Literal::F32(value) => write!(self.out, "{value:?}")?,
2829            crate::Literal::F16(value) => write!(self.out, "{value:?}h")?,
2830            crate::Literal::U32(value) => write!(self.out, "{value}u")?,
2831            // `-2147483648` is parsed by some compilers as unary negation of
2832            // positive 2147483648, which is too large for an int, causing
2833            // issues for some compilers. Neither DXC nor FXC appear to have
2834            // this problem, but this is not specified and could change. We
2835            // therefore use `-2147483647 - 1` as a precaution.
2836            crate::Literal::I32(value) if value == i32::MIN => {
2837                write!(self.out, "int({} - 1)", value + 1)?
2838            }
2839            // HLSL has no suffix for explicit i32 literals, but not using any suffix
2840            // makes the type ambiguous which prevents overload resolution from
2841            // working. So we explicitly use the int() constructor syntax.
2842            crate::Literal::I32(value) => write!(self.out, "int({value})")?,
2843            crate::Literal::U64(value) => write!(self.out, "{value}uL")?,
2844            // I64 version of the minimum I32 value issue described above.
2845            crate::Literal::I64(value) if value == i64::MIN => {
2846                write!(self.out, "({}L - 1L)", value + 1)?;
2847            }
2848            crate::Literal::I64(value) => write!(self.out, "{value}L")?,
2849            crate::Literal::Bool(value) => write!(self.out, "{value}")?,
2850            crate::Literal::AbstractInt(_) | crate::Literal::AbstractFloat(_) => {
2851                return Err(Error::Custom(
2852                    "Abstract types should not appear in IR presented to backends".into(),
2853                ));
2854            }
2855        }
2856        Ok(())
2857    }
2858
2859    fn write_possibly_const_expression<E>(
2860        &mut self,
2861        module: &Module,
2862        expr: Handle<crate::Expression>,
2863        expressions: &crate::Arena<crate::Expression>,
2864        write_expression: E,
2865    ) -> BackendResult
2866    where
2867        E: Fn(&mut Self, Handle<crate::Expression>) -> BackendResult,
2868    {
2869        use crate::Expression;
2870
2871        match expressions[expr] {
2872            Expression::Literal(literal) => {
2873                self.write_literal(literal)?;
2874            }
2875            Expression::Constant(handle) => {
2876                let constant = &module.constants[handle];
2877                if constant.name.is_some() {
2878                    write!(self.out, "{}", self.names[&NameKey::Constant(handle)])?;
2879                } else {
2880                    self.write_const_expression(module, constant.init, &module.global_expressions)?;
2881                }
2882            }
2883            Expression::ZeroValue(ty) => {
2884                self.write_wrapped_zero_value_function_name(module, WrappedZeroValue { ty })?;
2885                write!(self.out, "()")?;
2886            }
2887            Expression::Compose { ty, ref components } => {
2888                match module.types[ty].inner {
2889                    TypeInner::Struct { .. } | TypeInner::Array { .. } => {
2890                        self.write_wrapped_constructor_function_name(
2891                            module,
2892                            WrappedConstructor { ty },
2893                        )?;
2894                    }
2895                    _ => {
2896                        self.write_type(module, ty)?;
2897                    }
2898                };
2899                write!(self.out, "(")?;
2900                for (index, component) in components.iter().enumerate() {
2901                    if index != 0 {
2902                        write!(self.out, ", ")?;
2903                    }
2904                    write_expression(self, *component)?;
2905                }
2906                write!(self.out, ")")?;
2907            }
2908            Expression::Splat { size, value } => {
2909                // hlsl is not supported one value constructor
2910                // if we write, for example, int4(0), dxc returns error:
2911                // error: too few elements in vector initialization (expected 4 elements, have 1)
2912                let number_of_components = match size {
2913                    crate::VectorSize::Bi => "xx",
2914                    crate::VectorSize::Tri => "xxx",
2915                    crate::VectorSize::Quad => "xxxx",
2916                };
2917                write!(self.out, "(")?;
2918                write_expression(self, value)?;
2919                write!(self.out, ").{number_of_components}")?
2920            }
2921            _ => {
2922                return Err(Error::Override);
2923            }
2924        }
2925
2926        Ok(())
2927    }
2928
2929    /// Helper method to write expressions
2930    ///
2931    /// # Notes
2932    /// Doesn't add any newlines or leading/trailing spaces
2933    pub(super) fn write_expr(
2934        &mut self,
2935        module: &Module,
2936        expr: Handle<crate::Expression>,
2937        func_ctx: &back::FunctionCtx<'_>,
2938    ) -> BackendResult {
2939        use crate::Expression;
2940
2941        // Handle the special semantics of vertex_index/instance_index
2942        let ff_input = if self.options.special_constants_binding.is_some() {
2943            func_ctx.is_fixed_function_input(expr, module)
2944        } else {
2945            None
2946        };
2947        let closing_bracket = match ff_input {
2948            Some(crate::BuiltIn::VertexIndex) => {
2949                write!(self.out, "({SPECIAL_CBUF_VAR}.{SPECIAL_FIRST_VERTEX} + ")?;
2950                ")"
2951            }
2952            Some(crate::BuiltIn::InstanceIndex) => {
2953                write!(self.out, "({SPECIAL_CBUF_VAR}.{SPECIAL_FIRST_INSTANCE} + ",)?;
2954                ")"
2955            }
2956            Some(crate::BuiltIn::NumWorkGroups) => {
2957                // Note: despite their names (`FIRST_VERTEX` and `FIRST_INSTANCE`),
2958                // in compute shaders the special constants contain the number
2959                // of workgroups, which we are using here.
2960                write!(
2961                    self.out,
2962                    "uint3({SPECIAL_CBUF_VAR}.{SPECIAL_FIRST_VERTEX}, {SPECIAL_CBUF_VAR}.{SPECIAL_FIRST_INSTANCE}, {SPECIAL_CBUF_VAR}.{SPECIAL_OTHER})",
2963                )?;
2964                return Ok(());
2965            }
2966            _ => "",
2967        };
2968
2969        if let Some(name) = self.named_expressions.get(&expr) {
2970            write!(self.out, "{name}{closing_bracket}")?;
2971            return Ok(());
2972        }
2973
2974        let expression = &func_ctx.expressions[expr];
2975
2976        match *expression {
2977            Expression::Literal(_)
2978            | Expression::Constant(_)
2979            | Expression::ZeroValue(_)
2980            | Expression::Compose { .. }
2981            | Expression::Splat { .. } => {
2982                self.write_possibly_const_expression(
2983                    module,
2984                    expr,
2985                    func_ctx.expressions,
2986                    |writer, expr| writer.write_expr(module, expr, func_ctx),
2987                )?;
2988            }
2989            Expression::Override(_) => return Err(Error::Override),
2990            // Avoid undefined behaviour for addition, subtraction, and
2991            // multiplication of signed integers by casting operands to
2992            // unsigned, performing the operation, then casting the result back
2993            // to signed.
2994            // TODO(#7109): This relies on the asint()/asuint() functions which only work
2995            // for 32-bit types, so we must find another solution for different bit widths.
2996            Expression::Binary {
2997                op:
2998                    op @ crate::BinaryOperator::Add
2999                    | op @ crate::BinaryOperator::Subtract
3000                    | op @ crate::BinaryOperator::Multiply,
3001                left,
3002                right,
3003            } if matches!(
3004                func_ctx.resolve_type(expr, &module.types).scalar(),
3005                Some(Scalar::I32)
3006            ) =>
3007            {
3008                write!(self.out, "asint(asuint(",)?;
3009                self.write_expr(module, left, func_ctx)?;
3010                write!(self.out, ") {} asuint(", back::binary_operation_str(op))?;
3011                self.write_expr(module, right, func_ctx)?;
3012                write!(self.out, "))")?;
3013            }
3014            // All of the multiplication can be expressed as `mul`,
3015            // except vector * vector, which needs to use the "*" operator.
3016            Expression::Binary {
3017                op: crate::BinaryOperator::Multiply,
3018                left,
3019                right,
3020            } if func_ctx.resolve_type(left, &module.types).is_matrix()
3021                || func_ctx.resolve_type(right, &module.types).is_matrix() =>
3022            {
3023                // We intentionally flip the order of multiplication as our matrices are implicitly transposed.
3024                write!(self.out, "mul(")?;
3025                self.write_expr(module, right, func_ctx)?;
3026                write!(self.out, ", ")?;
3027                self.write_expr(module, left, func_ctx)?;
3028                write!(self.out, ")")?;
3029            }
3030
3031            // WGSL says that floating-point division by zero should return
3032            // infinity. Microsoft's Direct3D 11 functional specification
3033            // (https://microsoft.github.io/DirectX-Specs/d3d/archive/D3D11_3_FunctionalSpec.htm)
3034            // says:
3035            //
3036            //     Divide by 0 produces +/- INF, except 0/0 which results in NaN.
3037            //
3038            // which is what we want. The DXIL specification for the FDiv
3039            // instruction corroborates this:
3040            //
3041            // https://github.com/microsoft/DirectXShaderCompiler/blob/main/docs/DXIL.rst#fdiv
3042            Expression::Binary {
3043                op: crate::BinaryOperator::Divide,
3044                left,
3045                right,
3046            } if matches!(
3047                func_ctx.resolve_type(expr, &module.types).scalar_kind(),
3048                Some(ScalarKind::Sint | ScalarKind::Uint)
3049            ) =>
3050            {
3051                write!(self.out, "{DIV_FUNCTION}(")?;
3052                self.write_expr(module, left, func_ctx)?;
3053                write!(self.out, ", ")?;
3054                self.write_expr(module, right, func_ctx)?;
3055                write!(self.out, ")")?;
3056            }
3057
3058            Expression::Binary {
3059                op: crate::BinaryOperator::Modulo,
3060                left,
3061                right,
3062            } if matches!(
3063                func_ctx.resolve_type(expr, &module.types).scalar_kind(),
3064                Some(ScalarKind::Sint | ScalarKind::Uint | ScalarKind::Float)
3065            ) =>
3066            {
3067                write!(self.out, "{MOD_FUNCTION}(")?;
3068                self.write_expr(module, left, func_ctx)?;
3069                write!(self.out, ", ")?;
3070                self.write_expr(module, right, func_ctx)?;
3071                write!(self.out, ")")?;
3072            }
3073
3074            Expression::Binary { op, left, right } => {
3075                write!(self.out, "(")?;
3076                self.write_expr(module, left, func_ctx)?;
3077                write!(self.out, " {} ", back::binary_operation_str(op))?;
3078                self.write_expr(module, right, func_ctx)?;
3079                write!(self.out, ")")?;
3080            }
3081            Expression::Access { base, index } => {
3082                if let Some(crate::AddressSpace::Storage { .. }) =
3083                    func_ctx.resolve_type(expr, &module.types).pointer_space()
3084                {
3085                    // do nothing, the chain is written on `Load`/`Store`
3086                } else {
3087                    // We use the function __get_col_of_matCx2 here in cases
3088                    // where `base`s type resolves to a matCx2 and is part of a
3089                    // struct member with type of (possibly nested) array of matCx2's.
3090                    //
3091                    // Note that this only works for `Load`s and we handle
3092                    // `Store`s differently in `Statement::Store`.
3093                    if let Some(MatrixType {
3094                        columns,
3095                        rows: crate::VectorSize::Bi,
3096                        width: 4,
3097                    }) = get_inner_matrix_of_struct_array_member(module, base, func_ctx, true)
3098                        .or_else(|| get_global_uniform_matrix(module, base, func_ctx))
3099                    {
3100                        write!(self.out, "__get_col_of_mat{}x2(", columns as u8)?;
3101                        self.write_expr(module, base, func_ctx)?;
3102                        write!(self.out, ", ")?;
3103                        self.write_expr(module, index, func_ctx)?;
3104                        write!(self.out, ")")?;
3105                        return Ok(());
3106                    }
3107
3108                    let resolved = func_ctx.resolve_type(base, &module.types);
3109
3110                    let (indexing_binding_array, non_uniform_qualifier) = match *resolved {
3111                        TypeInner::BindingArray { .. } => {
3112                            let uniformity = &func_ctx.info[index].uniformity;
3113
3114                            (true, uniformity.non_uniform_result.is_some())
3115                        }
3116                        _ => (false, false),
3117                    };
3118
3119                    self.write_expr(module, base, func_ctx)?;
3120
3121                    let array_sampler_info = self.sampler_binding_array_info_from_expression(
3122                        module, func_ctx, base, resolved,
3123                    );
3124
3125                    if let Some(ref info) = array_sampler_info {
3126                        write!(self.out, "{}[", info.sampler_heap_name)?;
3127                    } else {
3128                        write!(self.out, "[")?;
3129                    }
3130
3131                    let needs_bound_check = self.options.restrict_indexing
3132                        && !indexing_binding_array
3133                        && match resolved.pointer_space() {
3134                            Some(
3135                                crate::AddressSpace::Function
3136                                | crate::AddressSpace::Private
3137                                | crate::AddressSpace::WorkGroup
3138                                | crate::AddressSpace::Immediate
3139                                | crate::AddressSpace::TaskPayload
3140                                | crate::AddressSpace::RayPayload
3141                                | crate::AddressSpace::IncomingRayPayload,
3142                            )
3143                            | None => true,
3144                            Some(crate::AddressSpace::Uniform) => {
3145                                // check if BindTarget.restrict_indexing is set, this is used for dynamic buffers
3146                                let var_handle = self.fill_access_chain(module, base, func_ctx)?;
3147                                let bind_target = self
3148                                    .options
3149                                    .resolve_resource_binding(
3150                                        module.global_variables[var_handle]
3151                                            .binding
3152                                            .as_ref()
3153                                            .unwrap(),
3154                                    )
3155                                    .unwrap();
3156                                bind_target.restrict_indexing
3157                            }
3158                            Some(
3159                                crate::AddressSpace::Handle | crate::AddressSpace::Storage { .. },
3160                            ) => unreachable!(),
3161                        };
3162                    // Decide whether this index needs to be clamped to fall within range.
3163                    let restriction_needed = if needs_bound_check {
3164                        index::access_needs_check(
3165                            base,
3166                            index::GuardedIndex::Expression(index),
3167                            module,
3168                            func_ctx.expressions,
3169                            func_ctx.info,
3170                        )
3171                    } else {
3172                        None
3173                    };
3174                    if let Some(limit) = restriction_needed {
3175                        write!(self.out, "min(uint(")?;
3176                        self.write_expr(module, index, func_ctx)?;
3177                        write!(self.out, "), ")?;
3178                        match limit {
3179                            index::IndexableLength::Known(limit) => {
3180                                write!(self.out, "{}u", limit - 1)?;
3181                            }
3182                            index::IndexableLength::Dynamic => unreachable!(),
3183                        }
3184                        write!(self.out, ")")?;
3185                    } else {
3186                        if non_uniform_qualifier {
3187                            write!(self.out, "NonUniformResourceIndex(")?;
3188                        }
3189                        if let Some(ref info) = array_sampler_info {
3190                            write!(
3191                                self.out,
3192                                "{}[{} + ",
3193                                info.sampler_index_buffer_name, info.binding_array_base_index_name,
3194                            )?;
3195                        }
3196                        self.write_expr(module, index, func_ctx)?;
3197                        if array_sampler_info.is_some() {
3198                            write!(self.out, "]")?;
3199                        }
3200                        if non_uniform_qualifier {
3201                            write!(self.out, ")")?;
3202                        }
3203                    }
3204
3205                    write!(self.out, "]")?;
3206                }
3207            }
3208            Expression::AccessIndex { base, index } => {
3209                if let Some(crate::AddressSpace::Storage { .. }) =
3210                    func_ctx.resolve_type(expr, &module.types).pointer_space()
3211                {
3212                    // do nothing, the chain is written on `Load`/`Store`
3213                } else {
3214                    // See if we need to write the matrix column access in a
3215                    // special way since the type of `base` is our special
3216                    // __matCx2 struct.
3217                    if let Some(MatrixType {
3218                        rows: crate::VectorSize::Bi,
3219                        width: 4,
3220                        ..
3221                    }) = get_inner_matrix_of_struct_array_member(module, base, func_ctx, true)
3222                        .or_else(|| get_global_uniform_matrix(module, base, func_ctx))
3223                    {
3224                        self.write_expr(module, base, func_ctx)?;
3225                        write!(self.out, "._{index}")?;
3226                        return Ok(());
3227                    }
3228
3229                    let base_ty_res = &func_ctx.info[base].ty;
3230                    let mut resolved = base_ty_res.inner_with(&module.types);
3231                    let base_ty_handle = match *resolved {
3232                        TypeInner::Pointer { base, .. } => {
3233                            resolved = &module.types[base].inner;
3234                            Some(base)
3235                        }
3236                        _ => base_ty_res.handle(),
3237                    };
3238
3239                    // We treat matrices of the form `matCx2` as a sequence of C `vec2`s.
3240                    // See the module-level block comment in mod.rs for details.
3241                    //
3242                    // We handle matrix reconstruction here for Loads.
3243                    // Stores are handled directly by `Statement::Store`.
3244                    if let TypeInner::Struct { ref members, .. } = *resolved {
3245                        let member = &members[index as usize];
3246
3247                        match module.types[member.ty].inner {
3248                            TypeInner::Matrix {
3249                                rows: crate::VectorSize::Bi,
3250                                ..
3251                            } if member.binding.is_none() => {
3252                                let ty = base_ty_handle.unwrap();
3253                                self.write_wrapped_struct_matrix_get_function_name(
3254                                    WrappedStructMatrixAccess { ty, index },
3255                                )?;
3256                                write!(self.out, "(")?;
3257                                self.write_expr(module, base, func_ctx)?;
3258                                write!(self.out, ")")?;
3259                                return Ok(());
3260                            }
3261                            _ => {}
3262                        }
3263                    }
3264
3265                    let array_sampler_info = self.sampler_binding_array_info_from_expression(
3266                        module, func_ctx, base, resolved,
3267                    );
3268
3269                    if let Some(ref info) = array_sampler_info {
3270                        write!(
3271                            self.out,
3272                            "{}[{}",
3273                            info.sampler_heap_name, info.sampler_index_buffer_name
3274                        )?;
3275                    }
3276
3277                    self.write_expr(module, base, func_ctx)?;
3278
3279                    match *resolved {
3280                        // We specifically lift the ValuePointer to this case. While `[0]` is valid
3281                        // HLSL for any vector behind a value pointer, FXC completely miscompiles
3282                        // it and generates completely nonsensical DXBC.
3283                        //
3284                        // See https://github.com/gfx-rs/naga/issues/2095 for more details.
3285                        TypeInner::Vector { .. } | TypeInner::ValuePointer { .. } => {
3286                            // Write vector access as a swizzle
3287                            write!(self.out, ".{}", back::COMPONENTS[index as usize])?
3288                        }
3289                        TypeInner::Matrix { .. }
3290                        | TypeInner::Array { .. }
3291                        | TypeInner::BindingArray { .. } => {
3292                            if let Some(ref info) = array_sampler_info {
3293                                write!(
3294                                    self.out,
3295                                    "[{} + {index}]",
3296                                    info.binding_array_base_index_name
3297                                )?;
3298                            } else {
3299                                write!(self.out, "[{index}]")?;
3300                            }
3301                        }
3302                        TypeInner::Struct { .. } => {
3303                            // This will never panic in case the type is a `Struct`, this is not true
3304                            // for other types so we can only check while inside this match arm
3305                            let ty = base_ty_handle.unwrap();
3306
3307                            write!(
3308                                self.out,
3309                                ".{}",
3310                                &self.names[&NameKey::StructMember(ty, index)]
3311                            )?
3312                        }
3313                        ref other => return Err(Error::Custom(format!("Cannot index {other:?}"))),
3314                    }
3315
3316                    if array_sampler_info.is_some() {
3317                        write!(self.out, "]")?;
3318                    }
3319                }
3320            }
3321            Expression::FunctionArgument(pos) => {
3322                let ty = func_ctx.resolve_type(expr, &module.types);
3323
3324                // We know that any external texture function argument has been expanded into
3325                // separate consecutive arguments for each plane and the parameters buffer. And we
3326                // also know that external textures can only ever be used as an argument to another
3327                // function. Therefore we can simply emit each of the expanded arguments in a
3328                // consecutive comma-separated list.
3329                if let TypeInner::Image {
3330                    class: crate::ImageClass::External,
3331                    ..
3332                } = *ty
3333                {
3334                    let plane_names = [0, 1, 2].map(|i| {
3335                        &self.names[&func_ctx
3336                            .external_texture_argument_key(pos, ExternalTextureNameKey::Plane(i))]
3337                    });
3338                    let params_name = &self.names[&func_ctx
3339                        .external_texture_argument_key(pos, ExternalTextureNameKey::Params)];
3340                    write!(
3341                        self.out,
3342                        "{}, {}, {}, {}",
3343                        plane_names[0], plane_names[1], plane_names[2], params_name
3344                    )?;
3345                } else {
3346                    let key = func_ctx.argument_key(pos);
3347                    let name = &self.names[&key];
3348                    write!(self.out, "{name}")?;
3349                }
3350            }
3351            Expression::ImageSample {
3352                coordinate,
3353                image,
3354                sampler,
3355                clamp_to_edge: true,
3356                gather: None,
3357                array_index: None,
3358                offset: None,
3359                level: crate::SampleLevel::Zero,
3360                depth_ref: None,
3361            } => {
3362                write!(self.out, "{IMAGE_SAMPLE_BASE_CLAMP_TO_EDGE_FUNCTION}(")?;
3363                self.write_expr(module, image, func_ctx)?;
3364                write!(self.out, ", ")?;
3365                self.write_expr(module, sampler, func_ctx)?;
3366                write!(self.out, ", ")?;
3367                self.write_expr(module, coordinate, func_ctx)?;
3368                write!(self.out, ")")?;
3369            }
3370            Expression::ImageSample {
3371                image,
3372                sampler,
3373                gather,
3374                coordinate,
3375                array_index,
3376                offset,
3377                level,
3378                depth_ref,
3379                clamp_to_edge,
3380            } => {
3381                if clamp_to_edge {
3382                    return Err(Error::Custom(
3383                        "ImageSample::clamp_to_edge should have been validated out".to_string(),
3384                    ));
3385                }
3386
3387                use crate::SampleLevel as Sl;
3388                const COMPONENTS: [&str; 4] = ["", "Green", "Blue", "Alpha"];
3389
3390                let (base_str, component_str) = match gather {
3391                    Some(component) => ("Gather", COMPONENTS[component as usize]),
3392                    None => ("Sample", ""),
3393                };
3394                let cmp_str = match depth_ref {
3395                    Some(_) => "Cmp",
3396                    None => "",
3397                };
3398                let level_str = match level {
3399                    Sl::Zero if gather.is_none() => "LevelZero",
3400                    Sl::Auto | Sl::Zero => "",
3401                    Sl::Exact(_) => "Level",
3402                    Sl::Bias(_) => "Bias",
3403                    Sl::Gradient { .. } => "Grad",
3404                };
3405
3406                self.write_expr(module, image, func_ctx)?;
3407                write!(self.out, ".{base_str}{cmp_str}{component_str}{level_str}(")?;
3408                self.write_expr(module, sampler, func_ctx)?;
3409                write!(self.out, ", ")?;
3410                self.write_texture_coordinates(
3411                    "float",
3412                    coordinate,
3413                    array_index,
3414                    None,
3415                    module,
3416                    func_ctx,
3417                )?;
3418
3419                if let Some(depth_ref) = depth_ref {
3420                    write!(self.out, ", ")?;
3421                    self.write_expr(module, depth_ref, func_ctx)?;
3422                }
3423
3424                match level {
3425                    Sl::Auto | Sl::Zero => {}
3426                    Sl::Exact(expr) => {
3427                        write!(self.out, ", ")?;
3428                        self.write_expr(module, expr, func_ctx)?;
3429                    }
3430                    Sl::Bias(expr) => {
3431                        write!(self.out, ", ")?;
3432                        self.write_expr(module, expr, func_ctx)?;
3433                    }
3434                    Sl::Gradient { x, y } => {
3435                        write!(self.out, ", ")?;
3436                        self.write_expr(module, x, func_ctx)?;
3437                        write!(self.out, ", ")?;
3438                        self.write_expr(module, y, func_ctx)?;
3439                    }
3440                }
3441
3442                if let Some(offset) = offset {
3443                    write!(self.out, ", ")?;
3444                    write!(self.out, "int2(")?; // work around https://github.com/microsoft/DirectXShaderCompiler/issues/5082#issuecomment-1540147807
3445                    self.write_const_expression(module, offset, func_ctx.expressions)?;
3446                    write!(self.out, ")")?;
3447                }
3448
3449                write!(self.out, ")")?;
3450            }
3451            Expression::ImageQuery { image, query } => {
3452                // use wrapped image query function
3453                if let TypeInner::Image {
3454                    dim,
3455                    arrayed,
3456                    class,
3457                } = *func_ctx.resolve_type(image, &module.types)
3458                {
3459                    let wrapped_image_query = WrappedImageQuery {
3460                        dim,
3461                        arrayed,
3462                        class,
3463                        query: query.into(),
3464                    };
3465
3466                    self.write_wrapped_image_query_function_name(wrapped_image_query)?;
3467                    write!(self.out, "(")?;
3468                    // Image always first param
3469                    self.write_expr(module, image, func_ctx)?;
3470                    if let crate::ImageQuery::Size { level: Some(level) } = query {
3471                        write!(self.out, ", ")?;
3472                        self.write_expr(module, level, func_ctx)?;
3473                    }
3474                    write!(self.out, ")")?;
3475                }
3476            }
3477            Expression::ImageLoad {
3478                image,
3479                coordinate,
3480                array_index,
3481                sample,
3482                level,
3483            } => self.write_image_load(
3484                &module,
3485                expr,
3486                func_ctx,
3487                image,
3488                coordinate,
3489                array_index,
3490                sample,
3491                level,
3492            )?,
3493            Expression::GlobalVariable(handle) => {
3494                let global_variable = &module.global_variables[handle];
3495                let ty = &module.types[global_variable.ty].inner;
3496
3497                // In the case of binding arrays of samplers, we need to not write anything
3498                // as the we are in the wrong position to fully write the expression.
3499                //
3500                // The entire writing is done by AccessIndex.
3501                let is_binding_array_of_samplers = match *ty {
3502                    TypeInner::BindingArray { base, .. } => {
3503                        let base_ty = &module.types[base].inner;
3504                        matches!(*base_ty, TypeInner::Sampler { .. })
3505                    }
3506                    _ => false,
3507                };
3508
3509                let is_storage_space =
3510                    matches!(global_variable.space, crate::AddressSpace::Storage { .. });
3511
3512                // Our external texture global variable has been expanded into multiple
3513                // global variables, one for each plane and the parameters buffer.
3514                // External textures can only ever be used as arguments to a function
3515                // call, and we know that an external texture argument to any function
3516                // will have been expanded to separate consecutive arguments for each
3517                // plane and the parameters buffer. Therefore we can simply emit each of
3518                // the expanded global variables in a consecutive comma-separated list.
3519                if let TypeInner::Image {
3520                    class: crate::ImageClass::External,
3521                    ..
3522                } = *ty
3523                {
3524                    let plane_names = [0, 1, 2].map(|i| {
3525                        &self.names[&NameKey::ExternalTextureGlobalVariable(
3526                            handle,
3527                            ExternalTextureNameKey::Plane(i),
3528                        )]
3529                    });
3530                    let params_name = &self.names[&NameKey::ExternalTextureGlobalVariable(
3531                        handle,
3532                        ExternalTextureNameKey::Params,
3533                    )];
3534                    write!(
3535                        self.out,
3536                        "{}, {}, {}, {}",
3537                        plane_names[0], plane_names[1], plane_names[2], params_name
3538                    )?;
3539                } else if !is_binding_array_of_samplers && !is_storage_space {
3540                    let name = &self.names[&NameKey::GlobalVariable(handle)];
3541                    write!(self.out, "{name}")?;
3542                }
3543            }
3544            Expression::LocalVariable(handle) => {
3545                write!(self.out, "{}", self.names[&func_ctx.name_key(handle)])?
3546            }
3547            Expression::Load { pointer } => {
3548                match func_ctx
3549                    .resolve_type(pointer, &module.types)
3550                    .pointer_space()
3551                {
3552                    Some(crate::AddressSpace::Storage { .. }) => {
3553                        let var_handle = self.fill_access_chain(module, pointer, func_ctx)?;
3554                        let result_ty = func_ctx.info[expr].ty.clone();
3555                        self.write_storage_load(module, var_handle, result_ty, func_ctx)?;
3556                    }
3557                    _ => {
3558                        let mut close_paren = false;
3559
3560                        // We cast the value loaded to a native HLSL floatCx2
3561                        // in cases where it is of type:
3562                        //  - __matCx2 or
3563                        //  - a (possibly nested) array of __matCx2's
3564                        if let Some(MatrixType {
3565                            rows: crate::VectorSize::Bi,
3566                            width: 4,
3567                            ..
3568                        }) = get_inner_matrix_of_struct_array_member(
3569                            module, pointer, func_ctx, false,
3570                        )
3571                        .or_else(|| get_inner_matrix_of_global_uniform(module, pointer, func_ctx))
3572                        {
3573                            let mut resolved = func_ctx.resolve_type(pointer, &module.types);
3574                            let ptr_tr = resolved.pointer_base_type();
3575                            if let Some(ptr_ty) =
3576                                ptr_tr.as_ref().map(|tr| tr.inner_with(&module.types))
3577                            {
3578                                resolved = ptr_ty;
3579                            }
3580
3581                            write!(self.out, "((")?;
3582                            if let TypeInner::Array { base, size, .. } = *resolved {
3583                                self.write_type(module, base)?;
3584                                self.write_array_size(module, base, size)?;
3585                            } else {
3586                                self.write_value_type(module, resolved)?;
3587                            }
3588                            write!(self.out, ")")?;
3589                            close_paren = true;
3590                        }
3591
3592                        self.write_expr(module, pointer, func_ctx)?;
3593
3594                        if close_paren {
3595                            write!(self.out, ")")?;
3596                        }
3597                    }
3598                }
3599            }
3600            Expression::Unary { op, expr } => {
3601                // https://docs.microsoft.com/en-us/windows/win32/direct3dhlsl/dx-graphics-hlsl-operators#unary-operators
3602                let op_str = match op {
3603                    crate::UnaryOperator::Negate => {
3604                        match func_ctx.resolve_type(expr, &module.types).scalar() {
3605                            Some(Scalar::I32) => NEG_FUNCTION,
3606                            _ => "-",
3607                        }
3608                    }
3609                    crate::UnaryOperator::LogicalNot => "!",
3610                    crate::UnaryOperator::BitwiseNot => "~",
3611                };
3612                write!(self.out, "{op_str}(")?;
3613                self.write_expr(module, expr, func_ctx)?;
3614                write!(self.out, ")")?;
3615            }
3616            Expression::As {
3617                expr,
3618                kind,
3619                convert,
3620            } => {
3621                let inner = func_ctx.resolve_type(expr, &module.types);
3622                if inner.scalar_kind() == Some(ScalarKind::Float)
3623                    && (kind == ScalarKind::Sint || kind == ScalarKind::Uint)
3624                    && convert.is_some()
3625                {
3626                    // Use helper functions for float to int casts in order to
3627                    // avoid undefined behaviour when value is out of range for
3628                    // the target type.
3629                    let fun_name = match (kind, convert) {
3630                        (ScalarKind::Sint, Some(4)) => F2I32_FUNCTION,
3631                        (ScalarKind::Uint, Some(4)) => F2U32_FUNCTION,
3632                        (ScalarKind::Sint, Some(8)) => F2I64_FUNCTION,
3633                        (ScalarKind::Uint, Some(8)) => F2U64_FUNCTION,
3634                        _ => unreachable!(),
3635                    };
3636                    write!(self.out, "{fun_name}(")?;
3637                    self.write_expr(module, expr, func_ctx)?;
3638                    write!(self.out, ")")?;
3639                } else {
3640                    let close_paren = match convert {
3641                        Some(dst_width) => {
3642                            let scalar = Scalar {
3643                                kind,
3644                                width: dst_width,
3645                            };
3646                            match *inner {
3647                                TypeInner::Vector { size, .. } => {
3648                                    write!(
3649                                        self.out,
3650                                        "{}{}(",
3651                                        scalar.to_hlsl_str()?,
3652                                        common::vector_size_str(size)
3653                                    )?;
3654                                }
3655                                TypeInner::Scalar(_) => {
3656                                    write!(self.out, "{}(", scalar.to_hlsl_str()?,)?;
3657                                }
3658                                TypeInner::Matrix { columns, rows, .. } => {
3659                                    write!(
3660                                        self.out,
3661                                        "{}{}x{}(",
3662                                        scalar.to_hlsl_str()?,
3663                                        common::vector_size_str(columns),
3664                                        common::vector_size_str(rows)
3665                                    )?;
3666                                }
3667                                _ => {
3668                                    return Err(Error::Unimplemented(format!(
3669                                        "write_expr expression::as {inner:?}"
3670                                    )));
3671                                }
3672                            };
3673                            true
3674                        }
3675                        None => {
3676                            if inner.scalar_width() == Some(8) {
3677                                false
3678                            } else {
3679                                write!(self.out, "{}(", kind.to_hlsl_cast(),)?;
3680                                true
3681                            }
3682                        }
3683                    };
3684                    self.write_expr(module, expr, func_ctx)?;
3685                    if close_paren {
3686                        write!(self.out, ")")?;
3687                    }
3688                }
3689            }
3690            Expression::Math {
3691                fun,
3692                arg,
3693                arg1,
3694                arg2,
3695                arg3,
3696            } => {
3697                use crate::MathFunction as Mf;
3698
3699                enum Function {
3700                    Asincosh { is_sin: bool },
3701                    Atanh,
3702                    Pack2x16float,
3703                    Pack2x16snorm,
3704                    Pack2x16unorm,
3705                    Pack4x8snorm,
3706                    Pack4x8unorm,
3707                    Pack4xI8,
3708                    Pack4xU8,
3709                    Pack4xI8Clamp,
3710                    Pack4xU8Clamp,
3711                    Unpack2x16float,
3712                    Unpack2x16snorm,
3713                    Unpack2x16unorm,
3714                    Unpack4x8snorm,
3715                    Unpack4x8unorm,
3716                    Unpack4xI8,
3717                    Unpack4xU8,
3718                    Dot4I8Packed,
3719                    Dot4U8Packed,
3720                    QuantizeToF16,
3721                    Regular(&'static str),
3722                    MissingIntOverload(&'static str),
3723                    MissingIntReturnType(&'static str),
3724                    CountTrailingZeros,
3725                    CountLeadingZeros,
3726                }
3727
3728                let fun = match fun {
3729                    // comparison
3730                    Mf::Abs => match func_ctx.resolve_type(arg, &module.types).scalar() {
3731                        Some(Scalar::I32) => Function::Regular(ABS_FUNCTION),
3732                        _ => Function::Regular("abs"),
3733                    },
3734                    Mf::Min => Function::Regular("min"),
3735                    Mf::Max => Function::Regular("max"),
3736                    Mf::Clamp => Function::Regular("clamp"),
3737                    Mf::Saturate => Function::Regular("saturate"),
3738                    // trigonometry
3739                    Mf::Cos => Function::Regular("cos"),
3740                    Mf::Cosh => Function::Regular("cosh"),
3741                    Mf::Sin => Function::Regular("sin"),
3742                    Mf::Sinh => Function::Regular("sinh"),
3743                    Mf::Tan => Function::Regular("tan"),
3744                    Mf::Tanh => Function::Regular("tanh"),
3745                    Mf::Acos => Function::Regular("acos"),
3746                    Mf::Asin => Function::Regular("asin"),
3747                    Mf::Atan => Function::Regular("atan"),
3748                    Mf::Atan2 => Function::Regular("atan2"),
3749                    Mf::Asinh => Function::Asincosh { is_sin: true },
3750                    Mf::Acosh => Function::Asincosh { is_sin: false },
3751                    Mf::Atanh => Function::Atanh,
3752                    Mf::Radians => Function::Regular("radians"),
3753                    Mf::Degrees => Function::Regular("degrees"),
3754                    // decomposition
3755                    Mf::Ceil => Function::Regular("ceil"),
3756                    Mf::Floor => Function::Regular("floor"),
3757                    Mf::Round => Function::Regular("round"),
3758                    Mf::Fract => Function::Regular("frac"),
3759                    Mf::Trunc => Function::Regular("trunc"),
3760                    Mf::Modf => Function::Regular(MODF_FUNCTION),
3761                    Mf::Frexp => Function::Regular(FREXP_FUNCTION),
3762                    Mf::Ldexp => Function::Regular("ldexp"),
3763                    // exponent
3764                    Mf::Exp => Function::Regular("exp"),
3765                    Mf::Exp2 => Function::Regular("exp2"),
3766                    Mf::Log => Function::Regular("log"),
3767                    Mf::Log2 => Function::Regular("log2"),
3768                    Mf::Pow => Function::Regular("pow"),
3769                    // geometry
3770                    Mf::Dot => Function::Regular("dot"),
3771                    Mf::Dot4I8Packed => Function::Dot4I8Packed,
3772                    Mf::Dot4U8Packed => Function::Dot4U8Packed,
3773                    //Mf::Outer => ,
3774                    Mf::Cross => Function::Regular("cross"),
3775                    Mf::Distance => Function::Regular("distance"),
3776                    Mf::Length => Function::Regular("length"),
3777                    Mf::Normalize => Function::Regular("normalize"),
3778                    Mf::FaceForward => Function::Regular("faceforward"),
3779                    Mf::Reflect => Function::Regular("reflect"),
3780                    Mf::Refract => Function::Regular("refract"),
3781                    // computational
3782                    Mf::Sign => Function::Regular("sign"),
3783                    Mf::Fma => Function::Regular("mad"),
3784                    Mf::Mix => Function::Regular("lerp"),
3785                    Mf::Step => Function::Regular("step"),
3786                    Mf::SmoothStep => Function::Regular("smoothstep"),
3787                    Mf::Sqrt => Function::Regular("sqrt"),
3788                    Mf::InverseSqrt => Function::Regular("rsqrt"),
3789                    //Mf::Inverse =>,
3790                    Mf::Transpose => Function::Regular("transpose"),
3791                    Mf::Determinant => Function::Regular("determinant"),
3792                    Mf::QuantizeToF16 => Function::QuantizeToF16,
3793                    // bits
3794                    Mf::CountTrailingZeros => Function::CountTrailingZeros,
3795                    Mf::CountLeadingZeros => Function::CountLeadingZeros,
3796                    Mf::CountOneBits => Function::MissingIntOverload("countbits"),
3797                    Mf::ReverseBits => Function::MissingIntOverload("reversebits"),
3798                    Mf::FirstTrailingBit => Function::MissingIntReturnType("firstbitlow"),
3799                    Mf::FirstLeadingBit => Function::MissingIntReturnType("firstbithigh"),
3800                    Mf::ExtractBits => Function::Regular(EXTRACT_BITS_FUNCTION),
3801                    Mf::InsertBits => Function::Regular(INSERT_BITS_FUNCTION),
3802                    // Data Packing
3803                    Mf::Pack2x16float => Function::Pack2x16float,
3804                    Mf::Pack2x16snorm => Function::Pack2x16snorm,
3805                    Mf::Pack2x16unorm => Function::Pack2x16unorm,
3806                    Mf::Pack4x8snorm => Function::Pack4x8snorm,
3807                    Mf::Pack4x8unorm => Function::Pack4x8unorm,
3808                    Mf::Pack4xI8 => Function::Pack4xI8,
3809                    Mf::Pack4xU8 => Function::Pack4xU8,
3810                    Mf::Pack4xI8Clamp => Function::Pack4xI8Clamp,
3811                    Mf::Pack4xU8Clamp => Function::Pack4xU8Clamp,
3812                    // Data Unpacking
3813                    Mf::Unpack2x16float => Function::Unpack2x16float,
3814                    Mf::Unpack2x16snorm => Function::Unpack2x16snorm,
3815                    Mf::Unpack2x16unorm => Function::Unpack2x16unorm,
3816                    Mf::Unpack4x8snorm => Function::Unpack4x8snorm,
3817                    Mf::Unpack4x8unorm => Function::Unpack4x8unorm,
3818                    Mf::Unpack4xI8 => Function::Unpack4xI8,
3819                    Mf::Unpack4xU8 => Function::Unpack4xU8,
3820                    _ => return Err(Error::Unimplemented(format!("write_expr_math {fun:?}"))),
3821                };
3822
3823                match fun {
3824                    Function::Asincosh { is_sin } => {
3825                        write!(self.out, "log(")?;
3826                        self.write_expr(module, arg, func_ctx)?;
3827                        write!(self.out, " + sqrt(")?;
3828                        self.write_expr(module, arg, func_ctx)?;
3829                        write!(self.out, " * ")?;
3830                        self.write_expr(module, arg, func_ctx)?;
3831                        match is_sin {
3832                            true => write!(self.out, " + 1.0))")?,
3833                            false => write!(self.out, " - 1.0))")?,
3834                        }
3835                    }
3836                    Function::Atanh => {
3837                        write!(self.out, "0.5 * log((1.0 + ")?;
3838                        self.write_expr(module, arg, func_ctx)?;
3839                        write!(self.out, ") / (1.0 - ")?;
3840                        self.write_expr(module, arg, func_ctx)?;
3841                        write!(self.out, "))")?;
3842                    }
3843                    Function::Pack2x16float => {
3844                        write!(self.out, "(f32tof16(")?;
3845                        self.write_expr(module, arg, func_ctx)?;
3846                        write!(self.out, "[0]) | f32tof16(")?;
3847                        self.write_expr(module, arg, func_ctx)?;
3848                        write!(self.out, "[1]) << 16)")?;
3849                    }
3850                    Function::Pack2x16snorm => {
3851                        let scale = 32767;
3852
3853                        write!(self.out, "uint((int(round(clamp(")?;
3854                        self.write_expr(module, arg, func_ctx)?;
3855                        write!(
3856                            self.out,
3857                            "[0], -1.0, 1.0) * {scale}.0)) & 0xFFFF) | ((int(round(clamp("
3858                        )?;
3859                        self.write_expr(module, arg, func_ctx)?;
3860                        write!(self.out, "[1], -1.0, 1.0) * {scale}.0)) & 0xFFFF) << 16))",)?;
3861                    }
3862                    Function::Pack2x16unorm => {
3863                        let scale = 65535;
3864
3865                        write!(self.out, "(uint(round(clamp(")?;
3866                        self.write_expr(module, arg, func_ctx)?;
3867                        write!(self.out, "[0], 0.0, 1.0) * {scale}.0)) | uint(round(clamp(")?;
3868                        self.write_expr(module, arg, func_ctx)?;
3869                        write!(self.out, "[1], 0.0, 1.0) * {scale}.0)) << 16)")?;
3870                    }
3871                    Function::Pack4x8snorm => {
3872                        let scale = 127;
3873
3874                        write!(self.out, "uint((int(round(clamp(")?;
3875                        self.write_expr(module, arg, func_ctx)?;
3876                        write!(
3877                            self.out,
3878                            "[0], -1.0, 1.0) * {scale}.0)) & 0xFF) | ((int(round(clamp("
3879                        )?;
3880                        self.write_expr(module, arg, func_ctx)?;
3881                        write!(
3882                            self.out,
3883                            "[1], -1.0, 1.0) * {scale}.0)) & 0xFF) << 8) | ((int(round(clamp("
3884                        )?;
3885                        self.write_expr(module, arg, func_ctx)?;
3886                        write!(
3887                            self.out,
3888                            "[2], -1.0, 1.0) * {scale}.0)) & 0xFF) << 16) | ((int(round(clamp("
3889                        )?;
3890                        self.write_expr(module, arg, func_ctx)?;
3891                        write!(self.out, "[3], -1.0, 1.0) * {scale}.0)) & 0xFF) << 24))",)?;
3892                    }
3893                    Function::Pack4x8unorm => {
3894                        let scale = 255;
3895
3896                        write!(self.out, "(uint(round(clamp(")?;
3897                        self.write_expr(module, arg, func_ctx)?;
3898                        write!(self.out, "[0], 0.0, 1.0) * {scale}.0)) | uint(round(clamp(")?;
3899                        self.write_expr(module, arg, func_ctx)?;
3900                        write!(
3901                            self.out,
3902                            "[1], 0.0, 1.0) * {scale}.0)) << 8 | uint(round(clamp("
3903                        )?;
3904                        self.write_expr(module, arg, func_ctx)?;
3905                        write!(
3906                            self.out,
3907                            "[2], 0.0, 1.0) * {scale}.0)) << 16 | uint(round(clamp("
3908                        )?;
3909                        self.write_expr(module, arg, func_ctx)?;
3910                        write!(self.out, "[3], 0.0, 1.0) * {scale}.0)) << 24)")?;
3911                    }
3912                    fun @ (Function::Pack4xI8
3913                    | Function::Pack4xU8
3914                    | Function::Pack4xI8Clamp
3915                    | Function::Pack4xU8Clamp) => {
3916                        let was_signed =
3917                            matches!(fun, Function::Pack4xI8 | Function::Pack4xI8Clamp);
3918                        let clamp_bounds = match fun {
3919                            Function::Pack4xI8Clamp => Some(("-128", "127")),
3920                            Function::Pack4xU8Clamp => Some(("0", "255")),
3921                            _ => None,
3922                        };
3923                        if was_signed {
3924                            write!(self.out, "uint(")?;
3925                        }
3926                        let write_arg = |this: &mut Self| -> BackendResult {
3927                            if let Some((min, max)) = clamp_bounds {
3928                                write!(this.out, "clamp(")?;
3929                                this.write_expr(module, arg, func_ctx)?;
3930                                write!(this.out, ", {min}, {max})")?;
3931                            } else {
3932                                this.write_expr(module, arg, func_ctx)?;
3933                            }
3934                            Ok(())
3935                        };
3936                        write!(self.out, "(")?;
3937                        write_arg(self)?;
3938                        write!(self.out, "[0] & 0xFF) | ((")?;
3939                        write_arg(self)?;
3940                        write!(self.out, "[1] & 0xFF) << 8) | ((")?;
3941                        write_arg(self)?;
3942                        write!(self.out, "[2] & 0xFF) << 16) | ((")?;
3943                        write_arg(self)?;
3944                        write!(self.out, "[3] & 0xFF) << 24)")?;
3945                        if was_signed {
3946                            write!(self.out, ")")?;
3947                        }
3948                    }
3949
3950                    Function::Unpack2x16float => {
3951                        write!(self.out, "float2(f16tof32(")?;
3952                        self.write_expr(module, arg, func_ctx)?;
3953                        write!(self.out, "), f16tof32((")?;
3954                        self.write_expr(module, arg, func_ctx)?;
3955                        write!(self.out, ") >> 16))")?;
3956                    }
3957                    Function::Unpack2x16snorm => {
3958                        let scale = 32767;
3959
3960                        write!(self.out, "(float2(int2(")?;
3961                        self.write_expr(module, arg, func_ctx)?;
3962                        write!(self.out, " << 16, ")?;
3963                        self.write_expr(module, arg, func_ctx)?;
3964                        write!(self.out, ") >> 16) / {scale}.0)")?;
3965                    }
3966                    Function::Unpack2x16unorm => {
3967                        let scale = 65535;
3968
3969                        write!(self.out, "(float2(")?;
3970                        self.write_expr(module, arg, func_ctx)?;
3971                        write!(self.out, " & 0xFFFF, ")?;
3972                        self.write_expr(module, arg, func_ctx)?;
3973                        write!(self.out, " >> 16) / {scale}.0)")?;
3974                    }
3975                    Function::Unpack4x8snorm => {
3976                        let scale = 127;
3977
3978                        write!(self.out, "(float4(int4(")?;
3979                        self.write_expr(module, arg, func_ctx)?;
3980                        write!(self.out, " << 24, ")?;
3981                        self.write_expr(module, arg, func_ctx)?;
3982                        write!(self.out, " << 16, ")?;
3983                        self.write_expr(module, arg, func_ctx)?;
3984                        write!(self.out, " << 8, ")?;
3985                        self.write_expr(module, arg, func_ctx)?;
3986                        write!(self.out, ") >> 24) / {scale}.0)")?;
3987                    }
3988                    Function::Unpack4x8unorm => {
3989                        let scale = 255;
3990
3991                        write!(self.out, "(float4(")?;
3992                        self.write_expr(module, arg, func_ctx)?;
3993                        write!(self.out, " & 0xFF, ")?;
3994                        self.write_expr(module, arg, func_ctx)?;
3995                        write!(self.out, " >> 8 & 0xFF, ")?;
3996                        self.write_expr(module, arg, func_ctx)?;
3997                        write!(self.out, " >> 16 & 0xFF, ")?;
3998                        self.write_expr(module, arg, func_ctx)?;
3999                        write!(self.out, " >> 24) / {scale}.0)")?;
4000                    }
4001                    fun @ (Function::Unpack4xI8 | Function::Unpack4xU8) => {
4002                        write!(self.out, "(")?;
4003                        if matches!(fun, Function::Unpack4xU8) {
4004                            write!(self.out, "u")?;
4005                        }
4006                        write!(self.out, "int4(")?;
4007                        self.write_expr(module, arg, func_ctx)?;
4008                        write!(self.out, ", ")?;
4009                        self.write_expr(module, arg, func_ctx)?;
4010                        write!(self.out, " >> 8, ")?;
4011                        self.write_expr(module, arg, func_ctx)?;
4012                        write!(self.out, " >> 16, ")?;
4013                        self.write_expr(module, arg, func_ctx)?;
4014                        write!(self.out, " >> 24) << 24 >> 24)")?;
4015                    }
4016                    fun @ (Function::Dot4I8Packed | Function::Dot4U8Packed) => {
4017                        let arg1 = arg1.unwrap();
4018
4019                        if self.options.shader_model >= ShaderModel::V6_4 {
4020                            // Intrinsics `dot4add_{i, u}8packed` are available in SM 6.4 and later.
4021                            let function_name = match fun {
4022                                Function::Dot4I8Packed => "dot4add_i8packed",
4023                                Function::Dot4U8Packed => "dot4add_u8packed",
4024                                _ => unreachable!(),
4025                            };
4026                            write!(self.out, "{function_name}(")?;
4027                            self.write_expr(module, arg, func_ctx)?;
4028                            write!(self.out, ", ")?;
4029                            self.write_expr(module, arg1, func_ctx)?;
4030                            write!(self.out, ", 0)")?;
4031                        } else {
4032                            // Fall back to a polyfill as `dot4add_u8packed` is not available.
4033                            write!(self.out, "dot(")?;
4034
4035                            if matches!(fun, Function::Dot4U8Packed) {
4036                                write!(self.out, "u")?;
4037                            }
4038                            write!(self.out, "int4(")?;
4039                            self.write_expr(module, arg, func_ctx)?;
4040                            write!(self.out, ", ")?;
4041                            self.write_expr(module, arg, func_ctx)?;
4042                            write!(self.out, " >> 8, ")?;
4043                            self.write_expr(module, arg, func_ctx)?;
4044                            write!(self.out, " >> 16, ")?;
4045                            self.write_expr(module, arg, func_ctx)?;
4046                            write!(self.out, " >> 24) << 24 >> 24, ")?;
4047
4048                            if matches!(fun, Function::Dot4U8Packed) {
4049                                write!(self.out, "u")?;
4050                            }
4051                            write!(self.out, "int4(")?;
4052                            self.write_expr(module, arg1, func_ctx)?;
4053                            write!(self.out, ", ")?;
4054                            self.write_expr(module, arg1, func_ctx)?;
4055                            write!(self.out, " >> 8, ")?;
4056                            self.write_expr(module, arg1, func_ctx)?;
4057                            write!(self.out, " >> 16, ")?;
4058                            self.write_expr(module, arg1, func_ctx)?;
4059                            write!(self.out, " >> 24) << 24 >> 24)")?;
4060                        }
4061                    }
4062                    Function::QuantizeToF16 => {
4063                        write!(self.out, "f16tof32(f32tof16(")?;
4064                        self.write_expr(module, arg, func_ctx)?;
4065                        write!(self.out, "))")?;
4066                    }
4067                    Function::Regular(fun_name) => {
4068                        write!(self.out, "{fun_name}(")?;
4069                        self.write_expr(module, arg, func_ctx)?;
4070                        if let Some(arg) = arg1 {
4071                            write!(self.out, ", ")?;
4072                            self.write_expr(module, arg, func_ctx)?;
4073                        }
4074                        if let Some(arg) = arg2 {
4075                            write!(self.out, ", ")?;
4076                            self.write_expr(module, arg, func_ctx)?;
4077                        }
4078                        if let Some(arg) = arg3 {
4079                            write!(self.out, ", ")?;
4080                            self.write_expr(module, arg, func_ctx)?;
4081                        }
4082                        write!(self.out, ")")?
4083                    }
4084                    // These overloads are only missing on FXC, so this is only needed for 32bit types,
4085                    // as non-32bit types are DXC only.
4086                    Function::MissingIntOverload(fun_name) => {
4087                        let scalar_kind = func_ctx.resolve_type(arg, &module.types).scalar();
4088                        if let Some(Scalar::I32) = scalar_kind {
4089                            write!(self.out, "asint({fun_name}(asuint(")?;
4090                            self.write_expr(module, arg, func_ctx)?;
4091                            write!(self.out, ")))")?;
4092                        } else {
4093                            write!(self.out, "{fun_name}(")?;
4094                            self.write_expr(module, arg, func_ctx)?;
4095                            write!(self.out, ")")?;
4096                        }
4097                    }
4098                    // These overloads are only missing on FXC, so this is only needed for 32bit types,
4099                    // as non-32bit types are DXC only.
4100                    Function::MissingIntReturnType(fun_name) => {
4101                        let scalar_kind = func_ctx.resolve_type(arg, &module.types).scalar();
4102                        if let Some(Scalar::I32) = scalar_kind {
4103                            write!(self.out, "asint({fun_name}(")?;
4104                            self.write_expr(module, arg, func_ctx)?;
4105                            write!(self.out, "))")?;
4106                        } else {
4107                            write!(self.out, "{fun_name}(")?;
4108                            self.write_expr(module, arg, func_ctx)?;
4109                            write!(self.out, ")")?;
4110                        }
4111                    }
4112                    Function::CountTrailingZeros => {
4113                        match *func_ctx.resolve_type(arg, &module.types) {
4114                            TypeInner::Vector { size, scalar } => {
4115                                let s = match size {
4116                                    crate::VectorSize::Bi => ".xx",
4117                                    crate::VectorSize::Tri => ".xxx",
4118                                    crate::VectorSize::Quad => ".xxxx",
4119                                };
4120
4121                                let scalar_width_bits = scalar.width * 8;
4122
4123                                if scalar.kind == ScalarKind::Uint || scalar.width != 4 {
4124                                    write!(
4125                                        self.out,
4126                                        "min(({scalar_width_bits}u){s}, firstbitlow("
4127                                    )?;
4128                                    self.write_expr(module, arg, func_ctx)?;
4129                                    write!(self.out, "))")?;
4130                                } else {
4131                                    // This is only needed for the FXC path, on 32bit signed integers.
4132                                    write!(
4133                                        self.out,
4134                                        "asint(min(({scalar_width_bits}u){s}, firstbitlow("
4135                                    )?;
4136                                    self.write_expr(module, arg, func_ctx)?;
4137                                    write!(self.out, ")))")?;
4138                                }
4139                            }
4140                            TypeInner::Scalar(scalar) => {
4141                                let scalar_width_bits = scalar.width * 8;
4142
4143                                if scalar.kind == ScalarKind::Uint || scalar.width != 4 {
4144                                    write!(self.out, "min({scalar_width_bits}u, firstbitlow(")?;
4145                                    self.write_expr(module, arg, func_ctx)?;
4146                                    write!(self.out, "))")?;
4147                                } else {
4148                                    // This is only needed for the FXC path, on 32bit signed integers.
4149                                    write!(
4150                                        self.out,
4151                                        "asint(min({scalar_width_bits}u, firstbitlow("
4152                                    )?;
4153                                    self.write_expr(module, arg, func_ctx)?;
4154                                    write!(self.out, ")))")?;
4155                                }
4156                            }
4157                            _ => unreachable!(),
4158                        }
4159
4160                        return Ok(());
4161                    }
4162                    Function::CountLeadingZeros => {
4163                        match *func_ctx.resolve_type(arg, &module.types) {
4164                            TypeInner::Vector { size, scalar } => {
4165                                let s = match size {
4166                                    crate::VectorSize::Bi => ".xx",
4167                                    crate::VectorSize::Tri => ".xxx",
4168                                    crate::VectorSize::Quad => ".xxxx",
4169                                };
4170
4171                                // scalar width - 1
4172                                let constant = scalar.width * 8 - 1;
4173
4174                                if scalar.kind == ScalarKind::Uint {
4175                                    write!(self.out, "(({constant}u){s} - firstbithigh(")?;
4176                                    self.write_expr(module, arg, func_ctx)?;
4177                                    write!(self.out, "))")?;
4178                                } else {
4179                                    let conversion_func = match scalar.width {
4180                                        4 => "asint",
4181                                        _ => "",
4182                                    };
4183                                    write!(self.out, "(")?;
4184                                    self.write_expr(module, arg, func_ctx)?;
4185                                    write!(
4186                                        self.out,
4187                                        " < (0){s} ? (0){s} : ({constant}){s} - {conversion_func}(firstbithigh("
4188                                    )?;
4189                                    self.write_expr(module, arg, func_ctx)?;
4190                                    write!(self.out, ")))")?;
4191                                }
4192                            }
4193                            TypeInner::Scalar(scalar) => {
4194                                // scalar width - 1
4195                                let constant = scalar.width * 8 - 1;
4196
4197                                if let ScalarKind::Uint = scalar.kind {
4198                                    write!(self.out, "({constant}u - firstbithigh(")?;
4199                                    self.write_expr(module, arg, func_ctx)?;
4200                                    write!(self.out, "))")?;
4201                                } else {
4202                                    let conversion_func = match scalar.width {
4203                                        4 => "asint",
4204                                        _ => "",
4205                                    };
4206                                    write!(self.out, "(")?;
4207                                    self.write_expr(module, arg, func_ctx)?;
4208                                    write!(
4209                                        self.out,
4210                                        " < 0 ? 0 : {constant} - {conversion_func}(firstbithigh("
4211                                    )?;
4212                                    self.write_expr(module, arg, func_ctx)?;
4213                                    write!(self.out, ")))")?;
4214                                }
4215                            }
4216                            _ => unreachable!(),
4217                        }
4218
4219                        return Ok(());
4220                    }
4221                }
4222            }
4223            Expression::Swizzle {
4224                size,
4225                vector,
4226                pattern,
4227            } => {
4228                self.write_expr(module, vector, func_ctx)?;
4229                write!(self.out, ".")?;
4230                for &sc in pattern[..size as usize].iter() {
4231                    self.out.write_char(back::COMPONENTS[sc as usize])?;
4232                }
4233            }
4234            Expression::ArrayLength(expr) => {
4235                let var_handle = match func_ctx.expressions[expr] {
4236                    Expression::AccessIndex { base, index: _ } => {
4237                        match func_ctx.expressions[base] {
4238                            Expression::GlobalVariable(handle) => handle,
4239                            _ => unreachable!(),
4240                        }
4241                    }
4242                    Expression::GlobalVariable(handle) => handle,
4243                    _ => unreachable!(),
4244                };
4245
4246                let var = &module.global_variables[var_handle];
4247                let (offset, stride) = match module.types[var.ty].inner {
4248                    TypeInner::Array { stride, .. } => (0, stride),
4249                    TypeInner::Struct { ref members, .. } => {
4250                        let last = members.last().unwrap();
4251                        let stride = match module.types[last.ty].inner {
4252                            TypeInner::Array { stride, .. } => stride,
4253                            _ => unreachable!(),
4254                        };
4255                        (last.offset, stride)
4256                    }
4257                    _ => unreachable!(),
4258                };
4259
4260                let storage_access = match var.space {
4261                    crate::AddressSpace::Storage { access } => access,
4262                    _ => crate::StorageAccess::default(),
4263                };
4264                let wrapped_array_length = WrappedArrayLength {
4265                    writable: storage_access.contains(crate::StorageAccess::STORE),
4266                };
4267
4268                write!(self.out, "((")?;
4269                self.write_wrapped_array_length_function_name(wrapped_array_length)?;
4270                let var_name = &self.names[&NameKey::GlobalVariable(var_handle)];
4271                write!(self.out, "({var_name}) - {offset}) / {stride})")?
4272            }
4273            Expression::Derivative { axis, ctrl, expr } => {
4274                use crate::{DerivativeAxis as Axis, DerivativeControl as Ctrl};
4275                if axis == Axis::Width && (ctrl == Ctrl::Coarse || ctrl == Ctrl::Fine) {
4276                    let tail = match ctrl {
4277                        Ctrl::Coarse => "coarse",
4278                        Ctrl::Fine => "fine",
4279                        Ctrl::None => unreachable!(),
4280                    };
4281                    write!(self.out, "abs(ddx_{tail}(")?;
4282                    self.write_expr(module, expr, func_ctx)?;
4283                    write!(self.out, ")) + abs(ddy_{tail}(")?;
4284                    self.write_expr(module, expr, func_ctx)?;
4285                    write!(self.out, "))")?
4286                } else {
4287                    let fun_str = match (axis, ctrl) {
4288                        (Axis::X, Ctrl::Coarse) => "ddx_coarse",
4289                        (Axis::X, Ctrl::Fine) => "ddx_fine",
4290                        (Axis::X, Ctrl::None) => "ddx",
4291                        (Axis::Y, Ctrl::Coarse) => "ddy_coarse",
4292                        (Axis::Y, Ctrl::Fine) => "ddy_fine",
4293                        (Axis::Y, Ctrl::None) => "ddy",
4294                        (Axis::Width, Ctrl::Coarse | Ctrl::Fine) => unreachable!(),
4295                        (Axis::Width, Ctrl::None) => "fwidth",
4296                    };
4297                    write!(self.out, "{fun_str}(")?;
4298                    self.write_expr(module, expr, func_ctx)?;
4299                    write!(self.out, ")")?
4300                }
4301            }
4302            Expression::Relational { fun, argument } => {
4303                use crate::RelationalFunction as Rf;
4304
4305                let fun_str = match fun {
4306                    Rf::All => "all",
4307                    Rf::Any => "any",
4308                    Rf::IsNan => "isnan",
4309                    Rf::IsInf => "isinf",
4310                };
4311                write!(self.out, "{fun_str}(")?;
4312                self.write_expr(module, argument, func_ctx)?;
4313                write!(self.out, ")")?
4314            }
4315            Expression::Select {
4316                condition,
4317                accept,
4318                reject,
4319            } => {
4320                write!(self.out, "(")?;
4321                self.write_expr(module, condition, func_ctx)?;
4322                write!(self.out, " ? ")?;
4323                self.write_expr(module, accept, func_ctx)?;
4324                write!(self.out, " : ")?;
4325                self.write_expr(module, reject, func_ctx)?;
4326                write!(self.out, ")")?
4327            }
4328            Expression::RayQueryGetIntersection { query, committed } => {
4329                // For reasoning, see write_stmt
4330                let Expression::LocalVariable(query_var) = func_ctx.expressions[query] else {
4331                    unreachable!()
4332                };
4333
4334                let tracker_expr_name = format!(
4335                    "{RAY_QUERY_TRACKER_VARIABLE_PREFIX}{}",
4336                    self.names[&func_ctx.name_key(query_var)]
4337                );
4338
4339                if committed {
4340                    write!(self.out, "GetCommittedIntersection(")?;
4341                    self.write_expr(module, query, func_ctx)?;
4342                    write!(self.out, ", {tracker_expr_name})")?;
4343                } else {
4344                    write!(self.out, "GetCandidateIntersection(")?;
4345                    self.write_expr(module, query, func_ctx)?;
4346                    write!(self.out, ", {tracker_expr_name})")?;
4347                }
4348            }
4349            // Not supported yet
4350            Expression::RayQueryVertexPositions { .. }
4351            | Expression::CooperativeLoad { .. }
4352            | Expression::CooperativeMultiplyAdd { .. } => {
4353                unreachable!()
4354            }
4355            // Nothing to do here, since call expression already cached
4356            Expression::CallResult(_)
4357            | Expression::AtomicResult { .. }
4358            | Expression::WorkGroupUniformLoadResult { .. }
4359            | Expression::RayQueryProceedResult
4360            | Expression::SubgroupBallotResult
4361            | Expression::SubgroupOperationResult { .. } => {}
4362        }
4363
4364        if !closing_bracket.is_empty() {
4365            write!(self.out, "{closing_bracket}")?;
4366        }
4367        Ok(())
4368    }
4369
4370    #[allow(clippy::too_many_arguments)]
4371    fn write_image_load(
4372        &mut self,
4373        module: &&Module,
4374        expr: Handle<crate::Expression>,
4375        func_ctx: &back::FunctionCtx,
4376        image: Handle<crate::Expression>,
4377        coordinate: Handle<crate::Expression>,
4378        array_index: Option<Handle<crate::Expression>>,
4379        sample: Option<Handle<crate::Expression>>,
4380        level: Option<Handle<crate::Expression>>,
4381    ) -> Result<(), Error> {
4382        let mut wrapping_type = None;
4383        match *func_ctx.resolve_type(image, &module.types) {
4384            TypeInner::Image {
4385                class: crate::ImageClass::External,
4386                ..
4387            } => {
4388                write!(self.out, "{IMAGE_LOAD_EXTERNAL_FUNCTION}(")?;
4389                self.write_expr(module, image, func_ctx)?;
4390                write!(self.out, ", ")?;
4391                self.write_expr(module, coordinate, func_ctx)?;
4392                write!(self.out, ")")?;
4393                return Ok(());
4394            }
4395            TypeInner::Image {
4396                class: crate::ImageClass::Storage { format, .. },
4397                ..
4398            } => {
4399                if format.single_component() {
4400                    wrapping_type = Some(Scalar::from(format));
4401                }
4402            }
4403            _ => {}
4404        }
4405        if let Some(scalar) = wrapping_type {
4406            write!(
4407                self.out,
4408                "{}{}(",
4409                help::IMAGE_STORAGE_LOAD_SCALAR_WRAPPER,
4410                scalar.to_hlsl_str()?
4411            )?;
4412        }
4413        // https://docs.microsoft.com/en-us/windows/win32/direct3dhlsl/dx-graphics-hlsl-to-load
4414        self.write_expr(module, image, func_ctx)?;
4415        write!(self.out, ".Load(")?;
4416
4417        self.write_texture_coordinates("int", coordinate, array_index, level, module, func_ctx)?;
4418
4419        if let Some(sample) = sample {
4420            write!(self.out, ", ")?;
4421            self.write_expr(module, sample, func_ctx)?;
4422        }
4423
4424        // close bracket for Load function
4425        write!(self.out, ")")?;
4426
4427        if wrapping_type.is_some() {
4428            write!(self.out, ")")?;
4429        }
4430
4431        // return x component if return type is scalar
4432        if let TypeInner::Scalar(_) = *func_ctx.resolve_type(expr, &module.types) {
4433            write!(self.out, ".x")?;
4434        }
4435        Ok(())
4436    }
4437
4438    /// Find the [`BindingArraySamplerInfo`] from an expression so that such an access
4439    /// can be generated later.
4440    fn sampler_binding_array_info_from_expression(
4441        &mut self,
4442        module: &Module,
4443        func_ctx: &back::FunctionCtx<'_>,
4444        base: Handle<crate::Expression>,
4445        resolved: &TypeInner,
4446    ) -> Option<BindingArraySamplerInfo> {
4447        if let TypeInner::BindingArray {
4448            base: base_ty_handle,
4449            ..
4450        } = *resolved
4451        {
4452            let base_ty = &module.types[base_ty_handle].inner;
4453            if let TypeInner::Sampler { comparison, .. } = *base_ty {
4454                let base = &func_ctx.expressions[base];
4455
4456                if let crate::Expression::GlobalVariable(handle) = *base {
4457                    let variable = &module.global_variables[handle];
4458
4459                    let sampler_heap_name = match comparison {
4460                        true => COMPARISON_SAMPLER_HEAP_VAR,
4461                        false => SAMPLER_HEAP_VAR,
4462                    };
4463
4464                    return Some(BindingArraySamplerInfo {
4465                        sampler_heap_name,
4466                        sampler_index_buffer_name: self
4467                            .wrapped
4468                            .sampler_index_buffers
4469                            .get(&super::SamplerIndexBufferKey {
4470                                group: variable.binding.unwrap().group,
4471                            })
4472                            .unwrap()
4473                            .clone(),
4474                        binding_array_base_index_name: self.names[&NameKey::GlobalVariable(handle)]
4475                            .clone(),
4476                    });
4477                }
4478            }
4479        }
4480
4481        None
4482    }
4483
4484    fn write_named_expr(
4485        &mut self,
4486        module: &Module,
4487        handle: Handle<crate::Expression>,
4488        name: String,
4489        // The expression which is being named.
4490        // Generally, this is the same as handle, except in WorkGroupUniformLoad
4491        named: Handle<crate::Expression>,
4492        ctx: &back::FunctionCtx,
4493    ) -> BackendResult {
4494        match ctx.info[named].ty {
4495            proc::TypeResolution::Handle(ty_handle) => match module.types[ty_handle].inner {
4496                TypeInner::Struct { .. } => {
4497                    let ty_name = &self.names[&NameKey::Type(ty_handle)];
4498                    write!(self.out, "{ty_name}")?;
4499                }
4500                _ => {
4501                    self.write_type(module, ty_handle)?;
4502                }
4503            },
4504            proc::TypeResolution::Value(ref inner) => {
4505                self.write_value_type(module, inner)?;
4506            }
4507        }
4508
4509        let resolved = ctx.resolve_type(named, &module.types);
4510
4511        write!(self.out, " {name}")?;
4512        // If rhs is a array type, we should write array size
4513        if let TypeInner::Array { base, size, .. } = *resolved {
4514            self.write_array_size(module, base, size)?;
4515        }
4516        write!(self.out, " = ")?;
4517        self.write_expr(module, handle, ctx)?;
4518        writeln!(self.out, ";")?;
4519        self.named_expressions.insert(named, name);
4520
4521        Ok(())
4522    }
4523
4524    /// Helper function that write default zero initialization
4525    pub(super) fn write_default_init(
4526        &mut self,
4527        module: &Module,
4528        ty: Handle<crate::Type>,
4529    ) -> BackendResult {
4530        write!(self.out, "(")?;
4531        self.write_type(module, ty)?;
4532        if let TypeInner::Array { base, size, .. } = module.types[ty].inner {
4533            self.write_array_size(module, base, size)?;
4534        }
4535        write!(self.out, ")0")?;
4536        Ok(())
4537    }
4538
4539    fn write_control_barrier(
4540        &mut self,
4541        barrier: crate::Barrier,
4542        level: back::Level,
4543    ) -> BackendResult {
4544        if barrier.contains(crate::Barrier::STORAGE) {
4545            writeln!(self.out, "{level}DeviceMemoryBarrierWithGroupSync();")?;
4546        }
4547        if barrier.contains(crate::Barrier::WORK_GROUP) {
4548            writeln!(self.out, "{level}GroupMemoryBarrierWithGroupSync();")?;
4549        }
4550        if barrier.contains(crate::Barrier::SUB_GROUP) {
4551            // Does not exist in DirectX
4552        }
4553        if barrier.contains(crate::Barrier::TEXTURE) {
4554            writeln!(self.out, "{level}DeviceMemoryBarrierWithGroupSync();")?;
4555        }
4556        Ok(())
4557    }
4558
4559    fn write_memory_barrier(
4560        &mut self,
4561        barrier: crate::Barrier,
4562        level: back::Level,
4563    ) -> BackendResult {
4564        if barrier.contains(crate::Barrier::STORAGE) {
4565            writeln!(self.out, "{level}DeviceMemoryBarrier();")?;
4566        }
4567        if barrier.contains(crate::Barrier::WORK_GROUP) {
4568            writeln!(self.out, "{level}GroupMemoryBarrier();")?;
4569        }
4570        if barrier.contains(crate::Barrier::SUB_GROUP) {
4571            // Does not exist in DirectX
4572        }
4573        if barrier.contains(crate::Barrier::TEXTURE) {
4574            writeln!(self.out, "{level}DeviceMemoryBarrier();")?;
4575        }
4576        Ok(())
4577    }
4578
4579    /// Helper to emit the shared tail of an HLSL atomic call (arguments, value, result)
4580    fn emit_hlsl_atomic_tail(
4581        &mut self,
4582        module: &Module,
4583        func_ctx: &back::FunctionCtx<'_>,
4584        fun: &crate::AtomicFunction,
4585        compare_expr: Option<Handle<crate::Expression>>,
4586        value: Handle<crate::Expression>,
4587        res_var_info: &Option<(Handle<crate::Expression>, String)>,
4588    ) -> BackendResult {
4589        if let Some(cmp) = compare_expr {
4590            write!(self.out, ", ")?;
4591            self.write_expr(module, cmp, func_ctx)?;
4592        }
4593        write!(self.out, ", ")?;
4594        if let crate::AtomicFunction::Subtract = *fun {
4595            // we just wrote `InterlockedAdd`, so negate the argument
4596            write!(self.out, "-")?;
4597        }
4598        self.write_expr(module, value, func_ctx)?;
4599        if let Some(&(_res_handle, ref res_name)) = res_var_info.as_ref() {
4600            write!(self.out, ", ")?;
4601            if compare_expr.is_some() {
4602                write!(self.out, "{res_name}.old_value")?;
4603            } else {
4604                write!(self.out, "{res_name}")?;
4605            }
4606        }
4607        writeln!(self.out, ");")?;
4608        Ok(())
4609    }
4610}
4611
4612pub(super) struct MatrixType {
4613    pub(super) columns: crate::VectorSize,
4614    pub(super) rows: crate::VectorSize,
4615    pub(super) width: crate::Bytes,
4616}
4617
4618pub(super) fn get_inner_matrix_data(
4619    module: &Module,
4620    handle: Handle<crate::Type>,
4621) -> Option<MatrixType> {
4622    match module.types[handle].inner {
4623        TypeInner::Matrix {
4624            columns,
4625            rows,
4626            scalar,
4627        } => Some(MatrixType {
4628            columns,
4629            rows,
4630            width: scalar.width,
4631        }),
4632        TypeInner::Array { base, .. } => get_inner_matrix_data(module, base),
4633        _ => None,
4634    }
4635}
4636
4637/// If `base` is an access chain of the form `mat`, `mat[col]`, or `mat[col][row]`,
4638/// returns a tuple of the matrix, the column (vector) index (if present), and
4639/// the row (scalar) index (if present).
4640fn find_matrix_in_access_chain(
4641    module: &Module,
4642    base: Handle<crate::Expression>,
4643    func_ctx: &back::FunctionCtx<'_>,
4644) -> Option<(Handle<crate::Expression>, Option<Index>, Option<Index>)> {
4645    let mut current_base = base;
4646    let mut vector = None;
4647    let mut scalar = None;
4648    loop {
4649        let resolved_tr = func_ctx
4650            .resolve_type(current_base, &module.types)
4651            .pointer_base_type();
4652        let resolved = resolved_tr.as_ref()?.inner_with(&module.types);
4653
4654        match *resolved {
4655            TypeInner::Matrix { .. } => return Some((current_base, vector, scalar)),
4656            TypeInner::Scalar(_) | TypeInner::Vector { .. } => {}
4657            _ => return None,
4658        }
4659
4660        let index;
4661        (current_base, index) = match func_ctx.expressions[current_base] {
4662            crate::Expression::Access { base, index } => (base, Index::Expression(index)),
4663            crate::Expression::AccessIndex { base, index } => (base, Index::Static(index)),
4664            _ => return None,
4665        };
4666
4667        match *resolved {
4668            TypeInner::Scalar(_) => scalar = Some(index),
4669            TypeInner::Vector { .. } => vector = Some(index),
4670            _ => unreachable!(),
4671        }
4672    }
4673}
4674
4675/// Returns the matrix data if the access chain starting at `base`:
4676/// - starts with an expression with resolved type of [`TypeInner::Matrix`] if `direct = true`
4677/// - contains one or more expressions with resolved type of [`TypeInner::Array`] of [`TypeInner::Matrix`]
4678/// - ends at an expression with resolved type of [`TypeInner::Struct`]
4679pub(super) fn get_inner_matrix_of_struct_array_member(
4680    module: &Module,
4681    base: Handle<crate::Expression>,
4682    func_ctx: &back::FunctionCtx<'_>,
4683    direct: bool,
4684) -> Option<MatrixType> {
4685    let mut mat_data = None;
4686    let mut array_base = None;
4687
4688    let mut current_base = base;
4689    loop {
4690        let mut resolved = func_ctx.resolve_type(current_base, &module.types);
4691        if let TypeInner::Pointer { base, .. } = *resolved {
4692            resolved = &module.types[base].inner;
4693        };
4694
4695        match *resolved {
4696            TypeInner::Matrix {
4697                columns,
4698                rows,
4699                scalar,
4700            } => {
4701                mat_data = Some(MatrixType {
4702                    columns,
4703                    rows,
4704                    width: scalar.width,
4705                })
4706            }
4707            TypeInner::Array { base, .. } => {
4708                array_base = Some(base);
4709            }
4710            TypeInner::Struct { .. } => {
4711                if let Some(array_base) = array_base {
4712                    if direct {
4713                        return mat_data;
4714                    } else {
4715                        return get_inner_matrix_data(module, array_base);
4716                    }
4717                }
4718
4719                break;
4720            }
4721            _ => break,
4722        }
4723
4724        current_base = match func_ctx.expressions[current_base] {
4725            crate::Expression::Access { base, .. } => base,
4726            crate::Expression::AccessIndex { base, .. } => base,
4727            _ => break,
4728        };
4729    }
4730    None
4731}
4732
4733/// Simpler version of get_inner_matrix_of_global_uniform that only looks at the
4734/// immediate expression, rather than traversing an access chain.
4735fn get_global_uniform_matrix(
4736    module: &Module,
4737    base: Handle<crate::Expression>,
4738    func_ctx: &back::FunctionCtx<'_>,
4739) -> Option<MatrixType> {
4740    let base_tr = func_ctx
4741        .resolve_type(base, &module.types)
4742        .pointer_base_type();
4743    let base_ty = base_tr.as_ref().map(|tr| tr.inner_with(&module.types));
4744    match (&func_ctx.expressions[base], base_ty) {
4745        (
4746            &crate::Expression::GlobalVariable(handle),
4747            Some(&TypeInner::Matrix {
4748                columns,
4749                rows,
4750                scalar,
4751            }),
4752        ) if module.global_variables[handle].space == crate::AddressSpace::Uniform => {
4753            Some(MatrixType {
4754                columns,
4755                rows,
4756                width: scalar.width,
4757            })
4758        }
4759        _ => None,
4760    }
4761}
4762
4763/// Returns the matrix data if the access chain starting at `base`:
4764/// - starts with an expression with resolved type of [`TypeInner::Matrix`]
4765/// - contains zero or more expressions with resolved type of [`TypeInner::Array`] of [`TypeInner::Matrix`]
4766/// - ends with an [`Expression::GlobalVariable`](crate::Expression::GlobalVariable) in [`AddressSpace::Uniform`](crate::AddressSpace::Uniform)
4767fn get_inner_matrix_of_global_uniform(
4768    module: &Module,
4769    base: Handle<crate::Expression>,
4770    func_ctx: &back::FunctionCtx<'_>,
4771) -> Option<MatrixType> {
4772    let mut mat_data = None;
4773    let mut array_base = None;
4774
4775    let mut current_base = base;
4776    loop {
4777        let mut resolved = func_ctx.resolve_type(current_base, &module.types);
4778        if let TypeInner::Pointer { base, .. } = *resolved {
4779            resolved = &module.types[base].inner;
4780        };
4781
4782        match *resolved {
4783            TypeInner::Matrix {
4784                columns,
4785                rows,
4786                scalar,
4787            } => {
4788                mat_data = Some(MatrixType {
4789                    columns,
4790                    rows,
4791                    width: scalar.width,
4792                })
4793            }
4794            TypeInner::Array { base, .. } => {
4795                array_base = Some(base);
4796            }
4797            _ => break,
4798        }
4799
4800        current_base = match func_ctx.expressions[current_base] {
4801            crate::Expression::Access { base, .. } => base,
4802            crate::Expression::AccessIndex { base, .. } => base,
4803            crate::Expression::GlobalVariable(handle)
4804                if module.global_variables[handle].space == crate::AddressSpace::Uniform =>
4805            {
4806                return mat_data.or_else(|| {
4807                    array_base.and_then(|array_base| get_inner_matrix_data(module, array_base))
4808                })
4809            }
4810            _ => break,
4811        };
4812    }
4813    None
4814}