naga/back/hlsl/
writer.rs

1use alloc::{
2    format,
3    string::{String, ToString},
4    vec::Vec,
5};
6use core::{fmt, mem};
7
8use super::{
9    help,
10    help::{
11        WrappedArrayLength, WrappedConstructor, WrappedImageQuery, WrappedStructMatrixAccess,
12        WrappedZeroValue,
13    },
14    storage::StoreValue,
15    BackendResult, Error, FragmentEntryPoint, Options, PipelineOptions, ShaderModel,
16};
17use crate::{
18    back::{self, get_entry_points, Baked},
19    common,
20    proc::{self, index, ExternalTextureNameKey, NameKey},
21    valid, Handle, Module, RayQueryFunction, Scalar, ScalarKind, ShaderStage, TypeInner,
22};
23
24const LOCATION_SEMANTIC: &str = "LOC";
25const SPECIAL_CBUF_TYPE: &str = "NagaConstants";
26const SPECIAL_CBUF_VAR: &str = "_NagaConstants";
27const SPECIAL_FIRST_VERTEX: &str = "first_vertex";
28const SPECIAL_FIRST_INSTANCE: &str = "first_instance";
29const SPECIAL_OTHER: &str = "other";
30
31pub(crate) const MODF_FUNCTION: &str = "naga_modf";
32pub(crate) const FREXP_FUNCTION: &str = "naga_frexp";
33pub(crate) const EXTRACT_BITS_FUNCTION: &str = "naga_extractBits";
34pub(crate) const INSERT_BITS_FUNCTION: &str = "naga_insertBits";
35pub(crate) const SAMPLER_HEAP_VAR: &str = "nagaSamplerHeap";
36pub(crate) const COMPARISON_SAMPLER_HEAP_VAR: &str = "nagaComparisonSamplerHeap";
37pub(crate) const SAMPLE_EXTERNAL_TEXTURE_FUNCTION: &str = "nagaSampleExternalTexture";
38pub(crate) const ABS_FUNCTION: &str = "naga_abs";
39pub(crate) const DIV_FUNCTION: &str = "naga_div";
40pub(crate) const MOD_FUNCTION: &str = "naga_mod";
41pub(crate) const NEG_FUNCTION: &str = "naga_neg";
42pub(crate) const F2I32_FUNCTION: &str = "naga_f2i32";
43pub(crate) const F2U32_FUNCTION: &str = "naga_f2u32";
44pub(crate) const F2I64_FUNCTION: &str = "naga_f2i64";
45pub(crate) const F2U64_FUNCTION: &str = "naga_f2u64";
46pub(crate) const IMAGE_SAMPLE_BASE_CLAMP_TO_EDGE_FUNCTION: &str =
47    "nagaTextureSampleBaseClampToEdge";
48pub(crate) const IMAGE_LOAD_EXTERNAL_FUNCTION: &str = "nagaTextureLoadExternal";
49
50enum Index {
51    Expression(Handle<crate::Expression>),
52    Static(u32),
53}
54
55struct EpStructMember {
56    name: String,
57    ty: Handle<crate::Type>,
58    // technically, this should always be `Some`
59    // (we `debug_assert!` this in `write_interface_struct`)
60    binding: Option<crate::Binding>,
61    index: u32,
62}
63
64/// Structure contains information required for generating
65/// wrapped structure of all entry points arguments
66struct EntryPointBinding {
67    /// Name of the fake EP argument that contains the struct
68    /// with all the flattened input data.
69    arg_name: String,
70    /// Generated structure name
71    ty_name: String,
72    /// Members of generated structure
73    members: Vec<EpStructMember>,
74}
75
76pub(super) struct EntryPointInterface {
77    /// If `Some`, the input of an entry point is gathered in a special
78    /// struct with members sorted by binding.
79    /// The `EntryPointBinding::members` array is sorted by index,
80    /// so that we can walk it in `write_ep_arguments_initialization`.
81    input: Option<EntryPointBinding>,
82    /// If `Some`, the output of an entry point is flattened.
83    /// The `EntryPointBinding::members` array is sorted by binding,
84    /// So that we can walk it in `Statement::Return` handler.
85    output: Option<EntryPointBinding>,
86}
87
88#[derive(Clone, Eq, PartialEq, PartialOrd, Ord)]
89enum InterfaceKey {
90    Location(u32),
91    BuiltIn(crate::BuiltIn),
92    Other,
93}
94
95impl InterfaceKey {
96    const fn new(binding: Option<&crate::Binding>) -> Self {
97        match binding {
98            Some(&crate::Binding::Location { location, .. }) => Self::Location(location),
99            Some(&crate::Binding::BuiltIn(built_in)) => Self::BuiltIn(built_in),
100            None => Self::Other,
101        }
102    }
103}
104
105#[derive(Copy, Clone, PartialEq)]
106enum Io {
107    Input,
108    Output,
109}
110
111const fn is_subgroup_builtin_binding(binding: &Option<crate::Binding>) -> bool {
112    let &Some(crate::Binding::BuiltIn(builtin)) = binding else {
113        return false;
114    };
115    matches!(
116        builtin,
117        crate::BuiltIn::SubgroupSize
118            | crate::BuiltIn::SubgroupInvocationId
119            | crate::BuiltIn::NumSubgroups
120            | crate::BuiltIn::SubgroupId
121    )
122}
123
124/// Information for how to generate a `binding_array<sampler>` access.
125struct BindingArraySamplerInfo {
126    /// Variable name of the sampler heap
127    sampler_heap_name: &'static str,
128    /// Variable name of the sampler index buffer
129    sampler_index_buffer_name: String,
130    /// Variable name of the base index _into_ the sampler index buffer
131    binding_array_base_index_name: String,
132}
133
134impl<'a, W: fmt::Write> super::Writer<'a, W> {
135    pub fn new(out: W, options: &'a Options, pipeline_options: &'a PipelineOptions) -> Self {
136        Self {
137            out,
138            names: crate::FastHashMap::default(),
139            namer: proc::Namer::default(),
140            options,
141            pipeline_options,
142            entry_point_io: crate::FastHashMap::default(),
143            named_expressions: crate::NamedExpressions::default(),
144            wrapped: super::Wrapped::default(),
145            written_committed_intersection: false,
146            written_candidate_intersection: false,
147            continue_ctx: back::continue_forward::ContinueCtx::default(),
148            temp_access_chain: Vec::new(),
149            need_bake_expressions: Default::default(),
150        }
151    }
152
153    fn reset(&mut self, module: &Module) {
154        self.names.clear();
155        self.namer.reset(
156            module,
157            &super::keywords::RESERVED_SET,
158            &super::keywords::RESERVED_CASE_INSENSITIVE_SET,
159            super::keywords::RESERVED_PREFIXES,
160            &mut self.names,
161        );
162        self.entry_point_io.clear();
163        self.named_expressions.clear();
164        self.wrapped.clear();
165        self.written_committed_intersection = false;
166        self.written_candidate_intersection = false;
167        self.continue_ctx.clear();
168        self.need_bake_expressions.clear();
169    }
170
171    /// Generates statements to be inserted immediately before and at the very
172    /// start of the body of each loop, to defeat infinite loop reasoning.
173    /// The 0th item of the returned tuple should be inserted immediately prior
174    /// to the loop and the 1st item should be inserted at the very start of
175    /// the loop body.
176    ///
177    /// See [`back::msl::Writer::gen_force_bounded_loop_statements`] for details.
178    fn gen_force_bounded_loop_statements(
179        &mut self,
180        level: back::Level,
181    ) -> Option<(String, String)> {
182        if !self.options.force_loop_bounding {
183            return None;
184        }
185
186        let loop_bound_name = self.namer.call("loop_bound");
187        let max = u32::MAX;
188        // Count down from u32::MAX rather than up from 0 to avoid hang on
189        // certain Intel drivers. See <https://github.com/gfx-rs/wgpu/issues/7319>.
190        let decl = format!("{level}uint2 {loop_bound_name} = uint2({max}u, {max}u);");
191        let level = level.next();
192        let break_and_inc = format!(
193            "{level}if (all({loop_bound_name} == uint2(0u, 0u))) {{ break; }}
194{level}{loop_bound_name} -= uint2({loop_bound_name}.y == 0u, 1u);"
195        );
196
197        Some((decl, break_and_inc))
198    }
199
200    /// Helper method used to find which expressions of a given function require baking
201    ///
202    /// # Notes
203    /// Clears `need_bake_expressions` set before adding to it
204    fn update_expressions_to_bake(
205        &mut self,
206        module: &Module,
207        func: &crate::Function,
208        info: &valid::FunctionInfo,
209    ) {
210        use crate::Expression;
211        self.need_bake_expressions.clear();
212        for (exp_handle, expr) in func.expressions.iter() {
213            let expr_info = &info[exp_handle];
214            let min_ref_count = func.expressions[exp_handle].bake_ref_count();
215            if min_ref_count <= expr_info.ref_count {
216                self.need_bake_expressions.insert(exp_handle);
217            }
218
219            if let Expression::Math { fun, arg, arg1, .. } = *expr {
220                match fun {
221                    crate::MathFunction::Asinh
222                    | crate::MathFunction::Acosh
223                    | crate::MathFunction::Atanh
224                    | crate::MathFunction::Unpack2x16float
225                    | crate::MathFunction::Unpack2x16snorm
226                    | crate::MathFunction::Unpack2x16unorm
227                    | crate::MathFunction::Unpack4x8snorm
228                    | crate::MathFunction::Unpack4x8unorm
229                    | crate::MathFunction::Unpack4xI8
230                    | crate::MathFunction::Unpack4xU8
231                    | crate::MathFunction::Pack2x16float
232                    | crate::MathFunction::Pack2x16snorm
233                    | crate::MathFunction::Pack2x16unorm
234                    | crate::MathFunction::Pack4x8snorm
235                    | crate::MathFunction::Pack4x8unorm
236                    | crate::MathFunction::Pack4xI8
237                    | crate::MathFunction::Pack4xU8
238                    | crate::MathFunction::Pack4xI8Clamp
239                    | crate::MathFunction::Pack4xU8Clamp => {
240                        self.need_bake_expressions.insert(arg);
241                    }
242                    crate::MathFunction::CountLeadingZeros => {
243                        let inner = info[exp_handle].ty.inner_with(&module.types);
244                        if let Some(ScalarKind::Sint) = inner.scalar_kind() {
245                            self.need_bake_expressions.insert(arg);
246                        }
247                    }
248                    crate::MathFunction::Dot4U8Packed | crate::MathFunction::Dot4I8Packed => {
249                        self.need_bake_expressions.insert(arg);
250                        self.need_bake_expressions.insert(arg1.unwrap());
251                    }
252                    _ => {}
253                }
254            }
255
256            if let Expression::Derivative { axis, ctrl, expr } = *expr {
257                use crate::{DerivativeAxis as Axis, DerivativeControl as Ctrl};
258                if axis == Axis::Width && (ctrl == Ctrl::Coarse || ctrl == Ctrl::Fine) {
259                    self.need_bake_expressions.insert(expr);
260                }
261            }
262
263            if let Expression::GlobalVariable(_) = *expr {
264                let inner = info[exp_handle].ty.inner_with(&module.types);
265
266                if let TypeInner::Sampler { .. } = *inner {
267                    self.need_bake_expressions.insert(exp_handle);
268                }
269            }
270        }
271        for statement in func.body.iter() {
272            match *statement {
273                crate::Statement::SubgroupCollectiveOperation {
274                    op: _,
275                    collective_op: crate::CollectiveOperation::InclusiveScan,
276                    argument,
277                    result: _,
278                } => {
279                    self.need_bake_expressions.insert(argument);
280                }
281                crate::Statement::Atomic {
282                    fun: crate::AtomicFunction::Exchange { compare: Some(cmp) },
283                    ..
284                } => {
285                    self.need_bake_expressions.insert(cmp);
286                }
287                _ => {}
288            }
289        }
290    }
291
292    pub fn write(
293        &mut self,
294        module: &Module,
295        module_info: &valid::ModuleInfo,
296        fragment_entry_point: Option<&FragmentEntryPoint<'_>>,
297    ) -> Result<super::ReflectionInfo, Error> {
298        self.reset(module);
299
300        // Write special constants, if needed
301        if let Some(ref bt) = self.options.special_constants_binding {
302            writeln!(self.out, "struct {SPECIAL_CBUF_TYPE} {{")?;
303            writeln!(self.out, "{}int {};", back::INDENT, SPECIAL_FIRST_VERTEX)?;
304            writeln!(self.out, "{}int {};", back::INDENT, SPECIAL_FIRST_INSTANCE)?;
305            writeln!(self.out, "{}uint {};", back::INDENT, SPECIAL_OTHER)?;
306            writeln!(self.out, "}};")?;
307            write!(
308                self.out,
309                "ConstantBuffer<{}> {}: register(b{}",
310                SPECIAL_CBUF_TYPE, SPECIAL_CBUF_VAR, bt.register
311            )?;
312            if bt.space != 0 {
313                write!(self.out, ", space{}", bt.space)?;
314            }
315            writeln!(self.out, ");")?;
316
317            // Extra newline for readability
318            writeln!(self.out)?;
319        }
320
321        for (group, bt) in self.options.dynamic_storage_buffer_offsets_targets.iter() {
322            writeln!(self.out, "struct __dynamic_buffer_offsetsTy{group} {{")?;
323            for i in 0..bt.size {
324                writeln!(self.out, "{}uint _{};", back::INDENT, i)?;
325            }
326            writeln!(self.out, "}};")?;
327            writeln!(
328                self.out,
329                "ConstantBuffer<__dynamic_buffer_offsetsTy{}> __dynamic_buffer_offsets{}: register(b{}, space{});",
330                group, group, bt.register, bt.space
331            )?;
332
333            // Extra newline for readability
334            writeln!(self.out)?;
335        }
336
337        // Save all entry point output types
338        let ep_results = module
339            .entry_points
340            .iter()
341            .map(|ep| (ep.stage, ep.function.result.clone()))
342            .collect::<Vec<(ShaderStage, Option<crate::FunctionResult>)>>();
343
344        self.write_all_mat_cx2_typedefs_and_functions(module)?;
345
346        // Write all structs
347        for (handle, ty) in module.types.iter() {
348            if let TypeInner::Struct { ref members, span } = ty.inner {
349                if module.types[members.last().unwrap().ty]
350                    .inner
351                    .is_dynamically_sized(&module.types)
352                {
353                    // unsized arrays can only be in storage buffers,
354                    // for which we use `ByteAddressBuffer` anyway.
355                    continue;
356                }
357
358                let ep_result = ep_results.iter().find(|e| {
359                    if let Some(ref result) = e.1 {
360                        result.ty == handle
361                    } else {
362                        false
363                    }
364                });
365
366                self.write_struct(
367                    module,
368                    handle,
369                    members,
370                    span,
371                    ep_result.map(|r| (r.0, Io::Output)),
372                )?;
373                writeln!(self.out)?;
374            }
375        }
376
377        self.write_special_functions(module)?;
378
379        self.write_wrapped_expression_functions(module, &module.global_expressions, None)?;
380        self.write_wrapped_zero_value_functions(module, &module.global_expressions)?;
381
382        // Write all named constants
383        let mut constants = module
384            .constants
385            .iter()
386            .filter(|&(_, c)| c.name.is_some())
387            .peekable();
388        while let Some((handle, _)) = constants.next() {
389            self.write_global_constant(module, handle)?;
390            // Add extra newline for readability on last iteration
391            if constants.peek().is_none() {
392                writeln!(self.out)?;
393            }
394        }
395
396        // Write all globals
397        for (global, _) in module.global_variables.iter() {
398            self.write_global(module, global)?;
399        }
400
401        if !module.global_variables.is_empty() {
402            // Add extra newline for readability
403            writeln!(self.out)?;
404        }
405
406        let ep_range = get_entry_points(module, self.pipeline_options.entry_point.as_ref())
407            .map_err(|(stage, name)| Error::EntryPointNotFound(stage, name))?;
408
409        // Write all entry points wrapped structs
410        for index in ep_range.clone() {
411            let ep = &module.entry_points[index];
412            let ep_name = self.names[&NameKey::EntryPoint(index as u16)].clone();
413            let ep_io = self.write_ep_interface(
414                module,
415                &ep.function,
416                ep.stage,
417                &ep_name,
418                fragment_entry_point,
419            )?;
420            self.entry_point_io.insert(index, ep_io);
421        }
422
423        // Write all regular functions
424        for (handle, function) in module.functions.iter() {
425            let info = &module_info[handle];
426
427            // Check if all of the globals are accessible
428            if !self.options.fake_missing_bindings {
429                if let Some((var_handle, _)) =
430                    module
431                        .global_variables
432                        .iter()
433                        .find(|&(var_handle, var)| match var.binding {
434                            Some(ref binding) if !info[var_handle].is_empty() => {
435                                self.options.resolve_resource_binding(binding).is_err()
436                                    && self
437                                        .options
438                                        .resolve_external_texture_resource_binding(binding)
439                                        .is_err()
440                            }
441                            _ => false,
442                        })
443                {
444                    log::info!(
445                        "Skipping function {:?} (name {:?}) because global {:?} is inaccessible",
446                        handle,
447                        function.name,
448                        var_handle
449                    );
450                    continue;
451                }
452            }
453
454            let ctx = back::FunctionCtx {
455                ty: back::FunctionType::Function(handle),
456                info,
457                expressions: &function.expressions,
458                named_expressions: &function.named_expressions,
459            };
460            let name = self.names[&NameKey::Function(handle)].clone();
461
462            self.write_wrapped_functions(module, &ctx)?;
463
464            self.write_function(module, name.as_str(), function, &ctx, info)?;
465
466            writeln!(self.out)?;
467        }
468
469        let mut translated_ep_names = Vec::with_capacity(ep_range.len());
470
471        // Write all entry points
472        for index in ep_range {
473            let ep = &module.entry_points[index];
474            let info = module_info.get_entry_point(index);
475
476            if !self.options.fake_missing_bindings {
477                let mut ep_error = None;
478                for (var_handle, var) in module.global_variables.iter() {
479                    match var.binding {
480                        Some(ref binding) if !info[var_handle].is_empty() => {
481                            if let Err(err) = self.options.resolve_resource_binding(binding) {
482                                if self
483                                    .options
484                                    .resolve_external_texture_resource_binding(binding)
485                                    .is_err()
486                                {
487                                    ep_error = Some(err);
488                                    break;
489                                }
490                            }
491                        }
492                        _ => {}
493                    }
494                }
495                if let Some(err) = ep_error {
496                    translated_ep_names.push(Err(err));
497                    continue;
498                }
499            }
500
501            let ctx = back::FunctionCtx {
502                ty: back::FunctionType::EntryPoint(index as u16),
503                info,
504                expressions: &ep.function.expressions,
505                named_expressions: &ep.function.named_expressions,
506            };
507
508            self.write_wrapped_functions(module, &ctx)?;
509
510            if ep.stage == ShaderStage::Compute {
511                // HLSL is calling workgroup size "num threads"
512                let num_threads = ep.workgroup_size;
513                writeln!(
514                    self.out,
515                    "[numthreads({}, {}, {})]",
516                    num_threads[0], num_threads[1], num_threads[2]
517                )?;
518            }
519
520            let name = self.names[&NameKey::EntryPoint(index as u16)].clone();
521            self.write_function(module, &name, &ep.function, &ctx, info)?;
522
523            if index < module.entry_points.len() - 1 {
524                writeln!(self.out)?;
525            }
526
527            translated_ep_names.push(Ok(name));
528        }
529
530        Ok(super::ReflectionInfo {
531            entry_point_names: translated_ep_names,
532        })
533    }
534
535    fn write_modifier(&mut self, binding: &crate::Binding) -> BackendResult {
536        match *binding {
537            crate::Binding::BuiltIn(crate::BuiltIn::Position { invariant: true }) => {
538                write!(self.out, "precise ")?;
539            }
540            crate::Binding::Location {
541                interpolation,
542                sampling,
543                ..
544            } => {
545                if let Some(interpolation) = interpolation {
546                    if let Some(string) = interpolation.to_hlsl_str() {
547                        write!(self.out, "{string} ")?
548                    }
549                }
550
551                if let Some(sampling) = sampling {
552                    if let Some(string) = sampling.to_hlsl_str() {
553                        write!(self.out, "{string} ")?
554                    }
555                }
556            }
557            crate::Binding::BuiltIn(_) => {}
558        }
559
560        Ok(())
561    }
562
563    //TODO: we could force fragment outputs to always go through `entry_point_io.output` path
564    // if they are struct, so that the `stage` argument here could be omitted.
565    fn write_semantic(
566        &mut self,
567        binding: &Option<crate::Binding>,
568        stage: Option<(ShaderStage, Io)>,
569    ) -> BackendResult {
570        match *binding {
571            Some(crate::Binding::BuiltIn(builtin)) if !is_subgroup_builtin_binding(binding) => {
572                let builtin_str = builtin.to_hlsl_str()?;
573                write!(self.out, " : {builtin_str}")?;
574            }
575            Some(crate::Binding::Location {
576                blend_src: Some(1), ..
577            }) => {
578                write!(self.out, " : SV_Target1")?;
579            }
580            Some(crate::Binding::Location { location, .. }) => {
581                if stage == Some((ShaderStage::Fragment, Io::Output)) {
582                    write!(self.out, " : SV_Target{location}")?;
583                } else {
584                    write!(self.out, " : {LOCATION_SEMANTIC}{location}")?;
585                }
586            }
587            _ => {}
588        }
589
590        Ok(())
591    }
592
593    fn write_interface_struct(
594        &mut self,
595        module: &Module,
596        shader_stage: (ShaderStage, Io),
597        struct_name: String,
598        mut members: Vec<EpStructMember>,
599    ) -> Result<EntryPointBinding, Error> {
600        // Sort the members so that first come the user-defined varyings
601        // in ascending locations, and then built-ins. This allows VS and FS
602        // interfaces to match with regards to order.
603        members.sort_by_key(|m| InterfaceKey::new(m.binding.as_ref()));
604
605        write!(self.out, "struct {struct_name}")?;
606        writeln!(self.out, " {{")?;
607        for m in members.iter() {
608            // Sanity check that each IO member is a built-in or is assigned a
609            // location. Also see note about nesting in `write_ep_input_struct`.
610            debug_assert!(m.binding.is_some());
611
612            if is_subgroup_builtin_binding(&m.binding) {
613                continue;
614            }
615            write!(self.out, "{}", back::INDENT)?;
616            if let Some(ref binding) = m.binding {
617                self.write_modifier(binding)?;
618            }
619            self.write_type(module, m.ty)?;
620            write!(self.out, " {}", &m.name)?;
621            self.write_semantic(&m.binding, Some(shader_stage))?;
622            writeln!(self.out, ";")?;
623        }
624        if members.iter().any(|arg| {
625            matches!(
626                arg.binding,
627                Some(crate::Binding::BuiltIn(crate::BuiltIn::SubgroupId))
628            )
629        }) {
630            writeln!(
631                self.out,
632                "{}uint __local_invocation_index : SV_GroupIndex;",
633                back::INDENT
634            )?;
635        }
636        writeln!(self.out, "}};")?;
637        writeln!(self.out)?;
638
639        // See ordering notes on EntryPointInterface fields
640        match shader_stage.1 {
641            Io::Input => {
642                // bring back the original order
643                members.sort_by_key(|m| m.index);
644            }
645            Io::Output => {
646                // keep it sorted by binding
647            }
648        }
649
650        Ok(EntryPointBinding {
651            arg_name: self.namer.call(struct_name.to_lowercase().as_str()),
652            ty_name: struct_name,
653            members,
654        })
655    }
656
657    /// Flatten all entry point arguments into a single struct.
658    /// This is needed since we need to re-order them: first placing user locations,
659    /// then built-ins.
660    fn write_ep_input_struct(
661        &mut self,
662        module: &Module,
663        func: &crate::Function,
664        stage: ShaderStage,
665        entry_point_name: &str,
666    ) -> Result<EntryPointBinding, Error> {
667        let struct_name = format!("{stage:?}Input_{entry_point_name}");
668
669        let mut fake_members = Vec::new();
670        for arg in func.arguments.iter() {
671            // NOTE: We don't need to handle nesting structs. All members must
672            // be either built-ins or assigned a location. I.E. `binding` is
673            // `Some`. This is checked in `VaryingContext::validate`. See:
674            // https://gpuweb.github.io/gpuweb/wgsl/#input-output-locations
675            match module.types[arg.ty].inner {
676                TypeInner::Struct { ref members, .. } => {
677                    for member in members.iter() {
678                        let name = self.namer.call_or(&member.name, "member");
679                        let index = fake_members.len() as u32;
680                        fake_members.push(EpStructMember {
681                            name,
682                            ty: member.ty,
683                            binding: member.binding.clone(),
684                            index,
685                        });
686                    }
687                }
688                _ => {
689                    let member_name = self.namer.call_or(&arg.name, "member");
690                    let index = fake_members.len() as u32;
691                    fake_members.push(EpStructMember {
692                        name: member_name,
693                        ty: arg.ty,
694                        binding: arg.binding.clone(),
695                        index,
696                    });
697                }
698            }
699        }
700
701        self.write_interface_struct(module, (stage, Io::Input), struct_name, fake_members)
702    }
703
704    /// Flatten all entry point results into a single struct.
705    /// This is needed since we need to re-order them: first placing user locations,
706    /// then built-ins.
707    fn write_ep_output_struct(
708        &mut self,
709        module: &Module,
710        result: &crate::FunctionResult,
711        stage: ShaderStage,
712        entry_point_name: &str,
713        frag_ep: Option<&FragmentEntryPoint<'_>>,
714    ) -> Result<EntryPointBinding, Error> {
715        let struct_name = format!("{stage:?}Output_{entry_point_name}");
716
717        let empty = [];
718        let members = match module.types[result.ty].inner {
719            TypeInner::Struct { ref members, .. } => members,
720            ref other => {
721                log::error!("Unexpected {other:?} output type without a binding");
722                &empty[..]
723            }
724        };
725
726        // Gather list of fragment input locations. We use this below to remove user-defined
727        // varyings from VS outputs that aren't in the FS inputs. This makes the VS interface match
728        // as long as the FS inputs are a subset of the VS outputs. This is only applied if the
729        // writer is supplied with information about the fragment entry point.
730        let fs_input_locs = if let (Some(frag_ep), ShaderStage::Vertex) = (frag_ep, stage) {
731            let mut fs_input_locs = Vec::new();
732            for arg in frag_ep.func.arguments.iter() {
733                let mut push_if_location = |binding: &Option<crate::Binding>| match *binding {
734                    Some(crate::Binding::Location { location, .. }) => fs_input_locs.push(location),
735                    Some(crate::Binding::BuiltIn(_)) | None => {}
736                };
737
738                // NOTE: We don't need to handle struct nesting. See note in
739                // `write_ep_input_struct`.
740                match frag_ep.module.types[arg.ty].inner {
741                    TypeInner::Struct { ref members, .. } => {
742                        for member in members.iter() {
743                            push_if_location(&member.binding);
744                        }
745                    }
746                    _ => push_if_location(&arg.binding),
747                }
748            }
749            fs_input_locs.sort();
750            Some(fs_input_locs)
751        } else {
752            None
753        };
754
755        let mut fake_members = Vec::new();
756        for (index, member) in members.iter().enumerate() {
757            if let Some(ref fs_input_locs) = fs_input_locs {
758                match member.binding {
759                    Some(crate::Binding::Location { location, .. }) => {
760                        if fs_input_locs.binary_search(&location).is_err() {
761                            continue;
762                        }
763                    }
764                    Some(crate::Binding::BuiltIn(_)) | None => {}
765                }
766            }
767
768            let member_name = self.namer.call_or(&member.name, "member");
769            fake_members.push(EpStructMember {
770                name: member_name,
771                ty: member.ty,
772                binding: member.binding.clone(),
773                index: index as u32,
774            });
775        }
776
777        self.write_interface_struct(module, (stage, Io::Output), struct_name, fake_members)
778    }
779
780    /// Writes special interface structures for an entry point. The special structures have
781    /// all the fields flattened into them and sorted by binding. They are needed to emulate
782    /// subgroup built-ins and to make the interfaces between VS outputs and FS inputs match.
783    fn write_ep_interface(
784        &mut self,
785        module: &Module,
786        func: &crate::Function,
787        stage: ShaderStage,
788        ep_name: &str,
789        frag_ep: Option<&FragmentEntryPoint<'_>>,
790    ) -> Result<EntryPointInterface, Error> {
791        Ok(EntryPointInterface {
792            input: if !func.arguments.is_empty()
793                && (stage == ShaderStage::Fragment
794                    || func
795                        .arguments
796                        .iter()
797                        .any(|arg| is_subgroup_builtin_binding(&arg.binding)))
798            {
799                Some(self.write_ep_input_struct(module, func, stage, ep_name)?)
800            } else {
801                None
802            },
803            output: match func.result {
804                Some(ref fr) if fr.binding.is_none() && stage == ShaderStage::Vertex => {
805                    Some(self.write_ep_output_struct(module, fr, stage, ep_name, frag_ep)?)
806                }
807                _ => None,
808            },
809        })
810    }
811
812    fn write_ep_argument_initialization(
813        &mut self,
814        ep: &crate::EntryPoint,
815        ep_input: &EntryPointBinding,
816        fake_member: &EpStructMember,
817    ) -> BackendResult {
818        match fake_member.binding {
819            Some(crate::Binding::BuiltIn(crate::BuiltIn::SubgroupSize)) => {
820                write!(self.out, "WaveGetLaneCount()")?
821            }
822            Some(crate::Binding::BuiltIn(crate::BuiltIn::SubgroupInvocationId)) => {
823                write!(self.out, "WaveGetLaneIndex()")?
824            }
825            Some(crate::Binding::BuiltIn(crate::BuiltIn::NumSubgroups)) => write!(
826                self.out,
827                "({}u + WaveGetLaneCount() - 1u) / WaveGetLaneCount()",
828                ep.workgroup_size[0] * ep.workgroup_size[1] * ep.workgroup_size[2]
829            )?,
830            Some(crate::Binding::BuiltIn(crate::BuiltIn::SubgroupId)) => {
831                write!(
832                    self.out,
833                    "{}.__local_invocation_index / WaveGetLaneCount()",
834                    ep_input.arg_name
835                )?;
836            }
837            _ => {
838                write!(self.out, "{}.{}", ep_input.arg_name, fake_member.name)?;
839            }
840        }
841        Ok(())
842    }
843
844    /// Write an entry point preface that initializes the arguments as specified in IR.
845    fn write_ep_arguments_initialization(
846        &mut self,
847        module: &Module,
848        func: &crate::Function,
849        ep_index: u16,
850    ) -> BackendResult {
851        let ep = &module.entry_points[ep_index as usize];
852        let ep_input = match self
853            .entry_point_io
854            .get_mut(&(ep_index as usize))
855            .unwrap()
856            .input
857            .take()
858        {
859            Some(ep_input) => ep_input,
860            None => return Ok(()),
861        };
862        let mut fake_iter = ep_input.members.iter();
863        for (arg_index, arg) in func.arguments.iter().enumerate() {
864            write!(self.out, "{}", back::INDENT)?;
865            self.write_type(module, arg.ty)?;
866            let arg_name = &self.names[&NameKey::EntryPointArgument(ep_index, arg_index as u32)];
867            write!(self.out, " {arg_name}")?;
868            match module.types[arg.ty].inner {
869                TypeInner::Array { base, size, .. } => {
870                    self.write_array_size(module, base, size)?;
871                    write!(self.out, " = ")?;
872                    self.write_ep_argument_initialization(
873                        ep,
874                        &ep_input,
875                        fake_iter.next().unwrap(),
876                    )?;
877                    writeln!(self.out, ";")?;
878                }
879                TypeInner::Struct { ref members, .. } => {
880                    write!(self.out, " = {{ ")?;
881                    for index in 0..members.len() {
882                        if index != 0 {
883                            write!(self.out, ", ")?;
884                        }
885                        self.write_ep_argument_initialization(
886                            ep,
887                            &ep_input,
888                            fake_iter.next().unwrap(),
889                        )?;
890                    }
891                    writeln!(self.out, " }};")?;
892                }
893                _ => {
894                    write!(self.out, " = ")?;
895                    self.write_ep_argument_initialization(
896                        ep,
897                        &ep_input,
898                        fake_iter.next().unwrap(),
899                    )?;
900                    writeln!(self.out, ";")?;
901                }
902            }
903        }
904        assert!(fake_iter.next().is_none());
905        Ok(())
906    }
907
908    /// Helper method used to write global variables
909    /// # Notes
910    /// Always adds a newline
911    fn write_global(
912        &mut self,
913        module: &Module,
914        handle: Handle<crate::GlobalVariable>,
915    ) -> BackendResult {
916        let global = &module.global_variables[handle];
917        let inner = &module.types[global.ty].inner;
918
919        let handle_ty = match *inner {
920            TypeInner::BindingArray { ref base, .. } => &module.types[*base].inner,
921            _ => inner,
922        };
923
924        // External textures are handled entirely differently, so defer entirely to that method.
925        // We do so prior to calling resolve_resource_binding() below, as we even need to resolve
926        // their bindings separately.
927        let is_external_texture = matches!(
928            *handle_ty,
929            TypeInner::Image {
930                class: crate::ImageClass::External,
931                ..
932            }
933        );
934        if is_external_texture {
935            return self.write_global_external_texture(module, handle, global);
936        }
937
938        if let Some(ref binding) = global.binding {
939            if let Err(err) = self.options.resolve_resource_binding(binding) {
940                log::info!(
941                    "Skipping global {:?} (name {:?}) for being inaccessible: {}",
942                    handle,
943                    global.name,
944                    err,
945                );
946                return Ok(());
947            }
948        }
949
950        // Samplers are handled entirely differently, so defer entirely to that method.
951        let is_sampler = matches!(*handle_ty, TypeInner::Sampler { .. });
952
953        if is_sampler {
954            return self.write_global_sampler(module, handle, global);
955        }
956
957        // https://docs.microsoft.com/en-us/windows/win32/direct3dhlsl/dx-graphics-hlsl-variable-register
958        let register_ty = match global.space {
959            crate::AddressSpace::Function => unreachable!("Function address space"),
960            crate::AddressSpace::Private => {
961                write!(self.out, "static ")?;
962                self.write_type(module, global.ty)?;
963                ""
964            }
965            crate::AddressSpace::WorkGroup => {
966                write!(self.out, "groupshared ")?;
967                self.write_type(module, global.ty)?;
968                ""
969            }
970            crate::AddressSpace::Uniform => {
971                // constant buffer declarations are expected to be inlined, e.g.
972                // `cbuffer foo: register(b0) { field1: type1; }`
973                write!(self.out, "cbuffer")?;
974                "b"
975            }
976            crate::AddressSpace::Storage { access } => {
977                let (prefix, register) = if access.contains(crate::StorageAccess::STORE) {
978                    ("RW", "u")
979                } else {
980                    ("", "t")
981                };
982                write!(self.out, "{prefix}ByteAddressBuffer")?;
983                register
984            }
985            crate::AddressSpace::Handle => {
986                let register = match *handle_ty {
987                    // all storage textures are UAV, unconditionally
988                    TypeInner::Image {
989                        class: crate::ImageClass::Storage { .. },
990                        ..
991                    } => "u",
992                    _ => "t",
993                };
994                self.write_type(module, global.ty)?;
995                register
996            }
997            crate::AddressSpace::PushConstant => {
998                // The type of the push constants will be wrapped in `ConstantBuffer`
999                write!(self.out, "ConstantBuffer<")?;
1000                "b"
1001            }
1002        };
1003
1004        // If the global is a push constant write the type now because it will be a
1005        // generic argument to `ConstantBuffer`
1006        if global.space == crate::AddressSpace::PushConstant {
1007            self.write_global_type(module, global.ty)?;
1008
1009            // need to write the array size if the type was emitted with `write_type`
1010            if let TypeInner::Array { base, size, .. } = module.types[global.ty].inner {
1011                self.write_array_size(module, base, size)?;
1012            }
1013
1014            // Close the angled brackets for the generic argument
1015            write!(self.out, ">")?;
1016        }
1017
1018        let name = &self.names[&NameKey::GlobalVariable(handle)];
1019        write!(self.out, " {name}")?;
1020
1021        // Push constants need to be assigned a binding explicitly by the consumer
1022        // since naga has no way to know the binding from the shader alone
1023        if global.space == crate::AddressSpace::PushConstant {
1024            match module.types[global.ty].inner {
1025                TypeInner::Struct { .. } => {}
1026                _ => {
1027                    return Err(Error::Unimplemented(format!(
1028                        "push-constant '{name}' has non-struct type; tracked by: https://github.com/gfx-rs/wgpu/issues/5683"
1029                    )));
1030                }
1031            }
1032
1033            let target = self
1034                .options
1035                .push_constants_target
1036                .as_ref()
1037                .expect("No bind target was defined for the push constants block");
1038            write!(self.out, ": register(b{}", target.register)?;
1039            if target.space != 0 {
1040                write!(self.out, ", space{}", target.space)?;
1041            }
1042            write!(self.out, ")")?;
1043        }
1044
1045        if let Some(ref binding) = global.binding {
1046            // this was already resolved earlier when we started evaluating an entry point.
1047            let bt = self.options.resolve_resource_binding(binding).unwrap();
1048
1049            // need to write the binding array size if the type was emitted with `write_type`
1050            if let TypeInner::BindingArray { base, size, .. } = module.types[global.ty].inner {
1051                if let Some(overridden_size) = bt.binding_array_size {
1052                    write!(self.out, "[{overridden_size}]")?;
1053                } else {
1054                    self.write_array_size(module, base, size)?;
1055                }
1056            }
1057
1058            write!(self.out, " : register({}{}", register_ty, bt.register)?;
1059            if bt.space != 0 {
1060                write!(self.out, ", space{}", bt.space)?;
1061            }
1062            write!(self.out, ")")?;
1063        } else {
1064            // need to write the array size if the type was emitted with `write_type`
1065            if let TypeInner::Array { base, size, .. } = module.types[global.ty].inner {
1066                self.write_array_size(module, base, size)?;
1067            }
1068            if global.space == crate::AddressSpace::Private {
1069                write!(self.out, " = ")?;
1070                if let Some(init) = global.init {
1071                    self.write_const_expression(module, init, &module.global_expressions)?;
1072                } else {
1073                    self.write_default_init(module, global.ty)?;
1074                }
1075            }
1076        }
1077
1078        if global.space == crate::AddressSpace::Uniform {
1079            write!(self.out, " {{ ")?;
1080
1081            self.write_global_type(module, global.ty)?;
1082
1083            write!(
1084                self.out,
1085                " {}",
1086                &self.names[&NameKey::GlobalVariable(handle)]
1087            )?;
1088
1089            // need to write the array size if the type was emitted with `write_type`
1090            if let TypeInner::Array { base, size, .. } = module.types[global.ty].inner {
1091                self.write_array_size(module, base, size)?;
1092            }
1093
1094            writeln!(self.out, "; }}")?;
1095        } else {
1096            writeln!(self.out, ";")?;
1097        }
1098
1099        Ok(())
1100    }
1101
1102    fn write_global_sampler(
1103        &mut self,
1104        module: &Module,
1105        handle: Handle<crate::GlobalVariable>,
1106        global: &crate::GlobalVariable,
1107    ) -> BackendResult {
1108        let binding = *global.binding.as_ref().unwrap();
1109
1110        let key = super::SamplerIndexBufferKey {
1111            group: binding.group,
1112        };
1113        self.write_wrapped_sampler_buffer(key)?;
1114
1115        // This was already validated, so we can confidently unwrap it.
1116        let bt = self.options.resolve_resource_binding(&binding).unwrap();
1117
1118        match module.types[global.ty].inner {
1119            TypeInner::Sampler { comparison } => {
1120                // If we are generating a static access, we create a variable for the sampler.
1121                //
1122                // This prevents the DXIL from containing multiple lookups for the sampler, which
1123                // the backend compiler will then have to eliminate. AMD does seem to be able to
1124                // eliminate these, but better safe than sorry.
1125
1126                write!(self.out, "static const ")?;
1127                self.write_type(module, global.ty)?;
1128
1129                let heap_var = if comparison {
1130                    COMPARISON_SAMPLER_HEAP_VAR
1131                } else {
1132                    SAMPLER_HEAP_VAR
1133                };
1134
1135                let index_buffer_name = &self.wrapped.sampler_index_buffers[&key];
1136                let name = &self.names[&NameKey::GlobalVariable(handle)];
1137                writeln!(
1138                    self.out,
1139                    " {name} = {heap_var}[{index_buffer_name}[{register}]];",
1140                    register = bt.register
1141                )?;
1142            }
1143            TypeInner::BindingArray { .. } => {
1144                // If we are generating a binding array, we cannot directly access the sampler as the index
1145                // into the sampler index buffer is unknown at compile time. Instead we generate a constant
1146                // that represents the "base" index into the sampler index buffer. This constant is added
1147                // to the user provided index to get the final index into the sampler index buffer.
1148
1149                let name = &self.names[&NameKey::GlobalVariable(handle)];
1150                writeln!(
1151                    self.out,
1152                    "static const uint {name} = {register};",
1153                    register = bt.register
1154                )?;
1155            }
1156            _ => unreachable!(),
1157        };
1158
1159        Ok(())
1160    }
1161
1162    /// Write the declarations for an external texture global variable.
1163    /// These are emitted as multiple global variables: Three `Texture2D`s
1164    /// (one for each plane) and a parameters cbuffer.
1165    fn write_global_external_texture(
1166        &mut self,
1167        module: &Module,
1168        handle: Handle<crate::GlobalVariable>,
1169        global: &crate::GlobalVariable,
1170    ) -> BackendResult {
1171        let res_binding = global
1172            .binding
1173            .as_ref()
1174            .expect("External texture global variables must have a resource binding");
1175        let ext_tex_bindings = match self
1176            .options
1177            .resolve_external_texture_resource_binding(res_binding)
1178        {
1179            Ok(bindings) => bindings,
1180            Err(err) => {
1181                log::info!(
1182                    "Skipping global {:?} (name {:?}) for being inaccessible: {}",
1183                    handle,
1184                    global.name,
1185                    err,
1186                );
1187                return Ok(());
1188            }
1189        };
1190
1191        let mut write_plane = |bt: &super::BindTarget, name| -> BackendResult {
1192            write!(
1193                self.out,
1194                "Texture2D<float4> {}: register(t{}",
1195                name, bt.register
1196            )?;
1197            if bt.space != 0 {
1198                write!(self.out, ", space{}", bt.space)?;
1199            }
1200            writeln!(self.out, ");")?;
1201            Ok(())
1202        };
1203        for (i, bt) in ext_tex_bindings.planes.iter().enumerate() {
1204            let plane_name = &self.names
1205                [&NameKey::ExternalTextureGlobalVariable(handle, ExternalTextureNameKey::Plane(i))];
1206            write_plane(bt, plane_name)?;
1207        }
1208
1209        let params_name = &self.names
1210            [&NameKey::ExternalTextureGlobalVariable(handle, ExternalTextureNameKey::Params)];
1211        let params_ty_name =
1212            &self.names[&NameKey::Type(module.special_types.external_texture_params.unwrap())];
1213        write!(
1214            self.out,
1215            "cbuffer {}: register(b{}",
1216            params_name, ext_tex_bindings.params.register
1217        )?;
1218        if ext_tex_bindings.params.space != 0 {
1219            write!(self.out, ", space{}", ext_tex_bindings.params.space)?;
1220        }
1221        writeln!(self.out, ") {{ {params_ty_name} {params_name}; }};")?;
1222
1223        Ok(())
1224    }
1225
1226    /// Helper method used to write global constants
1227    ///
1228    /// # Notes
1229    /// Ends in a newline
1230    fn write_global_constant(
1231        &mut self,
1232        module: &Module,
1233        handle: Handle<crate::Constant>,
1234    ) -> BackendResult {
1235        write!(self.out, "static const ")?;
1236        let constant = &module.constants[handle];
1237        self.write_type(module, constant.ty)?;
1238        let name = &self.names[&NameKey::Constant(handle)];
1239        write!(self.out, " {name}")?;
1240        // Write size for array type
1241        if let TypeInner::Array { base, size, .. } = module.types[constant.ty].inner {
1242            self.write_array_size(module, base, size)?;
1243        }
1244        write!(self.out, " = ")?;
1245        self.write_const_expression(module, constant.init, &module.global_expressions)?;
1246        writeln!(self.out, ";")?;
1247        Ok(())
1248    }
1249
1250    pub(super) fn write_array_size(
1251        &mut self,
1252        module: &Module,
1253        base: Handle<crate::Type>,
1254        size: crate::ArraySize,
1255    ) -> BackendResult {
1256        write!(self.out, "[")?;
1257
1258        match size.resolve(module.to_ctx())? {
1259            proc::IndexableLength::Known(size) => {
1260                write!(self.out, "{size}")?;
1261            }
1262            proc::IndexableLength::Dynamic => unreachable!(),
1263        }
1264
1265        write!(self.out, "]")?;
1266
1267        if let TypeInner::Array {
1268            base: next_base,
1269            size: next_size,
1270            ..
1271        } = module.types[base].inner
1272        {
1273            self.write_array_size(module, next_base, next_size)?;
1274        }
1275
1276        Ok(())
1277    }
1278
1279    /// Helper method used to write structs
1280    ///
1281    /// # Notes
1282    /// Ends in a newline
1283    fn write_struct(
1284        &mut self,
1285        module: &Module,
1286        handle: Handle<crate::Type>,
1287        members: &[crate::StructMember],
1288        span: u32,
1289        shader_stage: Option<(ShaderStage, Io)>,
1290    ) -> BackendResult {
1291        // Write struct name
1292        let struct_name = &self.names[&NameKey::Type(handle)];
1293        writeln!(self.out, "struct {struct_name} {{")?;
1294
1295        let mut last_offset = 0;
1296        for (index, member) in members.iter().enumerate() {
1297            if member.binding.is_none() && member.offset > last_offset {
1298                // using int as padding should work as long as the backend
1299                // doesn't support a type that's less than 4 bytes in size
1300                // (Error::UnsupportedScalar catches this)
1301                let padding = (member.offset - last_offset) / 4;
1302                for i in 0..padding {
1303                    writeln!(self.out, "{}int _pad{}_{};", back::INDENT, index, i)?;
1304                }
1305            }
1306            let ty_inner = &module.types[member.ty].inner;
1307            last_offset = member.offset + ty_inner.size_hlsl(module.to_ctx())?;
1308
1309            // The indentation is only for readability
1310            write!(self.out, "{}", back::INDENT)?;
1311
1312            match module.types[member.ty].inner {
1313                TypeInner::Array { base, size, .. } => {
1314                    // HLSL arrays are written as `type name[size]`
1315
1316                    self.write_global_type(module, member.ty)?;
1317
1318                    // Write `name`
1319                    write!(
1320                        self.out,
1321                        " {}",
1322                        &self.names[&NameKey::StructMember(handle, index as u32)]
1323                    )?;
1324                    // Write [size]
1325                    self.write_array_size(module, base, size)?;
1326                }
1327                // We treat matrices of the form `matCx2` as a sequence of C `vec2`s.
1328                // See the module-level block comment in mod.rs for details.
1329                TypeInner::Matrix {
1330                    rows,
1331                    columns,
1332                    scalar,
1333                } if member.binding.is_none() && rows == crate::VectorSize::Bi => {
1334                    let vec_ty = TypeInner::Vector { size: rows, scalar };
1335                    let field_name_key = NameKey::StructMember(handle, index as u32);
1336
1337                    for i in 0..columns as u8 {
1338                        if i != 0 {
1339                            write!(self.out, "; ")?;
1340                        }
1341                        self.write_value_type(module, &vec_ty)?;
1342                        write!(self.out, " {}_{}", &self.names[&field_name_key], i)?;
1343                    }
1344                }
1345                _ => {
1346                    // Write modifier before type
1347                    if let Some(ref binding) = member.binding {
1348                        self.write_modifier(binding)?;
1349                    }
1350
1351                    // Even though Naga IR matrices are column-major, we must describe
1352                    // matrices passed from the CPU as being in row-major order.
1353                    // See the module-level block comment in mod.rs for details.
1354                    if let TypeInner::Matrix { .. } = module.types[member.ty].inner {
1355                        write!(self.out, "row_major ")?;
1356                    }
1357
1358                    // Write the member type and name
1359                    self.write_type(module, member.ty)?;
1360                    write!(
1361                        self.out,
1362                        " {}",
1363                        &self.names[&NameKey::StructMember(handle, index as u32)]
1364                    )?;
1365                }
1366            }
1367
1368            self.write_semantic(&member.binding, shader_stage)?;
1369            writeln!(self.out, ";")?;
1370        }
1371
1372        // add padding at the end since sizes of types don't get rounded up to their alignment in HLSL
1373        if members.last().unwrap().binding.is_none() && span > last_offset {
1374            let padding = (span - last_offset) / 4;
1375            for i in 0..padding {
1376                writeln!(self.out, "{}int _end_pad_{};", back::INDENT, i)?;
1377            }
1378        }
1379
1380        writeln!(self.out, "}};")?;
1381        Ok(())
1382    }
1383
1384    /// Helper method used to write global/structs non image/sampler types
1385    ///
1386    /// # Notes
1387    /// Adds no trailing or leading whitespace
1388    pub(super) fn write_global_type(
1389        &mut self,
1390        module: &Module,
1391        ty: Handle<crate::Type>,
1392    ) -> BackendResult {
1393        let matrix_data = get_inner_matrix_data(module, ty);
1394
1395        // We treat matrices of the form `matCx2` as a sequence of C `vec2`s.
1396        // See the module-level block comment in mod.rs for details.
1397        if let Some(MatrixType {
1398            columns,
1399            rows: crate::VectorSize::Bi,
1400            width: 4,
1401        }) = matrix_data
1402        {
1403            write!(self.out, "__mat{}x2", columns as u8)?;
1404        } else {
1405            // Even though Naga IR matrices are column-major, we must describe
1406            // matrices passed from the CPU as being in row-major order.
1407            // See the module-level block comment in mod.rs for details.
1408            if matrix_data.is_some() {
1409                write!(self.out, "row_major ")?;
1410            }
1411
1412            self.write_type(module, ty)?;
1413        }
1414
1415        Ok(())
1416    }
1417
1418    /// Helper method used to write non image/sampler types
1419    ///
1420    /// # Notes
1421    /// Adds no trailing or leading whitespace
1422    pub(super) fn write_type(&mut self, module: &Module, ty: Handle<crate::Type>) -> BackendResult {
1423        let inner = &module.types[ty].inner;
1424        match *inner {
1425            TypeInner::Struct { .. } => write!(self.out, "{}", self.names[&NameKey::Type(ty)])?,
1426            // hlsl array has the size separated from the base type
1427            TypeInner::Array { base, .. } | TypeInner::BindingArray { base, .. } => {
1428                self.write_type(module, base)?
1429            }
1430            ref other => self.write_value_type(module, other)?,
1431        }
1432
1433        Ok(())
1434    }
1435
1436    /// Helper method used to write value types
1437    ///
1438    /// # Notes
1439    /// Adds no trailing or leading whitespace
1440    pub(super) fn write_value_type(&mut self, module: &Module, inner: &TypeInner) -> BackendResult {
1441        match *inner {
1442            TypeInner::Scalar(scalar) | TypeInner::Atomic(scalar) => {
1443                write!(self.out, "{}", scalar.to_hlsl_str()?)?;
1444            }
1445            TypeInner::Vector { size, scalar } => {
1446                write!(
1447                    self.out,
1448                    "{}{}",
1449                    scalar.to_hlsl_str()?,
1450                    common::vector_size_str(size)
1451                )?;
1452            }
1453            TypeInner::Matrix {
1454                columns,
1455                rows,
1456                scalar,
1457            } => {
1458                // The IR supports only float matrix
1459                // https://docs.microsoft.com/en-us/windows/win32/direct3dhlsl/dx-graphics-hlsl-matrix
1460
1461                // Because of the implicit transpose all matrices have in HLSL, we need to transpose the size as well.
1462                write!(
1463                    self.out,
1464                    "{}{}x{}",
1465                    scalar.to_hlsl_str()?,
1466                    common::vector_size_str(columns),
1467                    common::vector_size_str(rows),
1468                )?;
1469            }
1470            TypeInner::Image {
1471                dim,
1472                arrayed,
1473                class,
1474            } => {
1475                self.write_image_type(dim, arrayed, class)?;
1476            }
1477            TypeInner::Sampler { comparison } => {
1478                let sampler = if comparison {
1479                    "SamplerComparisonState"
1480                } else {
1481                    "SamplerState"
1482                };
1483                write!(self.out, "{sampler}")?;
1484            }
1485            // HLSL arrays are written as `type name[size]`
1486            // Current code is written arrays only as `[size]`
1487            // Base `type` and `name` should be written outside
1488            TypeInner::Array { base, size, .. } | TypeInner::BindingArray { base, size } => {
1489                self.write_array_size(module, base, size)?;
1490            }
1491            TypeInner::AccelerationStructure { .. } => {
1492                write!(self.out, "RaytracingAccelerationStructure")?;
1493            }
1494            TypeInner::RayQuery { .. } => {
1495                // these are constant flags, there are dynamic flags also but constant flags are not supported by naga
1496                write!(self.out, "RayQuery<RAY_FLAG_NONE>")?;
1497            }
1498            _ => return Err(Error::Unimplemented(format!("write_value_type {inner:?}"))),
1499        }
1500
1501        Ok(())
1502    }
1503
1504    /// Helper method used to write functions
1505    /// # Notes
1506    /// Ends in a newline
1507    fn write_function(
1508        &mut self,
1509        module: &Module,
1510        name: &str,
1511        func: &crate::Function,
1512        func_ctx: &back::FunctionCtx<'_>,
1513        info: &valid::FunctionInfo,
1514    ) -> BackendResult {
1515        // Function Declaration Syntax - https://docs.microsoft.com/en-us/windows/win32/direct3dhlsl/dx-graphics-hlsl-function-syntax
1516
1517        self.update_expressions_to_bake(module, func, info);
1518
1519        if let Some(ref result) = func.result {
1520            // Write typedef if return type is an array
1521            let array_return_type = match module.types[result.ty].inner {
1522                TypeInner::Array { base, size, .. } => {
1523                    let array_return_type = self.namer.call(&format!("ret_{name}"));
1524                    write!(self.out, "typedef ")?;
1525                    self.write_type(module, result.ty)?;
1526                    write!(self.out, " {array_return_type}")?;
1527                    self.write_array_size(module, base, size)?;
1528                    writeln!(self.out, ";")?;
1529                    Some(array_return_type)
1530                }
1531                _ => None,
1532            };
1533
1534            // Write modifier
1535            if let Some(
1536                ref binding @ crate::Binding::BuiltIn(crate::BuiltIn::Position { invariant: true }),
1537            ) = result.binding
1538            {
1539                self.write_modifier(binding)?;
1540            }
1541
1542            // Write return type
1543            match func_ctx.ty {
1544                back::FunctionType::Function(_) => {
1545                    if let Some(array_return_type) = array_return_type {
1546                        write!(self.out, "{array_return_type}")?;
1547                    } else {
1548                        self.write_type(module, result.ty)?;
1549                    }
1550                }
1551                back::FunctionType::EntryPoint(index) => {
1552                    if let Some(ref ep_output) =
1553                        self.entry_point_io.get(&(index as usize)).unwrap().output
1554                    {
1555                        write!(self.out, "{}", ep_output.ty_name)?;
1556                    } else {
1557                        self.write_type(module, result.ty)?;
1558                    }
1559                }
1560            }
1561        } else {
1562            write!(self.out, "void")?;
1563        }
1564
1565        // Write function name
1566        write!(self.out, " {name}(")?;
1567
1568        let need_workgroup_variables_initialization =
1569            self.need_workgroup_variables_initialization(func_ctx, module);
1570
1571        // Write function arguments for non entry point functions
1572        match func_ctx.ty {
1573            back::FunctionType::Function(handle) => {
1574                for (index, arg) in func.arguments.iter().enumerate() {
1575                    if index != 0 {
1576                        write!(self.out, ", ")?;
1577                    }
1578
1579                    self.write_function_argument(module, handle, arg, index)?;
1580                }
1581            }
1582            back::FunctionType::EntryPoint(ep_index) => {
1583                if let Some(ref ep_input) =
1584                    self.entry_point_io.get(&(ep_index as usize)).unwrap().input
1585                {
1586                    write!(self.out, "{} {}", ep_input.ty_name, ep_input.arg_name)?;
1587                } else {
1588                    let stage = module.entry_points[ep_index as usize].stage;
1589                    for (index, arg) in func.arguments.iter().enumerate() {
1590                        if index != 0 {
1591                            write!(self.out, ", ")?;
1592                        }
1593                        self.write_type(module, arg.ty)?;
1594
1595                        let argument_name =
1596                            &self.names[&NameKey::EntryPointArgument(ep_index, index as u32)];
1597
1598                        write!(self.out, " {argument_name}")?;
1599                        if let TypeInner::Array { base, size, .. } = module.types[arg.ty].inner {
1600                            self.write_array_size(module, base, size)?;
1601                        }
1602
1603                        self.write_semantic(&arg.binding, Some((stage, Io::Input)))?;
1604                    }
1605                }
1606                if need_workgroup_variables_initialization {
1607                    if self
1608                        .entry_point_io
1609                        .get(&(ep_index as usize))
1610                        .unwrap()
1611                        .input
1612                        .is_some()
1613                        || !func.arguments.is_empty()
1614                    {
1615                        write!(self.out, ", ")?;
1616                    }
1617                    write!(self.out, "uint3 __local_invocation_id : SV_GroupThreadID")?;
1618                }
1619            }
1620        }
1621        // Ends of arguments
1622        write!(self.out, ")")?;
1623
1624        // Write semantic if it present
1625        if let back::FunctionType::EntryPoint(index) = func_ctx.ty {
1626            let stage = module.entry_points[index as usize].stage;
1627            if let Some(crate::FunctionResult { ref binding, .. }) = func.result {
1628                self.write_semantic(binding, Some((stage, Io::Output)))?;
1629            }
1630        }
1631
1632        // Function body start
1633        writeln!(self.out)?;
1634        writeln!(self.out, "{{")?;
1635
1636        if need_workgroup_variables_initialization {
1637            self.write_workgroup_variables_initialization(func_ctx, module)?;
1638        }
1639
1640        if let back::FunctionType::EntryPoint(index) = func_ctx.ty {
1641            self.write_ep_arguments_initialization(module, func, index)?;
1642        }
1643
1644        // Write function local variables
1645        for (handle, local) in func.local_variables.iter() {
1646            // Write indentation (only for readability)
1647            write!(self.out, "{}", back::INDENT)?;
1648
1649            // Write the local name
1650            // The leading space is important
1651            self.write_type(module, local.ty)?;
1652            write!(self.out, " {}", self.names[&func_ctx.name_key(handle)])?;
1653            // Write size for array type
1654            if let TypeInner::Array { base, size, .. } = module.types[local.ty].inner {
1655                self.write_array_size(module, base, size)?;
1656            }
1657
1658            match module.types[local.ty].inner {
1659                // from https://microsoft.github.io/DirectX-Specs/d3d/Raytracing.html#tracerayinline-example-1 it seems that ray queries shouldn't be zeroed
1660                TypeInner::RayQuery { .. } => {}
1661                _ => {
1662                    write!(self.out, " = ")?;
1663                    // Write the local initializer if needed
1664                    if let Some(init) = local.init {
1665                        self.write_expr(module, init, func_ctx)?;
1666                    } else {
1667                        // Zero initialize local variables
1668                        self.write_default_init(module, local.ty)?;
1669                    }
1670                }
1671            }
1672            // Finish the local with `;` and add a newline (only for readability)
1673            writeln!(self.out, ";")?
1674        }
1675
1676        if !func.local_variables.is_empty() {
1677            writeln!(self.out)?;
1678        }
1679
1680        // Write the function body (statement list)
1681        for sta in func.body.iter() {
1682            // The indentation should always be 1 when writing the function body
1683            self.write_stmt(module, sta, func_ctx, back::Level(1))?;
1684        }
1685
1686        writeln!(self.out, "}}")?;
1687
1688        self.named_expressions.clear();
1689
1690        Ok(())
1691    }
1692
1693    fn write_function_argument(
1694        &mut self,
1695        module: &Module,
1696        handle: Handle<crate::Function>,
1697        arg: &crate::FunctionArgument,
1698        index: usize,
1699    ) -> BackendResult {
1700        // External texture arguments must be expanded into separate
1701        // arguments for each plane and the params buffer.
1702        if let TypeInner::Image {
1703            class: crate::ImageClass::External,
1704            ..
1705        } = module.types[arg.ty].inner
1706        {
1707            return self.write_function_external_texture_argument(module, handle, index);
1708        }
1709
1710        // Write argument type
1711        let arg_ty = match module.types[arg.ty].inner {
1712            // pointers in function arguments are expected and resolve to `inout`
1713            TypeInner::Pointer { base, .. } => {
1714                //TODO: can we narrow this down to just `in` when possible?
1715                write!(self.out, "inout ")?;
1716                base
1717            }
1718            _ => arg.ty,
1719        };
1720        self.write_type(module, arg_ty)?;
1721
1722        let argument_name = &self.names[&NameKey::FunctionArgument(handle, index as u32)];
1723
1724        // Write argument name. Space is important.
1725        write!(self.out, " {argument_name}")?;
1726        if let TypeInner::Array { base, size, .. } = module.types[arg_ty].inner {
1727            self.write_array_size(module, base, size)?;
1728        }
1729
1730        Ok(())
1731    }
1732
1733    fn write_function_external_texture_argument(
1734        &mut self,
1735        module: &Module,
1736        handle: Handle<crate::Function>,
1737        index: usize,
1738    ) -> BackendResult {
1739        let plane_names = [0, 1, 2].map(|i| {
1740            &self.names[&NameKey::ExternalTextureFunctionArgument(
1741                handle,
1742                index as u32,
1743                ExternalTextureNameKey::Plane(i),
1744            )]
1745        });
1746        let params_name = &self.names[&NameKey::ExternalTextureFunctionArgument(
1747            handle,
1748            index as u32,
1749            ExternalTextureNameKey::Params,
1750        )];
1751        let params_ty_name =
1752            &self.names[&NameKey::Type(module.special_types.external_texture_params.unwrap())];
1753        write!(
1754            self.out,
1755            "Texture2D<float4> {}, Texture2D<float4> {}, Texture2D<float4> {}, {params_ty_name} {params_name}",
1756            plane_names[0], plane_names[1], plane_names[2],
1757        )?;
1758        Ok(())
1759    }
1760
1761    fn need_workgroup_variables_initialization(
1762        &mut self,
1763        func_ctx: &back::FunctionCtx,
1764        module: &Module,
1765    ) -> bool {
1766        self.options.zero_initialize_workgroup_memory
1767            && func_ctx.ty.is_compute_entry_point(module)
1768            && module.global_variables.iter().any(|(handle, var)| {
1769                !func_ctx.info[handle].is_empty() && var.space == crate::AddressSpace::WorkGroup
1770            })
1771    }
1772
1773    fn write_workgroup_variables_initialization(
1774        &mut self,
1775        func_ctx: &back::FunctionCtx,
1776        module: &Module,
1777    ) -> BackendResult {
1778        let level = back::Level(1);
1779
1780        writeln!(
1781            self.out,
1782            "{level}if (all(__local_invocation_id == uint3(0u, 0u, 0u))) {{"
1783        )?;
1784
1785        let vars = module.global_variables.iter().filter(|&(handle, var)| {
1786            !func_ctx.info[handle].is_empty() && var.space == crate::AddressSpace::WorkGroup
1787        });
1788
1789        for (handle, var) in vars {
1790            let name = &self.names[&NameKey::GlobalVariable(handle)];
1791            write!(self.out, "{}{} = ", level.next(), name)?;
1792            self.write_default_init(module, var.ty)?;
1793            writeln!(self.out, ";")?;
1794        }
1795
1796        writeln!(self.out, "{level}}}")?;
1797        self.write_control_barrier(crate::Barrier::WORK_GROUP, level)
1798    }
1799
1800    /// Helper method used to write switches
1801    fn write_switch(
1802        &mut self,
1803        module: &Module,
1804        func_ctx: &back::FunctionCtx<'_>,
1805        level: back::Level,
1806        selector: Handle<crate::Expression>,
1807        cases: &[crate::SwitchCase],
1808    ) -> BackendResult {
1809        // Write all cases
1810        let indent_level_1 = level.next();
1811        let indent_level_2 = indent_level_1.next();
1812
1813        // See docs of `back::continue_forward` module.
1814        if let Some(variable) = self.continue_ctx.enter_switch(&mut self.namer) {
1815            writeln!(self.out, "{level}bool {variable} = false;",)?;
1816        };
1817
1818        // Check if there is only one body, by seeing if all except the last case are fall through
1819        // with empty bodies. FXC doesn't handle these switches correctly, so
1820        // we generate a `do {} while(false);` loop instead. There must be a default case, so there
1821        // is no need to check if one of the cases would have matched.
1822        let one_body = cases
1823            .iter()
1824            .rev()
1825            .skip(1)
1826            .all(|case| case.fall_through && case.body.is_empty());
1827        if one_body {
1828            // Start the do-while
1829            writeln!(self.out, "{level}do {{")?;
1830            // Note: Expressions have no side-effects so we don't need to emit selector expression.
1831
1832            // Body
1833            if let Some(case) = cases.last() {
1834                for sta in case.body.iter() {
1835                    self.write_stmt(module, sta, func_ctx, indent_level_1)?;
1836                }
1837            }
1838            // End do-while
1839            writeln!(self.out, "{level}}} while(false);")?;
1840        } else {
1841            // Start the switch
1842            write!(self.out, "{level}")?;
1843            write!(self.out, "switch(")?;
1844            self.write_expr(module, selector, func_ctx)?;
1845            writeln!(self.out, ") {{")?;
1846
1847            for (i, case) in cases.iter().enumerate() {
1848                match case.value {
1849                    crate::SwitchValue::I32(value) => {
1850                        write!(self.out, "{indent_level_1}case {value}:")?
1851                    }
1852                    crate::SwitchValue::U32(value) => {
1853                        write!(self.out, "{indent_level_1}case {value}u:")?
1854                    }
1855                    crate::SwitchValue::Default => write!(self.out, "{indent_level_1}default:")?,
1856                }
1857
1858                // The new block is not only stylistic, it plays a role here:
1859                // We might end up having to write the same case body
1860                // multiple times due to FXC not supporting fallthrough.
1861                // Therefore, some `Expression`s written by `Statement::Emit`
1862                // will end up having the same name (`_expr<handle_index>`).
1863                // So we need to put each case in its own scope.
1864                let write_block_braces = !(case.fall_through && case.body.is_empty());
1865                if write_block_braces {
1866                    writeln!(self.out, " {{")?;
1867                } else {
1868                    writeln!(self.out)?;
1869                }
1870
1871                // Although FXC does support a series of case clauses before
1872                // a block[^yes], it does not support fallthrough from a
1873                // non-empty case block to the next[^no]. If this case has a
1874                // non-empty body with a fallthrough, emulate that by
1875                // duplicating the bodies of all the cases it would fall
1876                // into as extensions of this case's own body. This makes
1877                // the HLSL output potentially quadratic in the size of the
1878                // Naga IR.
1879                //
1880                // [^yes]: ```hlsl
1881                // case 1:
1882                // case 2: do_stuff()
1883                // ```
1884                // [^no]: ```hlsl
1885                // case 1: do_this();
1886                // case 2: do_that();
1887                // ```
1888                if case.fall_through && !case.body.is_empty() {
1889                    let curr_len = i + 1;
1890                    let end_case_idx = curr_len
1891                        + cases
1892                            .iter()
1893                            .skip(curr_len)
1894                            .position(|case| !case.fall_through)
1895                            .unwrap();
1896                    let indent_level_3 = indent_level_2.next();
1897                    for case in &cases[i..=end_case_idx] {
1898                        writeln!(self.out, "{indent_level_2}{{")?;
1899                        let prev_len = self.named_expressions.len();
1900                        for sta in case.body.iter() {
1901                            self.write_stmt(module, sta, func_ctx, indent_level_3)?;
1902                        }
1903                        // Clear all named expressions that were previously inserted by the statements in the block
1904                        self.named_expressions.truncate(prev_len);
1905                        writeln!(self.out, "{indent_level_2}}}")?;
1906                    }
1907
1908                    let last_case = &cases[end_case_idx];
1909                    if last_case.body.last().is_none_or(|s| !s.is_terminator()) {
1910                        writeln!(self.out, "{indent_level_2}break;")?;
1911                    }
1912                } else {
1913                    for sta in case.body.iter() {
1914                        self.write_stmt(module, sta, func_ctx, indent_level_2)?;
1915                    }
1916                    if !case.fall_through && case.body.last().is_none_or(|s| !s.is_terminator()) {
1917                        writeln!(self.out, "{indent_level_2}break;")?;
1918                    }
1919                }
1920
1921                if write_block_braces {
1922                    writeln!(self.out, "{indent_level_1}}}")?;
1923                }
1924            }
1925
1926            writeln!(self.out, "{level}}}")?;
1927        }
1928
1929        // Handle any forwarded continue statements.
1930        use back::continue_forward::ExitControlFlow;
1931        let op = match self.continue_ctx.exit_switch() {
1932            ExitControlFlow::None => None,
1933            ExitControlFlow::Continue { variable } => Some(("continue", variable)),
1934            ExitControlFlow::Break { variable } => Some(("break", variable)),
1935        };
1936        if let Some((control_flow, variable)) = op {
1937            writeln!(self.out, "{level}if ({variable}) {{")?;
1938            writeln!(self.out, "{indent_level_1}{control_flow};")?;
1939            writeln!(self.out, "{level}}}")?;
1940        }
1941
1942        Ok(())
1943    }
1944
1945    fn write_index(
1946        &mut self,
1947        module: &Module,
1948        index: Index,
1949        func_ctx: &back::FunctionCtx<'_>,
1950    ) -> BackendResult {
1951        match index {
1952            Index::Static(index) => {
1953                write!(self.out, "{index}")?;
1954            }
1955            Index::Expression(index) => {
1956                self.write_expr(module, index, func_ctx)?;
1957            }
1958        }
1959        Ok(())
1960    }
1961
1962    /// Helper method used to write statements
1963    ///
1964    /// # Notes
1965    /// Always adds a newline
1966    fn write_stmt(
1967        &mut self,
1968        module: &Module,
1969        stmt: &crate::Statement,
1970        func_ctx: &back::FunctionCtx<'_>,
1971        level: back::Level,
1972    ) -> BackendResult {
1973        use crate::Statement;
1974
1975        match *stmt {
1976            Statement::Emit(ref range) => {
1977                for handle in range.clone() {
1978                    let ptr_class = func_ctx.resolve_type(handle, &module.types).pointer_space();
1979                    let expr_name = if ptr_class.is_some() {
1980                        // HLSL can't save a pointer-valued expression in a variable,
1981                        // but we shouldn't ever need to: they should never be named expressions,
1982                        // and none of the expression types flagged by bake_ref_count can be pointer-valued.
1983                        None
1984                    } else if let Some(name) = func_ctx.named_expressions.get(&handle) {
1985                        // Front end provides names for all variables at the start of writing.
1986                        // But we write them to step by step. We need to recache them
1987                        // Otherwise, we could accidentally write variable name instead of full expression.
1988                        // Also, we use sanitized names! It defense backend from generating variable with name from reserved keywords.
1989                        Some(self.namer.call(name))
1990                    } else if self.need_bake_expressions.contains(&handle) {
1991                        Some(Baked(handle).to_string())
1992                    } else {
1993                        None
1994                    };
1995
1996                    if let Some(name) = expr_name {
1997                        write!(self.out, "{level}")?;
1998                        self.write_named_expr(module, handle, name, handle, func_ctx)?;
1999                    }
2000                }
2001            }
2002            // TODO: copy-paste from glsl-out
2003            Statement::Block(ref block) => {
2004                write!(self.out, "{level}")?;
2005                writeln!(self.out, "{{")?;
2006                for sta in block.iter() {
2007                    // Increase the indentation to help with readability
2008                    self.write_stmt(module, sta, func_ctx, level.next())?
2009                }
2010                writeln!(self.out, "{level}}}")?
2011            }
2012            // TODO: copy-paste from glsl-out
2013            Statement::If {
2014                condition,
2015                ref accept,
2016                ref reject,
2017            } => {
2018                write!(self.out, "{level}")?;
2019                write!(self.out, "if (")?;
2020                self.write_expr(module, condition, func_ctx)?;
2021                writeln!(self.out, ") {{")?;
2022
2023                let l2 = level.next();
2024                for sta in accept {
2025                    // Increase indentation to help with readability
2026                    self.write_stmt(module, sta, func_ctx, l2)?;
2027                }
2028
2029                // If there are no statements in the reject block we skip writing it
2030                // This is only for readability
2031                if !reject.is_empty() {
2032                    writeln!(self.out, "{level}}} else {{")?;
2033
2034                    for sta in reject {
2035                        // Increase indentation to help with readability
2036                        self.write_stmt(module, sta, func_ctx, l2)?;
2037                    }
2038                }
2039
2040                writeln!(self.out, "{level}}}")?
2041            }
2042            // TODO: copy-paste from glsl-out
2043            Statement::Kill => writeln!(self.out, "{level}discard;")?,
2044            Statement::Return { value: None } => {
2045                writeln!(self.out, "{level}return;")?;
2046            }
2047            Statement::Return { value: Some(expr) } => {
2048                let base_ty_res = &func_ctx.info[expr].ty;
2049                let mut resolved = base_ty_res.inner_with(&module.types);
2050                if let TypeInner::Pointer { base, space: _ } = *resolved {
2051                    resolved = &module.types[base].inner;
2052                }
2053
2054                if let TypeInner::Struct { .. } = *resolved {
2055                    // We can safely unwrap here, since we now we working with struct
2056                    let ty = base_ty_res.handle().unwrap();
2057                    let struct_name = &self.names[&NameKey::Type(ty)];
2058                    let variable_name = self.namer.call(&struct_name.to_lowercase());
2059                    write!(self.out, "{level}const {struct_name} {variable_name} = ",)?;
2060                    self.write_expr(module, expr, func_ctx)?;
2061                    writeln!(self.out, ";")?;
2062
2063                    // for entry point returns, we may need to reshuffle the outputs into a different struct
2064                    let ep_output = match func_ctx.ty {
2065                        back::FunctionType::Function(_) => None,
2066                        back::FunctionType::EntryPoint(index) => self
2067                            .entry_point_io
2068                            .get(&(index as usize))
2069                            .unwrap()
2070                            .output
2071                            .as_ref(),
2072                    };
2073                    let final_name = match ep_output {
2074                        Some(ep_output) => {
2075                            let final_name = self.namer.call(&variable_name);
2076                            write!(
2077                                self.out,
2078                                "{}const {} {} = {{ ",
2079                                level, ep_output.ty_name, final_name,
2080                            )?;
2081                            for (index, m) in ep_output.members.iter().enumerate() {
2082                                if index != 0 {
2083                                    write!(self.out, ", ")?;
2084                                }
2085                                let member_name = &self.names[&NameKey::StructMember(ty, m.index)];
2086                                write!(self.out, "{variable_name}.{member_name}")?;
2087                            }
2088                            writeln!(self.out, " }};")?;
2089                            final_name
2090                        }
2091                        None => variable_name,
2092                    };
2093                    writeln!(self.out, "{level}return {final_name};")?;
2094                } else {
2095                    write!(self.out, "{level}return ")?;
2096                    self.write_expr(module, expr, func_ctx)?;
2097                    writeln!(self.out, ";")?
2098                }
2099            }
2100            Statement::Store { pointer, value } => {
2101                let ty_inner = func_ctx.resolve_type(pointer, &module.types);
2102                if let Some(crate::AddressSpace::Storage { .. }) = ty_inner.pointer_space() {
2103                    let var_handle = self.fill_access_chain(module, pointer, func_ctx)?;
2104                    self.write_storage_store(
2105                        module,
2106                        var_handle,
2107                        StoreValue::Expression(value),
2108                        func_ctx,
2109                        level,
2110                        None,
2111                    )?;
2112                } else {
2113                    // We treat matrices of the form `matCx2` as a sequence of C `vec2`s.
2114                    // See the module-level block comment in mod.rs for details.
2115                    //
2116                    // We handle matrix Stores here directly (including sub accesses for Vectors and Scalars).
2117                    // Loads are handled by `Expression::AccessIndex` (since sub accesses work fine for Loads).
2118                    enum MatrixAccess {
2119                        Direct {
2120                            base: Handle<crate::Expression>,
2121                            index: u32,
2122                        },
2123                        Struct {
2124                            columns: crate::VectorSize,
2125                            base: Handle<crate::Expression>,
2126                        },
2127                    }
2128
2129                    let get_members = |expr: Handle<crate::Expression>| {
2130                        let resolved = func_ctx.resolve_type(expr, &module.types);
2131                        match *resolved {
2132                            TypeInner::Pointer { base, .. } => match module.types[base].inner {
2133                                TypeInner::Struct { ref members, .. } => Some(members),
2134                                _ => None,
2135                            },
2136                            _ => None,
2137                        }
2138                    };
2139
2140                    write!(self.out, "{level}")?;
2141
2142                    let matrix_access_on_lhs =
2143                        find_matrix_in_access_chain(module, pointer, func_ctx).and_then(
2144                            |(matrix_expr, vector, scalar)| match (
2145                                func_ctx.resolve_type(matrix_expr, &module.types),
2146                                &func_ctx.expressions[matrix_expr],
2147                            ) {
2148                                (
2149                                    &TypeInner::Pointer { base: ty, .. },
2150                                    &crate::Expression::AccessIndex { base, index },
2151                                ) if matches!(
2152                                    module.types[ty].inner,
2153                                    TypeInner::Matrix {
2154                                        rows: crate::VectorSize::Bi,
2155                                        ..
2156                                    }
2157                                ) && get_members(base)
2158                                    .map(|members| members[index as usize].binding.is_none())
2159                                    == Some(true) =>
2160                                {
2161                                    Some((MatrixAccess::Direct { base, index }, vector, scalar))
2162                                }
2163                                _ => {
2164                                    if let Some(MatrixType {
2165                                        columns,
2166                                        rows: crate::VectorSize::Bi,
2167                                        width: 4,
2168                                    }) = get_inner_matrix_of_struct_array_member(
2169                                        module,
2170                                        matrix_expr,
2171                                        func_ctx,
2172                                        true,
2173                                    ) {
2174                                        Some((
2175                                            MatrixAccess::Struct {
2176                                                columns,
2177                                                base: matrix_expr,
2178                                            },
2179                                            vector,
2180                                            scalar,
2181                                        ))
2182                                    } else {
2183                                        None
2184                                    }
2185                                }
2186                            },
2187                        );
2188
2189                    match matrix_access_on_lhs {
2190                        Some((MatrixAccess::Direct { index, base }, vector, scalar)) => {
2191                            let base_ty_res = &func_ctx.info[base].ty;
2192                            let resolved = base_ty_res.inner_with(&module.types);
2193                            let ty = match *resolved {
2194                                TypeInner::Pointer { base, .. } => base,
2195                                _ => base_ty_res.handle().unwrap(),
2196                            };
2197
2198                            if let Some(Index::Static(vec_index)) = vector {
2199                                self.write_expr(module, base, func_ctx)?;
2200                                write!(
2201                                    self.out,
2202                                    ".{}_{}",
2203                                    &self.names[&NameKey::StructMember(ty, index)],
2204                                    vec_index
2205                                )?;
2206
2207                                if let Some(scalar_index) = scalar {
2208                                    write!(self.out, "[")?;
2209                                    self.write_index(module, scalar_index, func_ctx)?;
2210                                    write!(self.out, "]")?;
2211                                }
2212
2213                                write!(self.out, " = ")?;
2214                                self.write_expr(module, value, func_ctx)?;
2215                                writeln!(self.out, ";")?;
2216                            } else {
2217                                let access = WrappedStructMatrixAccess { ty, index };
2218                                match (&vector, &scalar) {
2219                                    (&Some(_), &Some(_)) => {
2220                                        self.write_wrapped_struct_matrix_set_scalar_function_name(
2221                                            access,
2222                                        )?;
2223                                    }
2224                                    (&Some(_), &None) => {
2225                                        self.write_wrapped_struct_matrix_set_vec_function_name(
2226                                            access,
2227                                        )?;
2228                                    }
2229                                    (&None, _) => {
2230                                        self.write_wrapped_struct_matrix_set_function_name(access)?;
2231                                    }
2232                                }
2233
2234                                write!(self.out, "(")?;
2235                                self.write_expr(module, base, func_ctx)?;
2236                                write!(self.out, ", ")?;
2237                                self.write_expr(module, value, func_ctx)?;
2238
2239                                if let Some(Index::Expression(vec_index)) = vector {
2240                                    write!(self.out, ", ")?;
2241                                    self.write_expr(module, vec_index, func_ctx)?;
2242
2243                                    if let Some(scalar_index) = scalar {
2244                                        write!(self.out, ", ")?;
2245                                        self.write_index(module, scalar_index, func_ctx)?;
2246                                    }
2247                                }
2248                                writeln!(self.out, ");")?;
2249                            }
2250                        }
2251                        Some((
2252                            MatrixAccess::Struct { columns, base },
2253                            Some(Index::Expression(vec_index)),
2254                            scalar,
2255                        )) => {
2256                            // We handle `Store`s to __matCx2 column vectors and scalar elements via
2257                            // the previously injected functions __set_col_of_matCx2 / __set_el_of_matCx2.
2258
2259                            if scalar.is_some() {
2260                                write!(self.out, "__set_el_of_mat{}x2", columns as u8)?;
2261                            } else {
2262                                write!(self.out, "__set_col_of_mat{}x2", columns as u8)?;
2263                            }
2264                            write!(self.out, "(")?;
2265                            self.write_expr(module, base, func_ctx)?;
2266                            write!(self.out, ", ")?;
2267                            self.write_expr(module, vec_index, func_ctx)?;
2268
2269                            if let Some(scalar_index) = scalar {
2270                                write!(self.out, ", ")?;
2271                                self.write_index(module, scalar_index, func_ctx)?;
2272                            }
2273
2274                            write!(self.out, ", ")?;
2275                            self.write_expr(module, value, func_ctx)?;
2276
2277                            writeln!(self.out, ");")?;
2278                        }
2279                        Some((MatrixAccess::Struct { .. }, Some(Index::Static(_)), _))
2280                        | Some((MatrixAccess::Struct { .. }, None, _))
2281                        | None => {
2282                            self.write_expr(module, pointer, func_ctx)?;
2283                            write!(self.out, " = ")?;
2284
2285                            // We cast the RHS of this store in cases where the LHS
2286                            // is a struct member with type:
2287                            //  - matCx2 or
2288                            //  - a (possibly nested) array of matCx2's
2289                            if let Some(MatrixType {
2290                                columns,
2291                                rows: crate::VectorSize::Bi,
2292                                width: 4,
2293                            }) = get_inner_matrix_of_struct_array_member(
2294                                module, pointer, func_ctx, false,
2295                            ) {
2296                                let mut resolved = func_ctx.resolve_type(pointer, &module.types);
2297                                if let TypeInner::Pointer { base, .. } = *resolved {
2298                                    resolved = &module.types[base].inner;
2299                                }
2300
2301                                write!(self.out, "(__mat{}x2", columns as u8)?;
2302                                if let TypeInner::Array { base, size, .. } = *resolved {
2303                                    self.write_array_size(module, base, size)?;
2304                                }
2305                                write!(self.out, ")")?;
2306                            }
2307
2308                            self.write_expr(module, value, func_ctx)?;
2309                            writeln!(self.out, ";")?
2310                        }
2311                    }
2312                }
2313            }
2314            Statement::Loop {
2315                ref body,
2316                ref continuing,
2317                break_if,
2318            } => {
2319                let force_loop_bound_statements = self.gen_force_bounded_loop_statements(level);
2320                let gate_name = (!continuing.is_empty() || break_if.is_some())
2321                    .then(|| self.namer.call("loop_init"));
2322
2323                if let Some((ref decl, _)) = force_loop_bound_statements {
2324                    writeln!(self.out, "{decl}")?;
2325                }
2326                if let Some(ref gate_name) = gate_name {
2327                    writeln!(self.out, "{level}bool {gate_name} = true;")?;
2328                }
2329
2330                self.continue_ctx.enter_loop();
2331                writeln!(self.out, "{level}while(true) {{")?;
2332                if let Some((_, ref break_and_inc)) = force_loop_bound_statements {
2333                    writeln!(self.out, "{break_and_inc}")?;
2334                }
2335                let l2 = level.next();
2336                if let Some(gate_name) = gate_name {
2337                    writeln!(self.out, "{l2}if (!{gate_name}) {{")?;
2338                    let l3 = l2.next();
2339                    for sta in continuing.iter() {
2340                        self.write_stmt(module, sta, func_ctx, l3)?;
2341                    }
2342                    if let Some(condition) = break_if {
2343                        write!(self.out, "{l3}if (")?;
2344                        self.write_expr(module, condition, func_ctx)?;
2345                        writeln!(self.out, ") {{")?;
2346                        writeln!(self.out, "{}break;", l3.next())?;
2347                        writeln!(self.out, "{l3}}}")?;
2348                    }
2349                    writeln!(self.out, "{l2}}}")?;
2350                    writeln!(self.out, "{l2}{gate_name} = false;")?;
2351                }
2352
2353                for sta in body.iter() {
2354                    self.write_stmt(module, sta, func_ctx, l2)?;
2355                }
2356
2357                writeln!(self.out, "{level}}}")?;
2358                self.continue_ctx.exit_loop();
2359            }
2360            Statement::Break => writeln!(self.out, "{level}break;")?,
2361            Statement::Continue => {
2362                if let Some(variable) = self.continue_ctx.continue_encountered() {
2363                    writeln!(self.out, "{level}{variable} = true;")?;
2364                    writeln!(self.out, "{level}break;")?
2365                } else {
2366                    writeln!(self.out, "{level}continue;")?
2367                }
2368            }
2369            Statement::ControlBarrier(barrier) => {
2370                self.write_control_barrier(barrier, level)?;
2371            }
2372            Statement::MemoryBarrier(barrier) => {
2373                self.write_memory_barrier(barrier, level)?;
2374            }
2375            Statement::ImageStore {
2376                image,
2377                coordinate,
2378                array_index,
2379                value,
2380            } => {
2381                write!(self.out, "{level}")?;
2382                self.write_expr(module, image, func_ctx)?;
2383
2384                write!(self.out, "[")?;
2385                if let Some(index) = array_index {
2386                    // Array index accepted only for texture_storage_2d_array, so we can safety use int3(coordinate, array_index) here
2387                    write!(self.out, "int3(")?;
2388                    self.write_expr(module, coordinate, func_ctx)?;
2389                    write!(self.out, ", ")?;
2390                    self.write_expr(module, index, func_ctx)?;
2391                    write!(self.out, ")")?;
2392                } else {
2393                    self.write_expr(module, coordinate, func_ctx)?;
2394                }
2395                write!(self.out, "]")?;
2396
2397                write!(self.out, " = ")?;
2398                self.write_expr(module, value, func_ctx)?;
2399                writeln!(self.out, ";")?;
2400            }
2401            Statement::Call {
2402                function,
2403                ref arguments,
2404                result,
2405            } => {
2406                write!(self.out, "{level}")?;
2407                if let Some(expr) = result {
2408                    write!(self.out, "const ")?;
2409                    let name = Baked(expr).to_string();
2410                    let expr_ty = &func_ctx.info[expr].ty;
2411                    let ty_inner = match *expr_ty {
2412                        proc::TypeResolution::Handle(handle) => {
2413                            self.write_type(module, handle)?;
2414                            &module.types[handle].inner
2415                        }
2416                        proc::TypeResolution::Value(ref value) => {
2417                            self.write_value_type(module, value)?;
2418                            value
2419                        }
2420                    };
2421                    write!(self.out, " {name}")?;
2422                    if let TypeInner::Array { base, size, .. } = *ty_inner {
2423                        self.write_array_size(module, base, size)?;
2424                    }
2425                    write!(self.out, " = ")?;
2426                    self.named_expressions.insert(expr, name);
2427                }
2428                let func_name = &self.names[&NameKey::Function(function)];
2429                write!(self.out, "{func_name}(")?;
2430                for (index, argument) in arguments.iter().enumerate() {
2431                    if index != 0 {
2432                        write!(self.out, ", ")?;
2433                    }
2434                    self.write_expr(module, *argument, func_ctx)?;
2435                }
2436                writeln!(self.out, ");")?
2437            }
2438            Statement::Atomic {
2439                pointer,
2440                ref fun,
2441                value,
2442                result,
2443            } => {
2444                write!(self.out, "{level}")?;
2445                let res_var_info = if let Some(res_handle) = result {
2446                    let name = Baked(res_handle).to_string();
2447                    match func_ctx.info[res_handle].ty {
2448                        proc::TypeResolution::Handle(handle) => self.write_type(module, handle)?,
2449                        proc::TypeResolution::Value(ref value) => {
2450                            self.write_value_type(module, value)?
2451                        }
2452                    };
2453                    write!(self.out, " {name}; ")?;
2454                    self.named_expressions.insert(res_handle, name.clone());
2455                    Some((res_handle, name))
2456                } else {
2457                    None
2458                };
2459                let pointer_space = func_ctx
2460                    .resolve_type(pointer, &module.types)
2461                    .pointer_space()
2462                    .unwrap();
2463                let fun_str = fun.to_hlsl_suffix();
2464                let compare_expr = match *fun {
2465                    crate::AtomicFunction::Exchange { compare: Some(cmp) } => Some(cmp),
2466                    _ => None,
2467                };
2468                match pointer_space {
2469                    crate::AddressSpace::WorkGroup => {
2470                        write!(self.out, "Interlocked{fun_str}(")?;
2471                        self.write_expr(module, pointer, func_ctx)?;
2472                        self.emit_hlsl_atomic_tail(
2473                            module,
2474                            func_ctx,
2475                            fun,
2476                            compare_expr,
2477                            value,
2478                            &res_var_info,
2479                        )?;
2480                    }
2481                    crate::AddressSpace::Storage { .. } => {
2482                        let var_handle = self.fill_access_chain(module, pointer, func_ctx)?;
2483                        let var_name = &self.names[&NameKey::GlobalVariable(var_handle)];
2484                        let width = match func_ctx.resolve_type(value, &module.types) {
2485                            &TypeInner::Scalar(Scalar { width: 8, .. }) => "64",
2486                            _ => "",
2487                        };
2488                        write!(self.out, "{var_name}.Interlocked{fun_str}{width}(")?;
2489                        let chain = mem::take(&mut self.temp_access_chain);
2490                        self.write_storage_address(module, &chain, func_ctx)?;
2491                        self.temp_access_chain = chain;
2492                        self.emit_hlsl_atomic_tail(
2493                            module,
2494                            func_ctx,
2495                            fun,
2496                            compare_expr,
2497                            value,
2498                            &res_var_info,
2499                        )?;
2500                    }
2501                    ref other => {
2502                        return Err(Error::Custom(format!(
2503                            "invalid address space {other:?} for atomic statement"
2504                        )))
2505                    }
2506                }
2507                if let Some(cmp) = compare_expr {
2508                    if let Some(&(_res_handle, ref res_name)) = res_var_info.as_ref() {
2509                        write!(
2510                            self.out,
2511                            "{level}{res_name}.exchanged = ({res_name}.old_value == "
2512                        )?;
2513                        self.write_expr(module, cmp, func_ctx)?;
2514                        writeln!(self.out, ");")?;
2515                    }
2516                }
2517            }
2518            Statement::ImageAtomic {
2519                image,
2520                coordinate,
2521                array_index,
2522                fun,
2523                value,
2524            } => {
2525                write!(self.out, "{level}")?;
2526
2527                let fun_str = fun.to_hlsl_suffix();
2528                write!(self.out, "Interlocked{fun_str}(")?;
2529                self.write_expr(module, image, func_ctx)?;
2530                write!(self.out, "[")?;
2531                self.write_texture_coordinates(
2532                    "int",
2533                    coordinate,
2534                    array_index,
2535                    None,
2536                    module,
2537                    func_ctx,
2538                )?;
2539                write!(self.out, "],")?;
2540
2541                self.write_expr(module, value, func_ctx)?;
2542                writeln!(self.out, ");")?;
2543            }
2544            Statement::WorkGroupUniformLoad { pointer, result } => {
2545                self.write_control_barrier(crate::Barrier::WORK_GROUP, level)?;
2546                write!(self.out, "{level}")?;
2547                let name = Baked(result).to_string();
2548                self.write_named_expr(module, pointer, name, result, func_ctx)?;
2549
2550                self.write_control_barrier(crate::Barrier::WORK_GROUP, level)?;
2551            }
2552            Statement::Switch {
2553                selector,
2554                ref cases,
2555            } => {
2556                self.write_switch(module, func_ctx, level, selector, cases)?;
2557            }
2558            Statement::RayQuery { query, ref fun } => match *fun {
2559                RayQueryFunction::Initialize {
2560                    acceleration_structure,
2561                    descriptor,
2562                } => {
2563                    write!(self.out, "{level}")?;
2564                    self.write_expr(module, query, func_ctx)?;
2565                    write!(self.out, ".TraceRayInline(")?;
2566                    self.write_expr(module, acceleration_structure, func_ctx)?;
2567                    write!(self.out, ", ")?;
2568                    self.write_expr(module, descriptor, func_ctx)?;
2569                    write!(self.out, ".flags, ")?;
2570                    self.write_expr(module, descriptor, func_ctx)?;
2571                    write!(self.out, ".cull_mask, ")?;
2572                    write!(self.out, "RayDescFromRayDesc_(")?;
2573                    self.write_expr(module, descriptor, func_ctx)?;
2574                    writeln!(self.out, "));")?;
2575                }
2576                RayQueryFunction::Proceed { result } => {
2577                    write!(self.out, "{level}")?;
2578                    let name = Baked(result).to_string();
2579                    write!(self.out, "const bool {name} = ")?;
2580                    self.named_expressions.insert(result, name);
2581                    self.write_expr(module, query, func_ctx)?;
2582                    writeln!(self.out, ".Proceed();")?;
2583                }
2584                RayQueryFunction::GenerateIntersection { hit_t } => {
2585                    write!(self.out, "{level}")?;
2586                    self.write_expr(module, query, func_ctx)?;
2587                    write!(self.out, ".CommitProceduralPrimitiveHit(")?;
2588                    self.write_expr(module, hit_t, func_ctx)?;
2589                    writeln!(self.out, ");")?;
2590                }
2591                RayQueryFunction::ConfirmIntersection => {
2592                    write!(self.out, "{level}")?;
2593                    self.write_expr(module, query, func_ctx)?;
2594                    writeln!(self.out, ".CommitNonOpaqueTriangleHit();")?;
2595                }
2596                RayQueryFunction::Terminate => {
2597                    write!(self.out, "{level}")?;
2598                    self.write_expr(module, query, func_ctx)?;
2599                    writeln!(self.out, ".Abort();")?;
2600                }
2601            },
2602            Statement::SubgroupBallot { result, predicate } => {
2603                write!(self.out, "{level}")?;
2604                let name = Baked(result).to_string();
2605                write!(self.out, "const uint4 {name} = ")?;
2606                self.named_expressions.insert(result, name);
2607
2608                write!(self.out, "WaveActiveBallot(")?;
2609                match predicate {
2610                    Some(predicate) => self.write_expr(module, predicate, func_ctx)?,
2611                    None => write!(self.out, "true")?,
2612                }
2613                writeln!(self.out, ");")?;
2614            }
2615            Statement::SubgroupCollectiveOperation {
2616                op,
2617                collective_op,
2618                argument,
2619                result,
2620            } => {
2621                write!(self.out, "{level}")?;
2622                write!(self.out, "const ")?;
2623                let name = Baked(result).to_string();
2624                match func_ctx.info[result].ty {
2625                    proc::TypeResolution::Handle(handle) => self.write_type(module, handle)?,
2626                    proc::TypeResolution::Value(ref value) => {
2627                        self.write_value_type(module, value)?
2628                    }
2629                };
2630                write!(self.out, " {name} = ")?;
2631                self.named_expressions.insert(result, name);
2632
2633                match (collective_op, op) {
2634                    (crate::CollectiveOperation::Reduce, crate::SubgroupOperation::All) => {
2635                        write!(self.out, "WaveActiveAllTrue(")?
2636                    }
2637                    (crate::CollectiveOperation::Reduce, crate::SubgroupOperation::Any) => {
2638                        write!(self.out, "WaveActiveAnyTrue(")?
2639                    }
2640                    (crate::CollectiveOperation::Reduce, crate::SubgroupOperation::Add) => {
2641                        write!(self.out, "WaveActiveSum(")?
2642                    }
2643                    (crate::CollectiveOperation::Reduce, crate::SubgroupOperation::Mul) => {
2644                        write!(self.out, "WaveActiveProduct(")?
2645                    }
2646                    (crate::CollectiveOperation::Reduce, crate::SubgroupOperation::Max) => {
2647                        write!(self.out, "WaveActiveMax(")?
2648                    }
2649                    (crate::CollectiveOperation::Reduce, crate::SubgroupOperation::Min) => {
2650                        write!(self.out, "WaveActiveMin(")?
2651                    }
2652                    (crate::CollectiveOperation::Reduce, crate::SubgroupOperation::And) => {
2653                        write!(self.out, "WaveActiveBitAnd(")?
2654                    }
2655                    (crate::CollectiveOperation::Reduce, crate::SubgroupOperation::Or) => {
2656                        write!(self.out, "WaveActiveBitOr(")?
2657                    }
2658                    (crate::CollectiveOperation::Reduce, crate::SubgroupOperation::Xor) => {
2659                        write!(self.out, "WaveActiveBitXor(")?
2660                    }
2661                    (crate::CollectiveOperation::ExclusiveScan, crate::SubgroupOperation::Add) => {
2662                        write!(self.out, "WavePrefixSum(")?
2663                    }
2664                    (crate::CollectiveOperation::ExclusiveScan, crate::SubgroupOperation::Mul) => {
2665                        write!(self.out, "WavePrefixProduct(")?
2666                    }
2667                    (crate::CollectiveOperation::InclusiveScan, crate::SubgroupOperation::Add) => {
2668                        self.write_expr(module, argument, func_ctx)?;
2669                        write!(self.out, " + WavePrefixSum(")?;
2670                    }
2671                    (crate::CollectiveOperation::InclusiveScan, crate::SubgroupOperation::Mul) => {
2672                        self.write_expr(module, argument, func_ctx)?;
2673                        write!(self.out, " * WavePrefixProduct(")?;
2674                    }
2675                    _ => unimplemented!(),
2676                }
2677                self.write_expr(module, argument, func_ctx)?;
2678                writeln!(self.out, ");")?;
2679            }
2680            Statement::SubgroupGather {
2681                mode,
2682                argument,
2683                result,
2684            } => {
2685                write!(self.out, "{level}")?;
2686                write!(self.out, "const ")?;
2687                let name = Baked(result).to_string();
2688                match func_ctx.info[result].ty {
2689                    proc::TypeResolution::Handle(handle) => self.write_type(module, handle)?,
2690                    proc::TypeResolution::Value(ref value) => {
2691                        self.write_value_type(module, value)?
2692                    }
2693                };
2694                write!(self.out, " {name} = ")?;
2695                self.named_expressions.insert(result, name);
2696                match mode {
2697                    crate::GatherMode::BroadcastFirst => {
2698                        write!(self.out, "WaveReadLaneFirst(")?;
2699                        self.write_expr(module, argument, func_ctx)?;
2700                    }
2701                    crate::GatherMode::QuadBroadcast(index) => {
2702                        write!(self.out, "QuadReadLaneAt(")?;
2703                        self.write_expr(module, argument, func_ctx)?;
2704                        write!(self.out, ", ")?;
2705                        self.write_expr(module, index, func_ctx)?;
2706                    }
2707                    crate::GatherMode::QuadSwap(direction) => {
2708                        match direction {
2709                            crate::Direction::X => {
2710                                write!(self.out, "QuadReadAcrossX(")?;
2711                            }
2712                            crate::Direction::Y => {
2713                                write!(self.out, "QuadReadAcrossY(")?;
2714                            }
2715                            crate::Direction::Diagonal => {
2716                                write!(self.out, "QuadReadAcrossDiagonal(")?;
2717                            }
2718                        }
2719                        self.write_expr(module, argument, func_ctx)?;
2720                    }
2721                    _ => {
2722                        write!(self.out, "WaveReadLaneAt(")?;
2723                        self.write_expr(module, argument, func_ctx)?;
2724                        write!(self.out, ", ")?;
2725                        match mode {
2726                            crate::GatherMode::BroadcastFirst => unreachable!(),
2727                            crate::GatherMode::Broadcast(index)
2728                            | crate::GatherMode::Shuffle(index) => {
2729                                self.write_expr(module, index, func_ctx)?;
2730                            }
2731                            crate::GatherMode::ShuffleDown(index) => {
2732                                write!(self.out, "WaveGetLaneIndex() + ")?;
2733                                self.write_expr(module, index, func_ctx)?;
2734                            }
2735                            crate::GatherMode::ShuffleUp(index) => {
2736                                write!(self.out, "WaveGetLaneIndex() - ")?;
2737                                self.write_expr(module, index, func_ctx)?;
2738                            }
2739                            crate::GatherMode::ShuffleXor(index) => {
2740                                write!(self.out, "WaveGetLaneIndex() ^ ")?;
2741                                self.write_expr(module, index, func_ctx)?;
2742                            }
2743                            crate::GatherMode::QuadBroadcast(_) => unreachable!(),
2744                            crate::GatherMode::QuadSwap(_) => unreachable!(),
2745                        }
2746                    }
2747                }
2748                writeln!(self.out, ");")?;
2749            }
2750        }
2751
2752        Ok(())
2753    }
2754
2755    fn write_const_expression(
2756        &mut self,
2757        module: &Module,
2758        expr: Handle<crate::Expression>,
2759        arena: &crate::Arena<crate::Expression>,
2760    ) -> BackendResult {
2761        self.write_possibly_const_expression(module, expr, arena, |writer, expr| {
2762            writer.write_const_expression(module, expr, arena)
2763        })
2764    }
2765
2766    pub(super) fn write_literal(&mut self, literal: crate::Literal) -> BackendResult {
2767        match literal {
2768            crate::Literal::F64(value) => write!(self.out, "{value:?}L")?,
2769            crate::Literal::F32(value) => write!(self.out, "{value:?}")?,
2770            crate::Literal::F16(value) => write!(self.out, "{value:?}h")?,
2771            crate::Literal::U32(value) => write!(self.out, "{value}u")?,
2772            // `-2147483648` is parsed by some compilers as unary negation of
2773            // positive 2147483648, which is too large for an int, causing
2774            // issues for some compilers. Neither DXC nor FXC appear to have
2775            // this problem, but this is not specified and could change. We
2776            // therefore use `-2147483647 - 1` as a precaution.
2777            crate::Literal::I32(value) if value == i32::MIN => {
2778                write!(self.out, "int({} - 1)", value + 1)?
2779            }
2780            // HLSL has no suffix for explicit i32 literals, but not using any suffix
2781            // makes the type ambiguous which prevents overload resolution from
2782            // working. So we explicitly use the int() constructor syntax.
2783            crate::Literal::I32(value) => write!(self.out, "int({value})")?,
2784            crate::Literal::U64(value) => write!(self.out, "{value}uL")?,
2785            // I64 version of the minimum I32 value issue described above.
2786            crate::Literal::I64(value) if value == i64::MIN => {
2787                write!(self.out, "({}L - 1L)", value + 1)?;
2788            }
2789            crate::Literal::I64(value) => write!(self.out, "{value}L")?,
2790            crate::Literal::Bool(value) => write!(self.out, "{value}")?,
2791            crate::Literal::AbstractInt(_) | crate::Literal::AbstractFloat(_) => {
2792                return Err(Error::Custom(
2793                    "Abstract types should not appear in IR presented to backends".into(),
2794                ));
2795            }
2796        }
2797        Ok(())
2798    }
2799
2800    fn write_possibly_const_expression<E>(
2801        &mut self,
2802        module: &Module,
2803        expr: Handle<crate::Expression>,
2804        expressions: &crate::Arena<crate::Expression>,
2805        write_expression: E,
2806    ) -> BackendResult
2807    where
2808        E: Fn(&mut Self, Handle<crate::Expression>) -> BackendResult,
2809    {
2810        use crate::Expression;
2811
2812        match expressions[expr] {
2813            Expression::Literal(literal) => {
2814                self.write_literal(literal)?;
2815            }
2816            Expression::Constant(handle) => {
2817                let constant = &module.constants[handle];
2818                if constant.name.is_some() {
2819                    write!(self.out, "{}", self.names[&NameKey::Constant(handle)])?;
2820                } else {
2821                    self.write_const_expression(module, constant.init, &module.global_expressions)?;
2822                }
2823            }
2824            Expression::ZeroValue(ty) => {
2825                self.write_wrapped_zero_value_function_name(module, WrappedZeroValue { ty })?;
2826                write!(self.out, "()")?;
2827            }
2828            Expression::Compose { ty, ref components } => {
2829                match module.types[ty].inner {
2830                    TypeInner::Struct { .. } | TypeInner::Array { .. } => {
2831                        self.write_wrapped_constructor_function_name(
2832                            module,
2833                            WrappedConstructor { ty },
2834                        )?;
2835                    }
2836                    _ => {
2837                        self.write_type(module, ty)?;
2838                    }
2839                };
2840                write!(self.out, "(")?;
2841                for (index, component) in components.iter().enumerate() {
2842                    if index != 0 {
2843                        write!(self.out, ", ")?;
2844                    }
2845                    write_expression(self, *component)?;
2846                }
2847                write!(self.out, ")")?;
2848            }
2849            Expression::Splat { size, value } => {
2850                // hlsl is not supported one value constructor
2851                // if we write, for example, int4(0), dxc returns error:
2852                // error: too few elements in vector initialization (expected 4 elements, have 1)
2853                let number_of_components = match size {
2854                    crate::VectorSize::Bi => "xx",
2855                    crate::VectorSize::Tri => "xxx",
2856                    crate::VectorSize::Quad => "xxxx",
2857                };
2858                write!(self.out, "(")?;
2859                write_expression(self, value)?;
2860                write!(self.out, ").{number_of_components}")?
2861            }
2862            _ => {
2863                return Err(Error::Override);
2864            }
2865        }
2866
2867        Ok(())
2868    }
2869
2870    /// Helper method to write expressions
2871    ///
2872    /// # Notes
2873    /// Doesn't add any newlines or leading/trailing spaces
2874    pub(super) fn write_expr(
2875        &mut self,
2876        module: &Module,
2877        expr: Handle<crate::Expression>,
2878        func_ctx: &back::FunctionCtx<'_>,
2879    ) -> BackendResult {
2880        use crate::Expression;
2881
2882        // Handle the special semantics of vertex_index/instance_index
2883        let ff_input = if self.options.special_constants_binding.is_some() {
2884            func_ctx.is_fixed_function_input(expr, module)
2885        } else {
2886            None
2887        };
2888        let closing_bracket = match ff_input {
2889            Some(crate::BuiltIn::VertexIndex) => {
2890                write!(self.out, "({SPECIAL_CBUF_VAR}.{SPECIAL_FIRST_VERTEX} + ")?;
2891                ")"
2892            }
2893            Some(crate::BuiltIn::InstanceIndex) => {
2894                write!(self.out, "({SPECIAL_CBUF_VAR}.{SPECIAL_FIRST_INSTANCE} + ",)?;
2895                ")"
2896            }
2897            Some(crate::BuiltIn::NumWorkGroups) => {
2898                // Note: despite their names (`FIRST_VERTEX` and `FIRST_INSTANCE`),
2899                // in compute shaders the special constants contain the number
2900                // of workgroups, which we are using here.
2901                write!(
2902                    self.out,
2903                    "uint3({SPECIAL_CBUF_VAR}.{SPECIAL_FIRST_VERTEX}, {SPECIAL_CBUF_VAR}.{SPECIAL_FIRST_INSTANCE}, {SPECIAL_CBUF_VAR}.{SPECIAL_OTHER})",
2904                )?;
2905                return Ok(());
2906            }
2907            _ => "",
2908        };
2909
2910        if let Some(name) = self.named_expressions.get(&expr) {
2911            write!(self.out, "{name}{closing_bracket}")?;
2912            return Ok(());
2913        }
2914
2915        let expression = &func_ctx.expressions[expr];
2916
2917        match *expression {
2918            Expression::Literal(_)
2919            | Expression::Constant(_)
2920            | Expression::ZeroValue(_)
2921            | Expression::Compose { .. }
2922            | Expression::Splat { .. } => {
2923                self.write_possibly_const_expression(
2924                    module,
2925                    expr,
2926                    func_ctx.expressions,
2927                    |writer, expr| writer.write_expr(module, expr, func_ctx),
2928                )?;
2929            }
2930            Expression::Override(_) => return Err(Error::Override),
2931            // Avoid undefined behaviour for addition, subtraction, and
2932            // multiplication of signed integers by casting operands to
2933            // unsigned, performing the operation, then casting the result back
2934            // to signed.
2935            // TODO(#7109): This relies on the asint()/asuint() functions which only work
2936            // for 32-bit types, so we must find another solution for different bit widths.
2937            Expression::Binary {
2938                op:
2939                    op @ crate::BinaryOperator::Add
2940                    | op @ crate::BinaryOperator::Subtract
2941                    | op @ crate::BinaryOperator::Multiply,
2942                left,
2943                right,
2944            } if matches!(
2945                func_ctx.resolve_type(expr, &module.types).scalar(),
2946                Some(Scalar::I32)
2947            ) =>
2948            {
2949                write!(self.out, "asint(asuint(",)?;
2950                self.write_expr(module, left, func_ctx)?;
2951                write!(self.out, ") {} asuint(", back::binary_operation_str(op))?;
2952                self.write_expr(module, right, func_ctx)?;
2953                write!(self.out, "))")?;
2954            }
2955            // All of the multiplication can be expressed as `mul`,
2956            // except vector * vector, which needs to use the "*" operator.
2957            Expression::Binary {
2958                op: crate::BinaryOperator::Multiply,
2959                left,
2960                right,
2961            } if func_ctx.resolve_type(left, &module.types).is_matrix()
2962                || func_ctx.resolve_type(right, &module.types).is_matrix() =>
2963            {
2964                // We intentionally flip the order of multiplication as our matrices are implicitly transposed.
2965                write!(self.out, "mul(")?;
2966                self.write_expr(module, right, func_ctx)?;
2967                write!(self.out, ", ")?;
2968                self.write_expr(module, left, func_ctx)?;
2969                write!(self.out, ")")?;
2970            }
2971
2972            // WGSL says that floating-point division by zero should return
2973            // infinity. Microsoft's Direct3D 11 functional specification
2974            // (https://microsoft.github.io/DirectX-Specs/d3d/archive/D3D11_3_FunctionalSpec.htm)
2975            // says:
2976            //
2977            //     Divide by 0 produces +/- INF, except 0/0 which results in NaN.
2978            //
2979            // which is what we want. The DXIL specification for the FDiv
2980            // instruction corroborates this:
2981            //
2982            // https://github.com/microsoft/DirectXShaderCompiler/blob/main/docs/DXIL.rst#fdiv
2983            Expression::Binary {
2984                op: crate::BinaryOperator::Divide,
2985                left,
2986                right,
2987            } if matches!(
2988                func_ctx.resolve_type(expr, &module.types).scalar_kind(),
2989                Some(ScalarKind::Sint | ScalarKind::Uint)
2990            ) =>
2991            {
2992                write!(self.out, "{DIV_FUNCTION}(")?;
2993                self.write_expr(module, left, func_ctx)?;
2994                write!(self.out, ", ")?;
2995                self.write_expr(module, right, func_ctx)?;
2996                write!(self.out, ")")?;
2997            }
2998
2999            Expression::Binary {
3000                op: crate::BinaryOperator::Modulo,
3001                left,
3002                right,
3003            } if matches!(
3004                func_ctx.resolve_type(expr, &module.types).scalar_kind(),
3005                Some(ScalarKind::Sint | ScalarKind::Uint | ScalarKind::Float)
3006            ) =>
3007            {
3008                write!(self.out, "{MOD_FUNCTION}(")?;
3009                self.write_expr(module, left, func_ctx)?;
3010                write!(self.out, ", ")?;
3011                self.write_expr(module, right, func_ctx)?;
3012                write!(self.out, ")")?;
3013            }
3014
3015            Expression::Binary { op, left, right } => {
3016                write!(self.out, "(")?;
3017                self.write_expr(module, left, func_ctx)?;
3018                write!(self.out, " {} ", back::binary_operation_str(op))?;
3019                self.write_expr(module, right, func_ctx)?;
3020                write!(self.out, ")")?;
3021            }
3022            Expression::Access { base, index } => {
3023                if let Some(crate::AddressSpace::Storage { .. }) =
3024                    func_ctx.resolve_type(expr, &module.types).pointer_space()
3025                {
3026                    // do nothing, the chain is written on `Load`/`Store`
3027                } else {
3028                    // We use the function __get_col_of_matCx2 here in cases
3029                    // where `base`s type resolves to a matCx2 and is part of a
3030                    // struct member with type of (possibly nested) array of matCx2's.
3031                    //
3032                    // Note that this only works for `Load`s and we handle
3033                    // `Store`s differently in `Statement::Store`.
3034                    if let Some(MatrixType {
3035                        columns,
3036                        rows: crate::VectorSize::Bi,
3037                        width: 4,
3038                    }) = get_inner_matrix_of_struct_array_member(module, base, func_ctx, true)
3039                        .or_else(|| get_global_uniform_matrix(module, base, func_ctx))
3040                    {
3041                        write!(self.out, "__get_col_of_mat{}x2(", columns as u8)?;
3042                        self.write_expr(module, base, func_ctx)?;
3043                        write!(self.out, ", ")?;
3044                        self.write_expr(module, index, func_ctx)?;
3045                        write!(self.out, ")")?;
3046                        return Ok(());
3047                    }
3048
3049                    let resolved = func_ctx.resolve_type(base, &module.types);
3050
3051                    let (indexing_binding_array, non_uniform_qualifier) = match *resolved {
3052                        TypeInner::BindingArray { .. } => {
3053                            let uniformity = &func_ctx.info[index].uniformity;
3054
3055                            (true, uniformity.non_uniform_result.is_some())
3056                        }
3057                        _ => (false, false),
3058                    };
3059
3060                    self.write_expr(module, base, func_ctx)?;
3061
3062                    let array_sampler_info = self.sampler_binding_array_info_from_expression(
3063                        module, func_ctx, base, resolved,
3064                    );
3065
3066                    if let Some(ref info) = array_sampler_info {
3067                        write!(self.out, "{}[", info.sampler_heap_name)?;
3068                    } else {
3069                        write!(self.out, "[")?;
3070                    }
3071
3072                    let needs_bound_check = self.options.restrict_indexing
3073                        && !indexing_binding_array
3074                        && match resolved.pointer_space() {
3075                            Some(
3076                                crate::AddressSpace::Function
3077                                | crate::AddressSpace::Private
3078                                | crate::AddressSpace::WorkGroup
3079                                | crate::AddressSpace::PushConstant,
3080                            )
3081                            | None => true,
3082                            Some(crate::AddressSpace::Uniform) => {
3083                                // check if BindTarget.restrict_indexing is set, this is used for dynamic buffers
3084                                let var_handle = self.fill_access_chain(module, base, func_ctx)?;
3085                                let bind_target = self
3086                                    .options
3087                                    .resolve_resource_binding(
3088                                        module.global_variables[var_handle]
3089                                            .binding
3090                                            .as_ref()
3091                                            .unwrap(),
3092                                    )
3093                                    .unwrap();
3094                                bind_target.restrict_indexing
3095                            }
3096                            Some(
3097                                crate::AddressSpace::Handle | crate::AddressSpace::Storage { .. },
3098                            ) => unreachable!(),
3099                        };
3100                    // Decide whether this index needs to be clamped to fall within range.
3101                    let restriction_needed = if needs_bound_check {
3102                        index::access_needs_check(
3103                            base,
3104                            index::GuardedIndex::Expression(index),
3105                            module,
3106                            func_ctx.expressions,
3107                            func_ctx.info,
3108                        )
3109                    } else {
3110                        None
3111                    };
3112                    if let Some(limit) = restriction_needed {
3113                        write!(self.out, "min(uint(")?;
3114                        self.write_expr(module, index, func_ctx)?;
3115                        write!(self.out, "), ")?;
3116                        match limit {
3117                            index::IndexableLength::Known(limit) => {
3118                                write!(self.out, "{}u", limit - 1)?;
3119                            }
3120                            index::IndexableLength::Dynamic => unreachable!(),
3121                        }
3122                        write!(self.out, ")")?;
3123                    } else {
3124                        if non_uniform_qualifier {
3125                            write!(self.out, "NonUniformResourceIndex(")?;
3126                        }
3127                        if let Some(ref info) = array_sampler_info {
3128                            write!(
3129                                self.out,
3130                                "{}[{} + ",
3131                                info.sampler_index_buffer_name, info.binding_array_base_index_name,
3132                            )?;
3133                        }
3134                        self.write_expr(module, index, func_ctx)?;
3135                        if array_sampler_info.is_some() {
3136                            write!(self.out, "]")?;
3137                        }
3138                        if non_uniform_qualifier {
3139                            write!(self.out, ")")?;
3140                        }
3141                    }
3142
3143                    write!(self.out, "]")?;
3144                }
3145            }
3146            Expression::AccessIndex { base, index } => {
3147                if let Some(crate::AddressSpace::Storage { .. }) =
3148                    func_ctx.resolve_type(expr, &module.types).pointer_space()
3149                {
3150                    // do nothing, the chain is written on `Load`/`Store`
3151                } else {
3152                    // See if we need to write the matrix column access in a
3153                    // special way since the type of `base` is our special
3154                    // __matCx2 struct.
3155                    if let Some(MatrixType {
3156                        rows: crate::VectorSize::Bi,
3157                        width: 4,
3158                        ..
3159                    }) = get_inner_matrix_of_struct_array_member(module, base, func_ctx, true)
3160                        .or_else(|| get_global_uniform_matrix(module, base, func_ctx))
3161                    {
3162                        self.write_expr(module, base, func_ctx)?;
3163                        write!(self.out, "._{index}")?;
3164                        return Ok(());
3165                    }
3166
3167                    let base_ty_res = &func_ctx.info[base].ty;
3168                    let mut resolved = base_ty_res.inner_with(&module.types);
3169                    let base_ty_handle = match *resolved {
3170                        TypeInner::Pointer { base, .. } => {
3171                            resolved = &module.types[base].inner;
3172                            Some(base)
3173                        }
3174                        _ => base_ty_res.handle(),
3175                    };
3176
3177                    // We treat matrices of the form `matCx2` as a sequence of C `vec2`s.
3178                    // See the module-level block comment in mod.rs for details.
3179                    //
3180                    // We handle matrix reconstruction here for Loads.
3181                    // Stores are handled directly by `Statement::Store`.
3182                    if let TypeInner::Struct { ref members, .. } = *resolved {
3183                        let member = &members[index as usize];
3184
3185                        match module.types[member.ty].inner {
3186                            TypeInner::Matrix {
3187                                rows: crate::VectorSize::Bi,
3188                                ..
3189                            } if member.binding.is_none() => {
3190                                let ty = base_ty_handle.unwrap();
3191                                self.write_wrapped_struct_matrix_get_function_name(
3192                                    WrappedStructMatrixAccess { ty, index },
3193                                )?;
3194                                write!(self.out, "(")?;
3195                                self.write_expr(module, base, func_ctx)?;
3196                                write!(self.out, ")")?;
3197                                return Ok(());
3198                            }
3199                            _ => {}
3200                        }
3201                    }
3202
3203                    let array_sampler_info = self.sampler_binding_array_info_from_expression(
3204                        module, func_ctx, base, resolved,
3205                    );
3206
3207                    if let Some(ref info) = array_sampler_info {
3208                        write!(
3209                            self.out,
3210                            "{}[{}",
3211                            info.sampler_heap_name, info.sampler_index_buffer_name
3212                        )?;
3213                    }
3214
3215                    self.write_expr(module, base, func_ctx)?;
3216
3217                    match *resolved {
3218                        // We specifically lift the ValuePointer to this case. While `[0]` is valid
3219                        // HLSL for any vector behind a value pointer, FXC completely miscompiles
3220                        // it and generates completely nonsensical DXBC.
3221                        //
3222                        // See https://github.com/gfx-rs/naga/issues/2095 for more details.
3223                        TypeInner::Vector { .. } | TypeInner::ValuePointer { .. } => {
3224                            // Write vector access as a swizzle
3225                            write!(self.out, ".{}", back::COMPONENTS[index as usize])?
3226                        }
3227                        TypeInner::Matrix { .. }
3228                        | TypeInner::Array { .. }
3229                        | TypeInner::BindingArray { .. } => {
3230                            if let Some(ref info) = array_sampler_info {
3231                                write!(
3232                                    self.out,
3233                                    "[{} + {index}]",
3234                                    info.binding_array_base_index_name
3235                                )?;
3236                            } else {
3237                                write!(self.out, "[{index}]")?;
3238                            }
3239                        }
3240                        TypeInner::Struct { .. } => {
3241                            // This will never panic in case the type is a `Struct`, this is not true
3242                            // for other types so we can only check while inside this match arm
3243                            let ty = base_ty_handle.unwrap();
3244
3245                            write!(
3246                                self.out,
3247                                ".{}",
3248                                &self.names[&NameKey::StructMember(ty, index)]
3249                            )?
3250                        }
3251                        ref other => return Err(Error::Custom(format!("Cannot index {other:?}"))),
3252                    }
3253
3254                    if array_sampler_info.is_some() {
3255                        write!(self.out, "]")?;
3256                    }
3257                }
3258            }
3259            Expression::FunctionArgument(pos) => {
3260                let ty = func_ctx.resolve_type(expr, &module.types);
3261
3262                // We know that any external texture function argument has been expanded into
3263                // separate consecutive arguments for each plane and the parameters buffer. And we
3264                // also know that external textures can only ever be used as an argument to another
3265                // function. Therefore we can simply emit each of the expanded arguments in a
3266                // consecutive comma-separated list.
3267                if let TypeInner::Image {
3268                    class: crate::ImageClass::External,
3269                    ..
3270                } = *ty
3271                {
3272                    let plane_names = [0, 1, 2].map(|i| {
3273                        &self.names[&func_ctx
3274                            .external_texture_argument_key(pos, ExternalTextureNameKey::Plane(i))]
3275                    });
3276                    let params_name = &self.names[&func_ctx
3277                        .external_texture_argument_key(pos, ExternalTextureNameKey::Params)];
3278                    write!(
3279                        self.out,
3280                        "{}, {}, {}, {}",
3281                        plane_names[0], plane_names[1], plane_names[2], params_name
3282                    )?;
3283                } else {
3284                    let key = func_ctx.argument_key(pos);
3285                    let name = &self.names[&key];
3286                    write!(self.out, "{name}")?;
3287                }
3288            }
3289            Expression::ImageSample {
3290                coordinate,
3291                image,
3292                sampler,
3293                clamp_to_edge: true,
3294                gather: None,
3295                array_index: None,
3296                offset: None,
3297                level: crate::SampleLevel::Zero,
3298                depth_ref: None,
3299            } => {
3300                write!(self.out, "{IMAGE_SAMPLE_BASE_CLAMP_TO_EDGE_FUNCTION}(")?;
3301                self.write_expr(module, image, func_ctx)?;
3302                write!(self.out, ", ")?;
3303                self.write_expr(module, sampler, func_ctx)?;
3304                write!(self.out, ", ")?;
3305                self.write_expr(module, coordinate, func_ctx)?;
3306                write!(self.out, ")")?;
3307            }
3308            Expression::ImageSample {
3309                image,
3310                sampler,
3311                gather,
3312                coordinate,
3313                array_index,
3314                offset,
3315                level,
3316                depth_ref,
3317                clamp_to_edge,
3318            } => {
3319                if clamp_to_edge {
3320                    return Err(Error::Custom(
3321                        "ImageSample::clamp_to_edge should have been validated out".to_string(),
3322                    ));
3323                }
3324
3325                use crate::SampleLevel as Sl;
3326                const COMPONENTS: [&str; 4] = ["", "Green", "Blue", "Alpha"];
3327
3328                let (base_str, component_str) = match gather {
3329                    Some(component) => ("Gather", COMPONENTS[component as usize]),
3330                    None => ("Sample", ""),
3331                };
3332                let cmp_str = match depth_ref {
3333                    Some(_) => "Cmp",
3334                    None => "",
3335                };
3336                let level_str = match level {
3337                    Sl::Zero if gather.is_none() => "LevelZero",
3338                    Sl::Auto | Sl::Zero => "",
3339                    Sl::Exact(_) => "Level",
3340                    Sl::Bias(_) => "Bias",
3341                    Sl::Gradient { .. } => "Grad",
3342                };
3343
3344                self.write_expr(module, image, func_ctx)?;
3345                write!(self.out, ".{base_str}{cmp_str}{component_str}{level_str}(")?;
3346                self.write_expr(module, sampler, func_ctx)?;
3347                write!(self.out, ", ")?;
3348                self.write_texture_coordinates(
3349                    "float",
3350                    coordinate,
3351                    array_index,
3352                    None,
3353                    module,
3354                    func_ctx,
3355                )?;
3356
3357                if let Some(depth_ref) = depth_ref {
3358                    write!(self.out, ", ")?;
3359                    self.write_expr(module, depth_ref, func_ctx)?;
3360                }
3361
3362                match level {
3363                    Sl::Auto | Sl::Zero => {}
3364                    Sl::Exact(expr) => {
3365                        write!(self.out, ", ")?;
3366                        self.write_expr(module, expr, func_ctx)?;
3367                    }
3368                    Sl::Bias(expr) => {
3369                        write!(self.out, ", ")?;
3370                        self.write_expr(module, expr, func_ctx)?;
3371                    }
3372                    Sl::Gradient { x, y } => {
3373                        write!(self.out, ", ")?;
3374                        self.write_expr(module, x, func_ctx)?;
3375                        write!(self.out, ", ")?;
3376                        self.write_expr(module, y, func_ctx)?;
3377                    }
3378                }
3379
3380                if let Some(offset) = offset {
3381                    write!(self.out, ", ")?;
3382                    write!(self.out, "int2(")?; // work around https://github.com/microsoft/DirectXShaderCompiler/issues/5082#issuecomment-1540147807
3383                    self.write_const_expression(module, offset, func_ctx.expressions)?;
3384                    write!(self.out, ")")?;
3385                }
3386
3387                write!(self.out, ")")?;
3388            }
3389            Expression::ImageQuery { image, query } => {
3390                // use wrapped image query function
3391                if let TypeInner::Image {
3392                    dim,
3393                    arrayed,
3394                    class,
3395                } = *func_ctx.resolve_type(image, &module.types)
3396                {
3397                    let wrapped_image_query = WrappedImageQuery {
3398                        dim,
3399                        arrayed,
3400                        class,
3401                        query: query.into(),
3402                    };
3403
3404                    self.write_wrapped_image_query_function_name(wrapped_image_query)?;
3405                    write!(self.out, "(")?;
3406                    // Image always first param
3407                    self.write_expr(module, image, func_ctx)?;
3408                    if let crate::ImageQuery::Size { level: Some(level) } = query {
3409                        write!(self.out, ", ")?;
3410                        self.write_expr(module, level, func_ctx)?;
3411                    }
3412                    write!(self.out, ")")?;
3413                }
3414            }
3415            Expression::ImageLoad {
3416                image,
3417                coordinate,
3418                array_index,
3419                sample,
3420                level,
3421            } => self.write_image_load(
3422                &module,
3423                expr,
3424                func_ctx,
3425                image,
3426                coordinate,
3427                array_index,
3428                sample,
3429                level,
3430            )?,
3431            Expression::GlobalVariable(handle) => {
3432                let global_variable = &module.global_variables[handle];
3433                let ty = &module.types[global_variable.ty].inner;
3434
3435                // In the case of binding arrays of samplers, we need to not write anything
3436                // as the we are in the wrong position to fully write the expression.
3437                //
3438                // The entire writing is done by AccessIndex.
3439                let is_binding_array_of_samplers = match *ty {
3440                    TypeInner::BindingArray { base, .. } => {
3441                        let base_ty = &module.types[base].inner;
3442                        matches!(*base_ty, TypeInner::Sampler { .. })
3443                    }
3444                    _ => false,
3445                };
3446
3447                let is_storage_space =
3448                    matches!(global_variable.space, crate::AddressSpace::Storage { .. });
3449
3450                // Our external texture global variable has been expanded into multiple
3451                // global variables, one for each plane and the parameters buffer.
3452                // External textures can only ever be used as arguments to a function
3453                // call, and we know that an external texture argument to any function
3454                // will have been expanded to separate consecutive arguments for each
3455                // plane and the parameters buffer. Therefore we can simply emit each of
3456                // the expanded global variables in a consecutive comma-separated list.
3457                if let TypeInner::Image {
3458                    class: crate::ImageClass::External,
3459                    ..
3460                } = *ty
3461                {
3462                    let plane_names = [0, 1, 2].map(|i| {
3463                        &self.names[&NameKey::ExternalTextureGlobalVariable(
3464                            handle,
3465                            ExternalTextureNameKey::Plane(i),
3466                        )]
3467                    });
3468                    let params_name = &self.names[&NameKey::ExternalTextureGlobalVariable(
3469                        handle,
3470                        ExternalTextureNameKey::Params,
3471                    )];
3472                    write!(
3473                        self.out,
3474                        "{}, {}, {}, {}",
3475                        plane_names[0], plane_names[1], plane_names[2], params_name
3476                    )?;
3477                } else if !is_binding_array_of_samplers && !is_storage_space {
3478                    let name = &self.names[&NameKey::GlobalVariable(handle)];
3479                    write!(self.out, "{name}")?;
3480                }
3481            }
3482            Expression::LocalVariable(handle) => {
3483                write!(self.out, "{}", self.names[&func_ctx.name_key(handle)])?
3484            }
3485            Expression::Load { pointer } => {
3486                match func_ctx
3487                    .resolve_type(pointer, &module.types)
3488                    .pointer_space()
3489                {
3490                    Some(crate::AddressSpace::Storage { .. }) => {
3491                        let var_handle = self.fill_access_chain(module, pointer, func_ctx)?;
3492                        let result_ty = func_ctx.info[expr].ty.clone();
3493                        self.write_storage_load(module, var_handle, result_ty, func_ctx)?;
3494                    }
3495                    _ => {
3496                        let mut close_paren = false;
3497
3498                        // We cast the value loaded to a native HLSL floatCx2
3499                        // in cases where it is of type:
3500                        //  - __matCx2 or
3501                        //  - a (possibly nested) array of __matCx2's
3502                        if let Some(MatrixType {
3503                            rows: crate::VectorSize::Bi,
3504                            width: 4,
3505                            ..
3506                        }) = get_inner_matrix_of_struct_array_member(
3507                            module, pointer, func_ctx, false,
3508                        )
3509                        .or_else(|| get_inner_matrix_of_global_uniform(module, pointer, func_ctx))
3510                        {
3511                            let mut resolved = func_ctx.resolve_type(pointer, &module.types);
3512                            let ptr_tr = resolved.pointer_base_type();
3513                            if let Some(ptr_ty) =
3514                                ptr_tr.as_ref().map(|tr| tr.inner_with(&module.types))
3515                            {
3516                                resolved = ptr_ty;
3517                            }
3518
3519                            write!(self.out, "((")?;
3520                            if let TypeInner::Array { base, size, .. } = *resolved {
3521                                self.write_type(module, base)?;
3522                                self.write_array_size(module, base, size)?;
3523                            } else {
3524                                self.write_value_type(module, resolved)?;
3525                            }
3526                            write!(self.out, ")")?;
3527                            close_paren = true;
3528                        }
3529
3530                        self.write_expr(module, pointer, func_ctx)?;
3531
3532                        if close_paren {
3533                            write!(self.out, ")")?;
3534                        }
3535                    }
3536                }
3537            }
3538            Expression::Unary { op, expr } => {
3539                // https://docs.microsoft.com/en-us/windows/win32/direct3dhlsl/dx-graphics-hlsl-operators#unary-operators
3540                let op_str = match op {
3541                    crate::UnaryOperator::Negate => {
3542                        match func_ctx.resolve_type(expr, &module.types).scalar() {
3543                            Some(Scalar::I32) => NEG_FUNCTION,
3544                            _ => "-",
3545                        }
3546                    }
3547                    crate::UnaryOperator::LogicalNot => "!",
3548                    crate::UnaryOperator::BitwiseNot => "~",
3549                };
3550                write!(self.out, "{op_str}(")?;
3551                self.write_expr(module, expr, func_ctx)?;
3552                write!(self.out, ")")?;
3553            }
3554            Expression::As {
3555                expr,
3556                kind,
3557                convert,
3558            } => {
3559                let inner = func_ctx.resolve_type(expr, &module.types);
3560                if inner.scalar_kind() == Some(ScalarKind::Float)
3561                    && (kind == ScalarKind::Sint || kind == ScalarKind::Uint)
3562                    && convert.is_some()
3563                {
3564                    // Use helper functions for float to int casts in order to
3565                    // avoid undefined behaviour when value is out of range for
3566                    // the target type.
3567                    let fun_name = match (kind, convert) {
3568                        (ScalarKind::Sint, Some(4)) => F2I32_FUNCTION,
3569                        (ScalarKind::Uint, Some(4)) => F2U32_FUNCTION,
3570                        (ScalarKind::Sint, Some(8)) => F2I64_FUNCTION,
3571                        (ScalarKind::Uint, Some(8)) => F2U64_FUNCTION,
3572                        _ => unreachable!(),
3573                    };
3574                    write!(self.out, "{fun_name}(")?;
3575                    self.write_expr(module, expr, func_ctx)?;
3576                    write!(self.out, ")")?;
3577                } else {
3578                    let close_paren = match convert {
3579                        Some(dst_width) => {
3580                            let scalar = Scalar {
3581                                kind,
3582                                width: dst_width,
3583                            };
3584                            match *inner {
3585                                TypeInner::Vector { size, .. } => {
3586                                    write!(
3587                                        self.out,
3588                                        "{}{}(",
3589                                        scalar.to_hlsl_str()?,
3590                                        common::vector_size_str(size)
3591                                    )?;
3592                                }
3593                                TypeInner::Scalar(_) => {
3594                                    write!(self.out, "{}(", scalar.to_hlsl_str()?,)?;
3595                                }
3596                                TypeInner::Matrix { columns, rows, .. } => {
3597                                    write!(
3598                                        self.out,
3599                                        "{}{}x{}(",
3600                                        scalar.to_hlsl_str()?,
3601                                        common::vector_size_str(columns),
3602                                        common::vector_size_str(rows)
3603                                    )?;
3604                                }
3605                                _ => {
3606                                    return Err(Error::Unimplemented(format!(
3607                                        "write_expr expression::as {inner:?}"
3608                                    )));
3609                                }
3610                            };
3611                            true
3612                        }
3613                        None => {
3614                            if inner.scalar_width() == Some(8) {
3615                                false
3616                            } else {
3617                                write!(self.out, "{}(", kind.to_hlsl_cast(),)?;
3618                                true
3619                            }
3620                        }
3621                    };
3622                    self.write_expr(module, expr, func_ctx)?;
3623                    if close_paren {
3624                        write!(self.out, ")")?;
3625                    }
3626                }
3627            }
3628            Expression::Math {
3629                fun,
3630                arg,
3631                arg1,
3632                arg2,
3633                arg3,
3634            } => {
3635                use crate::MathFunction as Mf;
3636
3637                enum Function {
3638                    Asincosh { is_sin: bool },
3639                    Atanh,
3640                    Pack2x16float,
3641                    Pack2x16snorm,
3642                    Pack2x16unorm,
3643                    Pack4x8snorm,
3644                    Pack4x8unorm,
3645                    Pack4xI8,
3646                    Pack4xU8,
3647                    Pack4xI8Clamp,
3648                    Pack4xU8Clamp,
3649                    Unpack2x16float,
3650                    Unpack2x16snorm,
3651                    Unpack2x16unorm,
3652                    Unpack4x8snorm,
3653                    Unpack4x8unorm,
3654                    Unpack4xI8,
3655                    Unpack4xU8,
3656                    Dot4I8Packed,
3657                    Dot4U8Packed,
3658                    QuantizeToF16,
3659                    Regular(&'static str),
3660                    MissingIntOverload(&'static str),
3661                    MissingIntReturnType(&'static str),
3662                    CountTrailingZeros,
3663                    CountLeadingZeros,
3664                }
3665
3666                let fun = match fun {
3667                    // comparison
3668                    Mf::Abs => match func_ctx.resolve_type(arg, &module.types).scalar() {
3669                        Some(Scalar::I32) => Function::Regular(ABS_FUNCTION),
3670                        _ => Function::Regular("abs"),
3671                    },
3672                    Mf::Min => Function::Regular("min"),
3673                    Mf::Max => Function::Regular("max"),
3674                    Mf::Clamp => Function::Regular("clamp"),
3675                    Mf::Saturate => Function::Regular("saturate"),
3676                    // trigonometry
3677                    Mf::Cos => Function::Regular("cos"),
3678                    Mf::Cosh => Function::Regular("cosh"),
3679                    Mf::Sin => Function::Regular("sin"),
3680                    Mf::Sinh => Function::Regular("sinh"),
3681                    Mf::Tan => Function::Regular("tan"),
3682                    Mf::Tanh => Function::Regular("tanh"),
3683                    Mf::Acos => Function::Regular("acos"),
3684                    Mf::Asin => Function::Regular("asin"),
3685                    Mf::Atan => Function::Regular("atan"),
3686                    Mf::Atan2 => Function::Regular("atan2"),
3687                    Mf::Asinh => Function::Asincosh { is_sin: true },
3688                    Mf::Acosh => Function::Asincosh { is_sin: false },
3689                    Mf::Atanh => Function::Atanh,
3690                    Mf::Radians => Function::Regular("radians"),
3691                    Mf::Degrees => Function::Regular("degrees"),
3692                    // decomposition
3693                    Mf::Ceil => Function::Regular("ceil"),
3694                    Mf::Floor => Function::Regular("floor"),
3695                    Mf::Round => Function::Regular("round"),
3696                    Mf::Fract => Function::Regular("frac"),
3697                    Mf::Trunc => Function::Regular("trunc"),
3698                    Mf::Modf => Function::Regular(MODF_FUNCTION),
3699                    Mf::Frexp => Function::Regular(FREXP_FUNCTION),
3700                    Mf::Ldexp => Function::Regular("ldexp"),
3701                    // exponent
3702                    Mf::Exp => Function::Regular("exp"),
3703                    Mf::Exp2 => Function::Regular("exp2"),
3704                    Mf::Log => Function::Regular("log"),
3705                    Mf::Log2 => Function::Regular("log2"),
3706                    Mf::Pow => Function::Regular("pow"),
3707                    // geometry
3708                    Mf::Dot => Function::Regular("dot"),
3709                    Mf::Dot4I8Packed => Function::Dot4I8Packed,
3710                    Mf::Dot4U8Packed => Function::Dot4U8Packed,
3711                    //Mf::Outer => ,
3712                    Mf::Cross => Function::Regular("cross"),
3713                    Mf::Distance => Function::Regular("distance"),
3714                    Mf::Length => Function::Regular("length"),
3715                    Mf::Normalize => Function::Regular("normalize"),
3716                    Mf::FaceForward => Function::Regular("faceforward"),
3717                    Mf::Reflect => Function::Regular("reflect"),
3718                    Mf::Refract => Function::Regular("refract"),
3719                    // computational
3720                    Mf::Sign => Function::Regular("sign"),
3721                    Mf::Fma => Function::Regular("mad"),
3722                    Mf::Mix => Function::Regular("lerp"),
3723                    Mf::Step => Function::Regular("step"),
3724                    Mf::SmoothStep => Function::Regular("smoothstep"),
3725                    Mf::Sqrt => Function::Regular("sqrt"),
3726                    Mf::InverseSqrt => Function::Regular("rsqrt"),
3727                    //Mf::Inverse =>,
3728                    Mf::Transpose => Function::Regular("transpose"),
3729                    Mf::Determinant => Function::Regular("determinant"),
3730                    Mf::QuantizeToF16 => Function::QuantizeToF16,
3731                    // bits
3732                    Mf::CountTrailingZeros => Function::CountTrailingZeros,
3733                    Mf::CountLeadingZeros => Function::CountLeadingZeros,
3734                    Mf::CountOneBits => Function::MissingIntOverload("countbits"),
3735                    Mf::ReverseBits => Function::MissingIntOverload("reversebits"),
3736                    Mf::FirstTrailingBit => Function::MissingIntReturnType("firstbitlow"),
3737                    Mf::FirstLeadingBit => Function::MissingIntReturnType("firstbithigh"),
3738                    Mf::ExtractBits => Function::Regular(EXTRACT_BITS_FUNCTION),
3739                    Mf::InsertBits => Function::Regular(INSERT_BITS_FUNCTION),
3740                    // Data Packing
3741                    Mf::Pack2x16float => Function::Pack2x16float,
3742                    Mf::Pack2x16snorm => Function::Pack2x16snorm,
3743                    Mf::Pack2x16unorm => Function::Pack2x16unorm,
3744                    Mf::Pack4x8snorm => Function::Pack4x8snorm,
3745                    Mf::Pack4x8unorm => Function::Pack4x8unorm,
3746                    Mf::Pack4xI8 => Function::Pack4xI8,
3747                    Mf::Pack4xU8 => Function::Pack4xU8,
3748                    Mf::Pack4xI8Clamp => Function::Pack4xI8Clamp,
3749                    Mf::Pack4xU8Clamp => Function::Pack4xU8Clamp,
3750                    // Data Unpacking
3751                    Mf::Unpack2x16float => Function::Unpack2x16float,
3752                    Mf::Unpack2x16snorm => Function::Unpack2x16snorm,
3753                    Mf::Unpack2x16unorm => Function::Unpack2x16unorm,
3754                    Mf::Unpack4x8snorm => Function::Unpack4x8snorm,
3755                    Mf::Unpack4x8unorm => Function::Unpack4x8unorm,
3756                    Mf::Unpack4xI8 => Function::Unpack4xI8,
3757                    Mf::Unpack4xU8 => Function::Unpack4xU8,
3758                    _ => return Err(Error::Unimplemented(format!("write_expr_math {fun:?}"))),
3759                };
3760
3761                match fun {
3762                    Function::Asincosh { is_sin } => {
3763                        write!(self.out, "log(")?;
3764                        self.write_expr(module, arg, func_ctx)?;
3765                        write!(self.out, " + sqrt(")?;
3766                        self.write_expr(module, arg, func_ctx)?;
3767                        write!(self.out, " * ")?;
3768                        self.write_expr(module, arg, func_ctx)?;
3769                        match is_sin {
3770                            true => write!(self.out, " + 1.0))")?,
3771                            false => write!(self.out, " - 1.0))")?,
3772                        }
3773                    }
3774                    Function::Atanh => {
3775                        write!(self.out, "0.5 * log((1.0 + ")?;
3776                        self.write_expr(module, arg, func_ctx)?;
3777                        write!(self.out, ") / (1.0 - ")?;
3778                        self.write_expr(module, arg, func_ctx)?;
3779                        write!(self.out, "))")?;
3780                    }
3781                    Function::Pack2x16float => {
3782                        write!(self.out, "(f32tof16(")?;
3783                        self.write_expr(module, arg, func_ctx)?;
3784                        write!(self.out, "[0]) | f32tof16(")?;
3785                        self.write_expr(module, arg, func_ctx)?;
3786                        write!(self.out, "[1]) << 16)")?;
3787                    }
3788                    Function::Pack2x16snorm => {
3789                        let scale = 32767;
3790
3791                        write!(self.out, "uint((int(round(clamp(")?;
3792                        self.write_expr(module, arg, func_ctx)?;
3793                        write!(
3794                            self.out,
3795                            "[0], -1.0, 1.0) * {scale}.0)) & 0xFFFF) | ((int(round(clamp("
3796                        )?;
3797                        self.write_expr(module, arg, func_ctx)?;
3798                        write!(self.out, "[1], -1.0, 1.0) * {scale}.0)) & 0xFFFF) << 16))",)?;
3799                    }
3800                    Function::Pack2x16unorm => {
3801                        let scale = 65535;
3802
3803                        write!(self.out, "(uint(round(clamp(")?;
3804                        self.write_expr(module, arg, func_ctx)?;
3805                        write!(self.out, "[0], 0.0, 1.0) * {scale}.0)) | uint(round(clamp(")?;
3806                        self.write_expr(module, arg, func_ctx)?;
3807                        write!(self.out, "[1], 0.0, 1.0) * {scale}.0)) << 16)")?;
3808                    }
3809                    Function::Pack4x8snorm => {
3810                        let scale = 127;
3811
3812                        write!(self.out, "uint((int(round(clamp(")?;
3813                        self.write_expr(module, arg, func_ctx)?;
3814                        write!(
3815                            self.out,
3816                            "[0], -1.0, 1.0) * {scale}.0)) & 0xFF) | ((int(round(clamp("
3817                        )?;
3818                        self.write_expr(module, arg, func_ctx)?;
3819                        write!(
3820                            self.out,
3821                            "[1], -1.0, 1.0) * {scale}.0)) & 0xFF) << 8) | ((int(round(clamp("
3822                        )?;
3823                        self.write_expr(module, arg, func_ctx)?;
3824                        write!(
3825                            self.out,
3826                            "[2], -1.0, 1.0) * {scale}.0)) & 0xFF) << 16) | ((int(round(clamp("
3827                        )?;
3828                        self.write_expr(module, arg, func_ctx)?;
3829                        write!(self.out, "[3], -1.0, 1.0) * {scale}.0)) & 0xFF) << 24))",)?;
3830                    }
3831                    Function::Pack4x8unorm => {
3832                        let scale = 255;
3833
3834                        write!(self.out, "(uint(round(clamp(")?;
3835                        self.write_expr(module, arg, func_ctx)?;
3836                        write!(self.out, "[0], 0.0, 1.0) * {scale}.0)) | uint(round(clamp(")?;
3837                        self.write_expr(module, arg, func_ctx)?;
3838                        write!(
3839                            self.out,
3840                            "[1], 0.0, 1.0) * {scale}.0)) << 8 | uint(round(clamp("
3841                        )?;
3842                        self.write_expr(module, arg, func_ctx)?;
3843                        write!(
3844                            self.out,
3845                            "[2], 0.0, 1.0) * {scale}.0)) << 16 | uint(round(clamp("
3846                        )?;
3847                        self.write_expr(module, arg, func_ctx)?;
3848                        write!(self.out, "[3], 0.0, 1.0) * {scale}.0)) << 24)")?;
3849                    }
3850                    fun @ (Function::Pack4xI8
3851                    | Function::Pack4xU8
3852                    | Function::Pack4xI8Clamp
3853                    | Function::Pack4xU8Clamp) => {
3854                        let was_signed =
3855                            matches!(fun, Function::Pack4xI8 | Function::Pack4xI8Clamp);
3856                        let clamp_bounds = match fun {
3857                            Function::Pack4xI8Clamp => Some(("-128", "127")),
3858                            Function::Pack4xU8Clamp => Some(("0", "255")),
3859                            _ => None,
3860                        };
3861                        if was_signed {
3862                            write!(self.out, "uint(")?;
3863                        }
3864                        let write_arg = |this: &mut Self| -> BackendResult {
3865                            if let Some((min, max)) = clamp_bounds {
3866                                write!(this.out, "clamp(")?;
3867                                this.write_expr(module, arg, func_ctx)?;
3868                                write!(this.out, ", {min}, {max})")?;
3869                            } else {
3870                                this.write_expr(module, arg, func_ctx)?;
3871                            }
3872                            Ok(())
3873                        };
3874                        write!(self.out, "(")?;
3875                        write_arg(self)?;
3876                        write!(self.out, "[0] & 0xFF) | ((")?;
3877                        write_arg(self)?;
3878                        write!(self.out, "[1] & 0xFF) << 8) | ((")?;
3879                        write_arg(self)?;
3880                        write!(self.out, "[2] & 0xFF) << 16) | ((")?;
3881                        write_arg(self)?;
3882                        write!(self.out, "[3] & 0xFF) << 24)")?;
3883                        if was_signed {
3884                            write!(self.out, ")")?;
3885                        }
3886                    }
3887
3888                    Function::Unpack2x16float => {
3889                        write!(self.out, "float2(f16tof32(")?;
3890                        self.write_expr(module, arg, func_ctx)?;
3891                        write!(self.out, "), f16tof32((")?;
3892                        self.write_expr(module, arg, func_ctx)?;
3893                        write!(self.out, ") >> 16))")?;
3894                    }
3895                    Function::Unpack2x16snorm => {
3896                        let scale = 32767;
3897
3898                        write!(self.out, "(float2(int2(")?;
3899                        self.write_expr(module, arg, func_ctx)?;
3900                        write!(self.out, " << 16, ")?;
3901                        self.write_expr(module, arg, func_ctx)?;
3902                        write!(self.out, ") >> 16) / {scale}.0)")?;
3903                    }
3904                    Function::Unpack2x16unorm => {
3905                        let scale = 65535;
3906
3907                        write!(self.out, "(float2(")?;
3908                        self.write_expr(module, arg, func_ctx)?;
3909                        write!(self.out, " & 0xFFFF, ")?;
3910                        self.write_expr(module, arg, func_ctx)?;
3911                        write!(self.out, " >> 16) / {scale}.0)")?;
3912                    }
3913                    Function::Unpack4x8snorm => {
3914                        let scale = 127;
3915
3916                        write!(self.out, "(float4(int4(")?;
3917                        self.write_expr(module, arg, func_ctx)?;
3918                        write!(self.out, " << 24, ")?;
3919                        self.write_expr(module, arg, func_ctx)?;
3920                        write!(self.out, " << 16, ")?;
3921                        self.write_expr(module, arg, func_ctx)?;
3922                        write!(self.out, " << 8, ")?;
3923                        self.write_expr(module, arg, func_ctx)?;
3924                        write!(self.out, ") >> 24) / {scale}.0)")?;
3925                    }
3926                    Function::Unpack4x8unorm => {
3927                        let scale = 255;
3928
3929                        write!(self.out, "(float4(")?;
3930                        self.write_expr(module, arg, func_ctx)?;
3931                        write!(self.out, " & 0xFF, ")?;
3932                        self.write_expr(module, arg, func_ctx)?;
3933                        write!(self.out, " >> 8 & 0xFF, ")?;
3934                        self.write_expr(module, arg, func_ctx)?;
3935                        write!(self.out, " >> 16 & 0xFF, ")?;
3936                        self.write_expr(module, arg, func_ctx)?;
3937                        write!(self.out, " >> 24) / {scale}.0)")?;
3938                    }
3939                    fun @ (Function::Unpack4xI8 | Function::Unpack4xU8) => {
3940                        write!(self.out, "(")?;
3941                        if matches!(fun, Function::Unpack4xU8) {
3942                            write!(self.out, "u")?;
3943                        }
3944                        write!(self.out, "int4(")?;
3945                        self.write_expr(module, arg, func_ctx)?;
3946                        write!(self.out, ", ")?;
3947                        self.write_expr(module, arg, func_ctx)?;
3948                        write!(self.out, " >> 8, ")?;
3949                        self.write_expr(module, arg, func_ctx)?;
3950                        write!(self.out, " >> 16, ")?;
3951                        self.write_expr(module, arg, func_ctx)?;
3952                        write!(self.out, " >> 24) << 24 >> 24)")?;
3953                    }
3954                    fun @ (Function::Dot4I8Packed | Function::Dot4U8Packed) => {
3955                        let arg1 = arg1.unwrap();
3956
3957                        if self.options.shader_model >= ShaderModel::V6_4 {
3958                            // Intrinsics `dot4add_{i, u}8packed` are available in SM 6.4 and later.
3959                            let function_name = match fun {
3960                                Function::Dot4I8Packed => "dot4add_i8packed",
3961                                Function::Dot4U8Packed => "dot4add_u8packed",
3962                                _ => unreachable!(),
3963                            };
3964                            write!(self.out, "{function_name}(")?;
3965                            self.write_expr(module, arg, func_ctx)?;
3966                            write!(self.out, ", ")?;
3967                            self.write_expr(module, arg1, func_ctx)?;
3968                            write!(self.out, ", 0)")?;
3969                        } else {
3970                            // Fall back to a polyfill as `dot4add_u8packed` is not available.
3971                            write!(self.out, "dot(")?;
3972
3973                            if matches!(fun, Function::Dot4U8Packed) {
3974                                write!(self.out, "u")?;
3975                            }
3976                            write!(self.out, "int4(")?;
3977                            self.write_expr(module, arg, func_ctx)?;
3978                            write!(self.out, ", ")?;
3979                            self.write_expr(module, arg, func_ctx)?;
3980                            write!(self.out, " >> 8, ")?;
3981                            self.write_expr(module, arg, func_ctx)?;
3982                            write!(self.out, " >> 16, ")?;
3983                            self.write_expr(module, arg, func_ctx)?;
3984                            write!(self.out, " >> 24) << 24 >> 24, ")?;
3985
3986                            if matches!(fun, Function::Dot4U8Packed) {
3987                                write!(self.out, "u")?;
3988                            }
3989                            write!(self.out, "int4(")?;
3990                            self.write_expr(module, arg1, func_ctx)?;
3991                            write!(self.out, ", ")?;
3992                            self.write_expr(module, arg1, func_ctx)?;
3993                            write!(self.out, " >> 8, ")?;
3994                            self.write_expr(module, arg1, func_ctx)?;
3995                            write!(self.out, " >> 16, ")?;
3996                            self.write_expr(module, arg1, func_ctx)?;
3997                            write!(self.out, " >> 24) << 24 >> 24)")?;
3998                        }
3999                    }
4000                    Function::QuantizeToF16 => {
4001                        write!(self.out, "f16tof32(f32tof16(")?;
4002                        self.write_expr(module, arg, func_ctx)?;
4003                        write!(self.out, "))")?;
4004                    }
4005                    Function::Regular(fun_name) => {
4006                        write!(self.out, "{fun_name}(")?;
4007                        self.write_expr(module, arg, func_ctx)?;
4008                        if let Some(arg) = arg1 {
4009                            write!(self.out, ", ")?;
4010                            self.write_expr(module, arg, func_ctx)?;
4011                        }
4012                        if let Some(arg) = arg2 {
4013                            write!(self.out, ", ")?;
4014                            self.write_expr(module, arg, func_ctx)?;
4015                        }
4016                        if let Some(arg) = arg3 {
4017                            write!(self.out, ", ")?;
4018                            self.write_expr(module, arg, func_ctx)?;
4019                        }
4020                        write!(self.out, ")")?
4021                    }
4022                    // These overloads are only missing on FXC, so this is only needed for 32bit types,
4023                    // as non-32bit types are DXC only.
4024                    Function::MissingIntOverload(fun_name) => {
4025                        let scalar_kind = func_ctx.resolve_type(arg, &module.types).scalar();
4026                        if let Some(Scalar::I32) = scalar_kind {
4027                            write!(self.out, "asint({fun_name}(asuint(")?;
4028                            self.write_expr(module, arg, func_ctx)?;
4029                            write!(self.out, ")))")?;
4030                        } else {
4031                            write!(self.out, "{fun_name}(")?;
4032                            self.write_expr(module, arg, func_ctx)?;
4033                            write!(self.out, ")")?;
4034                        }
4035                    }
4036                    // These overloads are only missing on FXC, so this is only needed for 32bit types,
4037                    // as non-32bit types are DXC only.
4038                    Function::MissingIntReturnType(fun_name) => {
4039                        let scalar_kind = func_ctx.resolve_type(arg, &module.types).scalar();
4040                        if let Some(Scalar::I32) = scalar_kind {
4041                            write!(self.out, "asint({fun_name}(")?;
4042                            self.write_expr(module, arg, func_ctx)?;
4043                            write!(self.out, "))")?;
4044                        } else {
4045                            write!(self.out, "{fun_name}(")?;
4046                            self.write_expr(module, arg, func_ctx)?;
4047                            write!(self.out, ")")?;
4048                        }
4049                    }
4050                    Function::CountTrailingZeros => {
4051                        match *func_ctx.resolve_type(arg, &module.types) {
4052                            TypeInner::Vector { size, scalar } => {
4053                                let s = match size {
4054                                    crate::VectorSize::Bi => ".xx",
4055                                    crate::VectorSize::Tri => ".xxx",
4056                                    crate::VectorSize::Quad => ".xxxx",
4057                                };
4058
4059                                let scalar_width_bits = scalar.width * 8;
4060
4061                                if scalar.kind == ScalarKind::Uint || scalar.width != 4 {
4062                                    write!(
4063                                        self.out,
4064                                        "min(({scalar_width_bits}u){s}, firstbitlow("
4065                                    )?;
4066                                    self.write_expr(module, arg, func_ctx)?;
4067                                    write!(self.out, "))")?;
4068                                } else {
4069                                    // This is only needed for the FXC path, on 32bit signed integers.
4070                                    write!(
4071                                        self.out,
4072                                        "asint(min(({scalar_width_bits}u){s}, firstbitlow("
4073                                    )?;
4074                                    self.write_expr(module, arg, func_ctx)?;
4075                                    write!(self.out, ")))")?;
4076                                }
4077                            }
4078                            TypeInner::Scalar(scalar) => {
4079                                let scalar_width_bits = scalar.width * 8;
4080
4081                                if scalar.kind == ScalarKind::Uint || scalar.width != 4 {
4082                                    write!(self.out, "min({scalar_width_bits}u, firstbitlow(")?;
4083                                    self.write_expr(module, arg, func_ctx)?;
4084                                    write!(self.out, "))")?;
4085                                } else {
4086                                    // This is only needed for the FXC path, on 32bit signed integers.
4087                                    write!(
4088                                        self.out,
4089                                        "asint(min({scalar_width_bits}u, firstbitlow("
4090                                    )?;
4091                                    self.write_expr(module, arg, func_ctx)?;
4092                                    write!(self.out, ")))")?;
4093                                }
4094                            }
4095                            _ => unreachable!(),
4096                        }
4097
4098                        return Ok(());
4099                    }
4100                    Function::CountLeadingZeros => {
4101                        match *func_ctx.resolve_type(arg, &module.types) {
4102                            TypeInner::Vector { size, scalar } => {
4103                                let s = match size {
4104                                    crate::VectorSize::Bi => ".xx",
4105                                    crate::VectorSize::Tri => ".xxx",
4106                                    crate::VectorSize::Quad => ".xxxx",
4107                                };
4108
4109                                // scalar width - 1
4110                                let constant = scalar.width * 8 - 1;
4111
4112                                if scalar.kind == ScalarKind::Uint {
4113                                    write!(self.out, "(({constant}u){s} - firstbithigh(")?;
4114                                    self.write_expr(module, arg, func_ctx)?;
4115                                    write!(self.out, "))")?;
4116                                } else {
4117                                    let conversion_func = match scalar.width {
4118                                        4 => "asint",
4119                                        _ => "",
4120                                    };
4121                                    write!(self.out, "(")?;
4122                                    self.write_expr(module, arg, func_ctx)?;
4123                                    write!(
4124                                        self.out,
4125                                        " < (0){s} ? (0){s} : ({constant}){s} - {conversion_func}(firstbithigh("
4126                                    )?;
4127                                    self.write_expr(module, arg, func_ctx)?;
4128                                    write!(self.out, ")))")?;
4129                                }
4130                            }
4131                            TypeInner::Scalar(scalar) => {
4132                                // scalar width - 1
4133                                let constant = scalar.width * 8 - 1;
4134
4135                                if let ScalarKind::Uint = scalar.kind {
4136                                    write!(self.out, "({constant}u - firstbithigh(")?;
4137                                    self.write_expr(module, arg, func_ctx)?;
4138                                    write!(self.out, "))")?;
4139                                } else {
4140                                    let conversion_func = match scalar.width {
4141                                        4 => "asint",
4142                                        _ => "",
4143                                    };
4144                                    write!(self.out, "(")?;
4145                                    self.write_expr(module, arg, func_ctx)?;
4146                                    write!(
4147                                        self.out,
4148                                        " < 0 ? 0 : {constant} - {conversion_func}(firstbithigh("
4149                                    )?;
4150                                    self.write_expr(module, arg, func_ctx)?;
4151                                    write!(self.out, ")))")?;
4152                                }
4153                            }
4154                            _ => unreachable!(),
4155                        }
4156
4157                        return Ok(());
4158                    }
4159                }
4160            }
4161            Expression::Swizzle {
4162                size,
4163                vector,
4164                pattern,
4165            } => {
4166                self.write_expr(module, vector, func_ctx)?;
4167                write!(self.out, ".")?;
4168                for &sc in pattern[..size as usize].iter() {
4169                    self.out.write_char(back::COMPONENTS[sc as usize])?;
4170                }
4171            }
4172            Expression::ArrayLength(expr) => {
4173                let var_handle = match func_ctx.expressions[expr] {
4174                    Expression::AccessIndex { base, index: _ } => {
4175                        match func_ctx.expressions[base] {
4176                            Expression::GlobalVariable(handle) => handle,
4177                            _ => unreachable!(),
4178                        }
4179                    }
4180                    Expression::GlobalVariable(handle) => handle,
4181                    _ => unreachable!(),
4182                };
4183
4184                let var = &module.global_variables[var_handle];
4185                let (offset, stride) = match module.types[var.ty].inner {
4186                    TypeInner::Array { stride, .. } => (0, stride),
4187                    TypeInner::Struct { ref members, .. } => {
4188                        let last = members.last().unwrap();
4189                        let stride = match module.types[last.ty].inner {
4190                            TypeInner::Array { stride, .. } => stride,
4191                            _ => unreachable!(),
4192                        };
4193                        (last.offset, stride)
4194                    }
4195                    _ => unreachable!(),
4196                };
4197
4198                let storage_access = match var.space {
4199                    crate::AddressSpace::Storage { access } => access,
4200                    _ => crate::StorageAccess::default(),
4201                };
4202                let wrapped_array_length = WrappedArrayLength {
4203                    writable: storage_access.contains(crate::StorageAccess::STORE),
4204                };
4205
4206                write!(self.out, "((")?;
4207                self.write_wrapped_array_length_function_name(wrapped_array_length)?;
4208                let var_name = &self.names[&NameKey::GlobalVariable(var_handle)];
4209                write!(self.out, "({var_name}) - {offset}) / {stride})")?
4210            }
4211            Expression::Derivative { axis, ctrl, expr } => {
4212                use crate::{DerivativeAxis as Axis, DerivativeControl as Ctrl};
4213                if axis == Axis::Width && (ctrl == Ctrl::Coarse || ctrl == Ctrl::Fine) {
4214                    let tail = match ctrl {
4215                        Ctrl::Coarse => "coarse",
4216                        Ctrl::Fine => "fine",
4217                        Ctrl::None => unreachable!(),
4218                    };
4219                    write!(self.out, "abs(ddx_{tail}(")?;
4220                    self.write_expr(module, expr, func_ctx)?;
4221                    write!(self.out, ")) + abs(ddy_{tail}(")?;
4222                    self.write_expr(module, expr, func_ctx)?;
4223                    write!(self.out, "))")?
4224                } else {
4225                    let fun_str = match (axis, ctrl) {
4226                        (Axis::X, Ctrl::Coarse) => "ddx_coarse",
4227                        (Axis::X, Ctrl::Fine) => "ddx_fine",
4228                        (Axis::X, Ctrl::None) => "ddx",
4229                        (Axis::Y, Ctrl::Coarse) => "ddy_coarse",
4230                        (Axis::Y, Ctrl::Fine) => "ddy_fine",
4231                        (Axis::Y, Ctrl::None) => "ddy",
4232                        (Axis::Width, Ctrl::Coarse | Ctrl::Fine) => unreachable!(),
4233                        (Axis::Width, Ctrl::None) => "fwidth",
4234                    };
4235                    write!(self.out, "{fun_str}(")?;
4236                    self.write_expr(module, expr, func_ctx)?;
4237                    write!(self.out, ")")?
4238                }
4239            }
4240            Expression::Relational { fun, argument } => {
4241                use crate::RelationalFunction as Rf;
4242
4243                let fun_str = match fun {
4244                    Rf::All => "all",
4245                    Rf::Any => "any",
4246                    Rf::IsNan => "isnan",
4247                    Rf::IsInf => "isinf",
4248                };
4249                write!(self.out, "{fun_str}(")?;
4250                self.write_expr(module, argument, func_ctx)?;
4251                write!(self.out, ")")?
4252            }
4253            Expression::Select {
4254                condition,
4255                accept,
4256                reject,
4257            } => {
4258                write!(self.out, "(")?;
4259                self.write_expr(module, condition, func_ctx)?;
4260                write!(self.out, " ? ")?;
4261                self.write_expr(module, accept, func_ctx)?;
4262                write!(self.out, " : ")?;
4263                self.write_expr(module, reject, func_ctx)?;
4264                write!(self.out, ")")?
4265            }
4266            Expression::RayQueryGetIntersection { query, committed } => {
4267                if committed {
4268                    write!(self.out, "GetCommittedIntersection(")?;
4269                    self.write_expr(module, query, func_ctx)?;
4270                    write!(self.out, ")")?;
4271                } else {
4272                    write!(self.out, "GetCandidateIntersection(")?;
4273                    self.write_expr(module, query, func_ctx)?;
4274                    write!(self.out, ")")?;
4275                }
4276            }
4277            // Not supported yet
4278            Expression::RayQueryVertexPositions { .. } => unreachable!(),
4279            // Nothing to do here, since call expression already cached
4280            Expression::CallResult(_)
4281            | Expression::AtomicResult { .. }
4282            | Expression::WorkGroupUniformLoadResult { .. }
4283            | Expression::RayQueryProceedResult
4284            | Expression::SubgroupBallotResult
4285            | Expression::SubgroupOperationResult { .. } => {}
4286        }
4287
4288        if !closing_bracket.is_empty() {
4289            write!(self.out, "{closing_bracket}")?;
4290        }
4291        Ok(())
4292    }
4293
4294    #[allow(clippy::too_many_arguments)]
4295    fn write_image_load(
4296        &mut self,
4297        module: &&Module,
4298        expr: Handle<crate::Expression>,
4299        func_ctx: &back::FunctionCtx,
4300        image: Handle<crate::Expression>,
4301        coordinate: Handle<crate::Expression>,
4302        array_index: Option<Handle<crate::Expression>>,
4303        sample: Option<Handle<crate::Expression>>,
4304        level: Option<Handle<crate::Expression>>,
4305    ) -> Result<(), Error> {
4306        let mut wrapping_type = None;
4307        match *func_ctx.resolve_type(image, &module.types) {
4308            TypeInner::Image {
4309                class: crate::ImageClass::External,
4310                ..
4311            } => {
4312                write!(self.out, "{IMAGE_LOAD_EXTERNAL_FUNCTION}(")?;
4313                self.write_expr(module, image, func_ctx)?;
4314                write!(self.out, ", ")?;
4315                self.write_expr(module, coordinate, func_ctx)?;
4316                write!(self.out, ")")?;
4317                return Ok(());
4318            }
4319            TypeInner::Image {
4320                class: crate::ImageClass::Storage { format, .. },
4321                ..
4322            } => {
4323                if format.single_component() {
4324                    wrapping_type = Some(Scalar::from(format));
4325                }
4326            }
4327            _ => {}
4328        }
4329        if let Some(scalar) = wrapping_type {
4330            write!(
4331                self.out,
4332                "{}{}(",
4333                help::IMAGE_STORAGE_LOAD_SCALAR_WRAPPER,
4334                scalar.to_hlsl_str()?
4335            )?;
4336        }
4337        // https://docs.microsoft.com/en-us/windows/win32/direct3dhlsl/dx-graphics-hlsl-to-load
4338        self.write_expr(module, image, func_ctx)?;
4339        write!(self.out, ".Load(")?;
4340
4341        self.write_texture_coordinates("int", coordinate, array_index, level, module, func_ctx)?;
4342
4343        if let Some(sample) = sample {
4344            write!(self.out, ", ")?;
4345            self.write_expr(module, sample, func_ctx)?;
4346        }
4347
4348        // close bracket for Load function
4349        write!(self.out, ")")?;
4350
4351        if wrapping_type.is_some() {
4352            write!(self.out, ")")?;
4353        }
4354
4355        // return x component if return type is scalar
4356        if let TypeInner::Scalar(_) = *func_ctx.resolve_type(expr, &module.types) {
4357            write!(self.out, ".x")?;
4358        }
4359        Ok(())
4360    }
4361
4362    /// Find the [`BindingArraySamplerInfo`] from an expression so that such an access
4363    /// can be generated later.
4364    fn sampler_binding_array_info_from_expression(
4365        &mut self,
4366        module: &Module,
4367        func_ctx: &back::FunctionCtx<'_>,
4368        base: Handle<crate::Expression>,
4369        resolved: &TypeInner,
4370    ) -> Option<BindingArraySamplerInfo> {
4371        if let TypeInner::BindingArray {
4372            base: base_ty_handle,
4373            ..
4374        } = *resolved
4375        {
4376            let base_ty = &module.types[base_ty_handle].inner;
4377            if let TypeInner::Sampler { comparison, .. } = *base_ty {
4378                let base = &func_ctx.expressions[base];
4379
4380                if let crate::Expression::GlobalVariable(handle) = *base {
4381                    let variable = &module.global_variables[handle];
4382
4383                    let sampler_heap_name = match comparison {
4384                        true => COMPARISON_SAMPLER_HEAP_VAR,
4385                        false => SAMPLER_HEAP_VAR,
4386                    };
4387
4388                    return Some(BindingArraySamplerInfo {
4389                        sampler_heap_name,
4390                        sampler_index_buffer_name: self
4391                            .wrapped
4392                            .sampler_index_buffers
4393                            .get(&super::SamplerIndexBufferKey {
4394                                group: variable.binding.unwrap().group,
4395                            })
4396                            .unwrap()
4397                            .clone(),
4398                        binding_array_base_index_name: self.names[&NameKey::GlobalVariable(handle)]
4399                            .clone(),
4400                    });
4401                }
4402            }
4403        }
4404
4405        None
4406    }
4407
4408    fn write_named_expr(
4409        &mut self,
4410        module: &Module,
4411        handle: Handle<crate::Expression>,
4412        name: String,
4413        // The expression which is being named.
4414        // Generally, this is the same as handle, except in WorkGroupUniformLoad
4415        named: Handle<crate::Expression>,
4416        ctx: &back::FunctionCtx,
4417    ) -> BackendResult {
4418        match ctx.info[named].ty {
4419            proc::TypeResolution::Handle(ty_handle) => match module.types[ty_handle].inner {
4420                TypeInner::Struct { .. } => {
4421                    let ty_name = &self.names[&NameKey::Type(ty_handle)];
4422                    write!(self.out, "{ty_name}")?;
4423                }
4424                _ => {
4425                    self.write_type(module, ty_handle)?;
4426                }
4427            },
4428            proc::TypeResolution::Value(ref inner) => {
4429                self.write_value_type(module, inner)?;
4430            }
4431        }
4432
4433        let resolved = ctx.resolve_type(named, &module.types);
4434
4435        write!(self.out, " {name}")?;
4436        // If rhs is a array type, we should write array size
4437        if let TypeInner::Array { base, size, .. } = *resolved {
4438            self.write_array_size(module, base, size)?;
4439        }
4440        write!(self.out, " = ")?;
4441        self.write_expr(module, handle, ctx)?;
4442        writeln!(self.out, ";")?;
4443        self.named_expressions.insert(named, name);
4444
4445        Ok(())
4446    }
4447
4448    /// Helper function that write default zero initialization
4449    pub(super) fn write_default_init(
4450        &mut self,
4451        module: &Module,
4452        ty: Handle<crate::Type>,
4453    ) -> BackendResult {
4454        write!(self.out, "(")?;
4455        self.write_type(module, ty)?;
4456        if let TypeInner::Array { base, size, .. } = module.types[ty].inner {
4457            self.write_array_size(module, base, size)?;
4458        }
4459        write!(self.out, ")0")?;
4460        Ok(())
4461    }
4462
4463    fn write_control_barrier(
4464        &mut self,
4465        barrier: crate::Barrier,
4466        level: back::Level,
4467    ) -> BackendResult {
4468        if barrier.contains(crate::Barrier::STORAGE) {
4469            writeln!(self.out, "{level}DeviceMemoryBarrierWithGroupSync();")?;
4470        }
4471        if barrier.contains(crate::Barrier::WORK_GROUP) {
4472            writeln!(self.out, "{level}GroupMemoryBarrierWithGroupSync();")?;
4473        }
4474        if barrier.contains(crate::Barrier::SUB_GROUP) {
4475            // Does not exist in DirectX
4476        }
4477        if barrier.contains(crate::Barrier::TEXTURE) {
4478            writeln!(self.out, "{level}DeviceMemoryBarrierWithGroupSync();")?;
4479        }
4480        Ok(())
4481    }
4482
4483    fn write_memory_barrier(
4484        &mut self,
4485        barrier: crate::Barrier,
4486        level: back::Level,
4487    ) -> BackendResult {
4488        if barrier.contains(crate::Barrier::STORAGE) {
4489            writeln!(self.out, "{level}DeviceMemoryBarrier();")?;
4490        }
4491        if barrier.contains(crate::Barrier::WORK_GROUP) {
4492            writeln!(self.out, "{level}GroupMemoryBarrier();")?;
4493        }
4494        if barrier.contains(crate::Barrier::SUB_GROUP) {
4495            // Does not exist in DirectX
4496        }
4497        if barrier.contains(crate::Barrier::TEXTURE) {
4498            writeln!(self.out, "{level}DeviceMemoryBarrier();")?;
4499        }
4500        Ok(())
4501    }
4502
4503    /// Helper to emit the shared tail of an HLSL atomic call (arguments, value, result)
4504    fn emit_hlsl_atomic_tail(
4505        &mut self,
4506        module: &Module,
4507        func_ctx: &back::FunctionCtx<'_>,
4508        fun: &crate::AtomicFunction,
4509        compare_expr: Option<Handle<crate::Expression>>,
4510        value: Handle<crate::Expression>,
4511        res_var_info: &Option<(Handle<crate::Expression>, String)>,
4512    ) -> BackendResult {
4513        if let Some(cmp) = compare_expr {
4514            write!(self.out, ", ")?;
4515            self.write_expr(module, cmp, func_ctx)?;
4516        }
4517        write!(self.out, ", ")?;
4518        if let crate::AtomicFunction::Subtract = *fun {
4519            // we just wrote `InterlockedAdd`, so negate the argument
4520            write!(self.out, "-")?;
4521        }
4522        self.write_expr(module, value, func_ctx)?;
4523        if let Some(&(_res_handle, ref res_name)) = res_var_info.as_ref() {
4524            write!(self.out, ", ")?;
4525            if compare_expr.is_some() {
4526                write!(self.out, "{res_name}.old_value")?;
4527            } else {
4528                write!(self.out, "{res_name}")?;
4529            }
4530        }
4531        writeln!(self.out, ");")?;
4532        Ok(())
4533    }
4534}
4535
4536pub(super) struct MatrixType {
4537    pub(super) columns: crate::VectorSize,
4538    pub(super) rows: crate::VectorSize,
4539    pub(super) width: crate::Bytes,
4540}
4541
4542pub(super) fn get_inner_matrix_data(
4543    module: &Module,
4544    handle: Handle<crate::Type>,
4545) -> Option<MatrixType> {
4546    match module.types[handle].inner {
4547        TypeInner::Matrix {
4548            columns,
4549            rows,
4550            scalar,
4551        } => Some(MatrixType {
4552            columns,
4553            rows,
4554            width: scalar.width,
4555        }),
4556        TypeInner::Array { base, .. } => get_inner_matrix_data(module, base),
4557        _ => None,
4558    }
4559}
4560
4561/// If `base` is an access chain of the form `mat`, `mat[col]`, or `mat[col][row]`,
4562/// returns a tuple of the matrix, the column (vector) index (if present), and
4563/// the row (scalar) index (if present).
4564fn find_matrix_in_access_chain(
4565    module: &Module,
4566    base: Handle<crate::Expression>,
4567    func_ctx: &back::FunctionCtx<'_>,
4568) -> Option<(Handle<crate::Expression>, Option<Index>, Option<Index>)> {
4569    let mut current_base = base;
4570    let mut vector = None;
4571    let mut scalar = None;
4572    loop {
4573        let resolved_tr = func_ctx
4574            .resolve_type(current_base, &module.types)
4575            .pointer_base_type();
4576        let resolved = resolved_tr.as_ref()?.inner_with(&module.types);
4577
4578        match *resolved {
4579            TypeInner::Matrix { .. } => return Some((current_base, vector, scalar)),
4580            TypeInner::Scalar(_) | TypeInner::Vector { .. } => {}
4581            _ => return None,
4582        }
4583
4584        let index;
4585        (current_base, index) = match func_ctx.expressions[current_base] {
4586            crate::Expression::Access { base, index } => (base, Index::Expression(index)),
4587            crate::Expression::AccessIndex { base, index } => (base, Index::Static(index)),
4588            _ => return None,
4589        };
4590
4591        match *resolved {
4592            TypeInner::Scalar(_) => scalar = Some(index),
4593            TypeInner::Vector { .. } => vector = Some(index),
4594            _ => unreachable!(),
4595        }
4596    }
4597}
4598
4599/// Returns the matrix data if the access chain starting at `base`:
4600/// - starts with an expression with resolved type of [`TypeInner::Matrix`] if `direct = true`
4601/// - contains one or more expressions with resolved type of [`TypeInner::Array`] of [`TypeInner::Matrix`]
4602/// - ends at an expression with resolved type of [`TypeInner::Struct`]
4603pub(super) fn get_inner_matrix_of_struct_array_member(
4604    module: &Module,
4605    base: Handle<crate::Expression>,
4606    func_ctx: &back::FunctionCtx<'_>,
4607    direct: bool,
4608) -> Option<MatrixType> {
4609    let mut mat_data = None;
4610    let mut array_base = None;
4611
4612    let mut current_base = base;
4613    loop {
4614        let mut resolved = func_ctx.resolve_type(current_base, &module.types);
4615        if let TypeInner::Pointer { base, .. } = *resolved {
4616            resolved = &module.types[base].inner;
4617        };
4618
4619        match *resolved {
4620            TypeInner::Matrix {
4621                columns,
4622                rows,
4623                scalar,
4624            } => {
4625                mat_data = Some(MatrixType {
4626                    columns,
4627                    rows,
4628                    width: scalar.width,
4629                })
4630            }
4631            TypeInner::Array { base, .. } => {
4632                array_base = Some(base);
4633            }
4634            TypeInner::Struct { .. } => {
4635                if let Some(array_base) = array_base {
4636                    if direct {
4637                        return mat_data;
4638                    } else {
4639                        return get_inner_matrix_data(module, array_base);
4640                    }
4641                }
4642
4643                break;
4644            }
4645            _ => break,
4646        }
4647
4648        current_base = match func_ctx.expressions[current_base] {
4649            crate::Expression::Access { base, .. } => base,
4650            crate::Expression::AccessIndex { base, .. } => base,
4651            _ => break,
4652        };
4653    }
4654    None
4655}
4656
4657/// Simpler version of get_inner_matrix_of_global_uniform that only looks at the
4658/// immediate expression, rather than traversing an access chain.
4659fn get_global_uniform_matrix(
4660    module: &Module,
4661    base: Handle<crate::Expression>,
4662    func_ctx: &back::FunctionCtx<'_>,
4663) -> Option<MatrixType> {
4664    let base_tr = func_ctx
4665        .resolve_type(base, &module.types)
4666        .pointer_base_type();
4667    let base_ty = base_tr.as_ref().map(|tr| tr.inner_with(&module.types));
4668    match (&func_ctx.expressions[base], base_ty) {
4669        (
4670            &crate::Expression::GlobalVariable(handle),
4671            Some(&TypeInner::Matrix {
4672                columns,
4673                rows,
4674                scalar,
4675            }),
4676        ) if module.global_variables[handle].space == crate::AddressSpace::Uniform => {
4677            Some(MatrixType {
4678                columns,
4679                rows,
4680                width: scalar.width,
4681            })
4682        }
4683        _ => None,
4684    }
4685}
4686
4687/// Returns the matrix data if the access chain starting at `base`:
4688/// - starts with an expression with resolved type of [`TypeInner::Matrix`]
4689/// - contains zero or more expressions with resolved type of [`TypeInner::Array`] of [`TypeInner::Matrix`]
4690/// - ends with an [`Expression::GlobalVariable`](crate::Expression::GlobalVariable) in [`AddressSpace::Uniform`](crate::AddressSpace::Uniform)
4691fn get_inner_matrix_of_global_uniform(
4692    module: &Module,
4693    base: Handle<crate::Expression>,
4694    func_ctx: &back::FunctionCtx<'_>,
4695) -> Option<MatrixType> {
4696    let mut mat_data = None;
4697    let mut array_base = None;
4698
4699    let mut current_base = base;
4700    loop {
4701        let mut resolved = func_ctx.resolve_type(current_base, &module.types);
4702        if let TypeInner::Pointer { base, .. } = *resolved {
4703            resolved = &module.types[base].inner;
4704        };
4705
4706        match *resolved {
4707            TypeInner::Matrix {
4708                columns,
4709                rows,
4710                scalar,
4711            } => {
4712                mat_data = Some(MatrixType {
4713                    columns,
4714                    rows,
4715                    width: scalar.width,
4716                })
4717            }
4718            TypeInner::Array { base, .. } => {
4719                array_base = Some(base);
4720            }
4721            _ => break,
4722        }
4723
4724        current_base = match func_ctx.expressions[current_base] {
4725            crate::Expression::Access { base, .. } => base,
4726            crate::Expression::AccessIndex { base, .. } => base,
4727            crate::Expression::GlobalVariable(handle)
4728                if module.global_variables[handle].space == crate::AddressSpace::Uniform =>
4729            {
4730                return mat_data.or_else(|| {
4731                    array_base.and_then(|array_base| get_inner_matrix_data(module, array_base))
4732                })
4733            }
4734            _ => break,
4735        };
4736    }
4737    None
4738}