naga/back/hlsl/
writer.rs

1use alloc::{
2    format,
3    string::{String, ToString},
4    vec::Vec,
5};
6use core::{
7    fmt::{self, Write as _},
8    mem,
9};
10
11use super::{
12    help,
13    help::{
14        WrappedArrayLength, WrappedConstructor, WrappedImageQuery, WrappedStructMatrixAccess,
15        WrappedZeroValue,
16    },
17    storage::StoreValue,
18    BackendResult, Error, FragmentEntryPoint, Options, PipelineOptions, ShaderModel,
19};
20use crate::{
21    back::{self, get_entry_points, Baked},
22    common,
23    proc::{self, index, ExternalTextureNameKey, NameKey},
24    valid, Handle, Module, RayQueryFunction, Scalar, ScalarKind, ShaderStage, TypeInner,
25};
26
27const LOCATION_SEMANTIC: &str = "LOC";
28const SPECIAL_CBUF_TYPE: &str = "NagaConstants";
29const SPECIAL_CBUF_VAR: &str = "_NagaConstants";
30const SPECIAL_FIRST_VERTEX: &str = "first_vertex";
31const SPECIAL_FIRST_INSTANCE: &str = "first_instance";
32const SPECIAL_OTHER: &str = "other";
33
34pub(crate) const MODF_FUNCTION: &str = "naga_modf";
35pub(crate) const FREXP_FUNCTION: &str = "naga_frexp";
36pub(crate) const EXTRACT_BITS_FUNCTION: &str = "naga_extractBits";
37pub(crate) const INSERT_BITS_FUNCTION: &str = "naga_insertBits";
38pub(crate) const SAMPLER_HEAP_VAR: &str = "nagaSamplerHeap";
39pub(crate) const COMPARISON_SAMPLER_HEAP_VAR: &str = "nagaComparisonSamplerHeap";
40pub(crate) const SAMPLE_EXTERNAL_TEXTURE_FUNCTION: &str = "nagaSampleExternalTexture";
41pub(crate) const ABS_FUNCTION: &str = "naga_abs";
42pub(crate) const DIV_FUNCTION: &str = "naga_div";
43pub(crate) const MOD_FUNCTION: &str = "naga_mod";
44pub(crate) const NEG_FUNCTION: &str = "naga_neg";
45pub(crate) const F2I32_FUNCTION: &str = "naga_f2i32";
46pub(crate) const F2U32_FUNCTION: &str = "naga_f2u32";
47pub(crate) const F2I64_FUNCTION: &str = "naga_f2i64";
48pub(crate) const F2U64_FUNCTION: &str = "naga_f2u64";
49pub(crate) const IMAGE_SAMPLE_BASE_CLAMP_TO_EDGE_FUNCTION: &str =
50    "nagaTextureSampleBaseClampToEdge";
51pub(crate) const IMAGE_LOAD_EXTERNAL_FUNCTION: &str = "nagaTextureLoadExternal";
52pub(crate) const RAY_QUERY_TRACKER_VARIABLE_PREFIX: &str = "naga_query_init_tracker_for_";
53/// Prefix for variables in a naga statement
54pub(crate) const INTERNAL_PREFIX: &str = "naga_";
55
56enum Index {
57    Expression(Handle<crate::Expression>),
58    Static(u32),
59}
60
61pub(super) struct EpStructMember {
62    pub(super) name: String,
63    pub(super) ty: Handle<crate::Type>,
64    // technically, this should always be `Some`
65    // (we `debug_assert!` this in `write_interface_struct`)
66    pub(super) binding: Option<crate::Binding>,
67    pub(super) index: u32,
68}
69
70/// Structure contains information required for generating
71/// wrapped structure of all entry points arguments
72pub(super) struct EntryPointBinding {
73    /// Name of the fake EP argument that contains the struct
74    /// with all the flattened input data.
75    pub(super) arg_name: String,
76    /// Generated structure name
77    pub(super) ty_name: String,
78    /// Members of generated structure
79    pub(super) members: Vec<EpStructMember>,
80    pub(super) local_invocation_index_name: Option<String>,
81}
82
83pub(super) struct EntryPointInterface {
84    /// If `Some`, the input of an entry point is gathered in a special
85    /// struct with members sorted by binding.
86    /// The `EntryPointBinding::members` array is sorted by index,
87    /// so that we can walk it in `write_ep_arguments_initialization`.
88    pub(crate) input: Option<EntryPointBinding>,
89    /// If `Some`, the output of an entry point is flattened.
90    /// The `EntryPointBinding::members` array is sorted by binding,
91    /// So that we can walk it in `Statement::Return` handler.
92    pub(crate) output: Option<EntryPointBinding>,
93    pub(crate) mesh_vertices: Option<EntryPointBinding>,
94    pub(crate) mesh_primitives: Option<EntryPointBinding>,
95    pub(crate) mesh_indices: Option<EntryPointBinding>,
96}
97
98#[derive(Clone, Eq, PartialEq, PartialOrd, Ord)]
99enum InterfaceKey {
100    Location(u32),
101    BuiltIn(crate::BuiltIn),
102    Other,
103}
104
105impl InterfaceKey {
106    const fn new(binding: Option<&crate::Binding>) -> Self {
107        match binding {
108            Some(&crate::Binding::Location { location, .. }) => Self::Location(location),
109            Some(&crate::Binding::BuiltIn(built_in)) => Self::BuiltIn(built_in),
110            None => Self::Other,
111        }
112    }
113}
114
115#[derive(Copy, Clone, PartialEq)]
116pub(super) enum Io {
117    Input,
118    Output,
119    MeshVertices,
120    MeshPrimitives,
121}
122
123/// Argument list for nested entry points
124pub(super) struct NestedEntryPointArgs {
125    /// Arguments literally declared by the user
126    pub user_args: Vec<String>,
127    pub task_payload: Option<String>,
128    pub local_invocation_index: String,
129}
130
131const fn is_subgroup_builtin_binding(binding: &Option<crate::Binding>) -> bool {
132    let &Some(crate::Binding::BuiltIn(builtin)) = binding else {
133        return false;
134    };
135    matches!(
136        builtin,
137        crate::BuiltIn::SubgroupSize
138            | crate::BuiltIn::SubgroupInvocationId
139            | crate::BuiltIn::NumSubgroups
140            | crate::BuiltIn::SubgroupId
141    )
142}
143
144/// Information for how to generate a `binding_array<sampler>` access.
145struct BindingArraySamplerInfo {
146    /// Variable name of the sampler heap
147    sampler_heap_name: &'static str,
148    /// Variable name of the sampler index buffer
149    sampler_index_buffer_name: String,
150    /// Variable name of the base index _into_ the sampler index buffer
151    binding_array_base_index_name: String,
152}
153
154impl<'a, W: fmt::Write> super::Writer<'a, W> {
155    pub fn new(out: W, options: &'a Options, pipeline_options: &'a PipelineOptions) -> Self {
156        Self {
157            out,
158            names: crate::FastHashMap::default(),
159            namer: proc::Namer::default(),
160            options,
161            pipeline_options,
162            entry_point_io: crate::FastHashMap::default(),
163            named_expressions: crate::NamedExpressions::default(),
164            wrapped: super::Wrapped::default(),
165            written_committed_intersection: false,
166            written_candidate_intersection: false,
167            continue_ctx: back::continue_forward::ContinueCtx::default(),
168            temp_access_chain: Vec::new(),
169            need_bake_expressions: Default::default(),
170            function_task_payload_var: Default::default(),
171        }
172    }
173
174    fn reset(&mut self, module: &Module) {
175        self.names.clear();
176        self.namer.reset(
177            module,
178            &super::keywords::RESERVED_SET,
179            proc::KeywordSet::empty(),
180            &super::keywords::RESERVED_CASE_INSENSITIVE_SET,
181            super::keywords::RESERVED_PREFIXES,
182            &mut self.names,
183        );
184        self.entry_point_io.clear();
185        self.named_expressions.clear();
186        self.wrapped.clear();
187        self.written_committed_intersection = false;
188        self.written_candidate_intersection = false;
189        self.continue_ctx.clear();
190        self.need_bake_expressions.clear();
191        self.function_task_payload_var.clear();
192    }
193
194    /// Generates statements to be inserted immediately before and at the very
195    /// start of the body of each loop, to defeat infinite loop reasoning.
196    /// The 0th item of the returned tuple should be inserted immediately prior
197    /// to the loop and the 1st item should be inserted at the very start of
198    /// the loop body.
199    ///
200    /// See [`back::msl::Writer::gen_force_bounded_loop_statements`] for details.
201    fn gen_force_bounded_loop_statements(
202        &mut self,
203        level: back::Level,
204    ) -> Option<(String, String)> {
205        if !self.options.force_loop_bounding {
206            return None;
207        }
208
209        let loop_bound_name = self.namer.call("loop_bound");
210        let max = u32::MAX;
211        // Count down from u32::MAX rather than up from 0 to avoid hang on
212        // certain Intel drivers. See <https://github.com/gfx-rs/wgpu/issues/7319>.
213        let decl = format!("{level}uint2 {loop_bound_name} = uint2({max}u, {max}u);");
214        let level = level.next();
215        let break_and_inc = format!(
216            "{level}if (all({loop_bound_name} == uint2(0u, 0u))) {{ break; }}
217{level}{loop_bound_name} -= uint2({loop_bound_name}.y == 0u, 1u);"
218        );
219
220        Some((decl, break_and_inc))
221    }
222
223    /// Helper method used to find which expressions of a given function require baking
224    ///
225    /// # Notes
226    /// Clears `need_bake_expressions` set before adding to it
227    fn update_expressions_to_bake(
228        &mut self,
229        module: &Module,
230        func: &crate::Function,
231        info: &valid::FunctionInfo,
232    ) {
233        use crate::Expression;
234        self.need_bake_expressions.clear();
235        for (exp_handle, expr) in func.expressions.iter() {
236            let expr_info = &info[exp_handle];
237            let min_ref_count = func.expressions[exp_handle].bake_ref_count();
238            if min_ref_count <= expr_info.ref_count {
239                self.need_bake_expressions.insert(exp_handle);
240            }
241            if let Expression::Load { pointer } = *expr {
242                if info[pointer]
243                    .ty
244                    .inner_with(&module.types)
245                    .is_atomic_pointer(&module.types)
246                {
247                    self.need_bake_expressions.insert(exp_handle);
248                }
249            }
250
251            if let Expression::Math { fun, arg, arg1, .. } = *expr {
252                match fun {
253                    crate::MathFunction::Asinh
254                    | crate::MathFunction::Acosh
255                    | crate::MathFunction::Atanh
256                    | crate::MathFunction::Unpack2x16float
257                    | crate::MathFunction::Unpack2x16snorm
258                    | crate::MathFunction::Unpack2x16unorm
259                    | crate::MathFunction::Unpack4x8snorm
260                    | crate::MathFunction::Unpack4x8unorm
261                    | crate::MathFunction::Unpack4xI8
262                    | crate::MathFunction::Unpack4xU8
263                    | crate::MathFunction::Pack2x16float
264                    | crate::MathFunction::Pack2x16snorm
265                    | crate::MathFunction::Pack2x16unorm
266                    | crate::MathFunction::Pack4x8snorm
267                    | crate::MathFunction::Pack4x8unorm
268                    | crate::MathFunction::Pack4xI8
269                    | crate::MathFunction::Pack4xU8
270                    | crate::MathFunction::Pack4xI8Clamp
271                    | crate::MathFunction::Pack4xU8Clamp => {
272                        self.need_bake_expressions.insert(arg);
273                    }
274                    crate::MathFunction::CountLeadingZeros => {
275                        let inner = info[exp_handle].ty.inner_with(&module.types);
276                        if let Some(ScalarKind::Sint) = inner.scalar_kind() {
277                            self.need_bake_expressions.insert(arg);
278                        }
279                    }
280                    crate::MathFunction::Dot4U8Packed | crate::MathFunction::Dot4I8Packed => {
281                        self.need_bake_expressions.insert(arg);
282                        self.need_bake_expressions.insert(arg1.unwrap());
283                    }
284                    _ => {}
285                }
286            }
287
288            if let Expression::Derivative { axis, ctrl, expr } = *expr {
289                use crate::{DerivativeAxis as Axis, DerivativeControl as Ctrl};
290                if axis == Axis::Width && (ctrl == Ctrl::Coarse || ctrl == Ctrl::Fine) {
291                    self.need_bake_expressions.insert(expr);
292                }
293            }
294
295            if let Expression::GlobalVariable(_) = *expr {
296                let inner = info[exp_handle].ty.inner_with(&module.types);
297
298                if let TypeInner::Sampler { .. } = *inner {
299                    self.need_bake_expressions.insert(exp_handle);
300                }
301            }
302        }
303        for statement in func.body.iter() {
304            match *statement {
305                crate::Statement::SubgroupCollectiveOperation {
306                    op: _,
307                    collective_op: crate::CollectiveOperation::InclusiveScan,
308                    argument,
309                    result: _,
310                } => {
311                    self.need_bake_expressions.insert(argument);
312                }
313                crate::Statement::Atomic {
314                    fun: crate::AtomicFunction::Exchange { compare: Some(cmp) },
315                    ..
316                } => {
317                    self.need_bake_expressions.insert(cmp);
318                }
319                _ => {}
320            }
321        }
322    }
323
324    pub fn write(
325        &mut self,
326        module: &Module,
327        module_info: &valid::ModuleInfo,
328        fragment_entry_point: Option<&FragmentEntryPoint<'_>>,
329    ) -> Result<super::ReflectionInfo, Error> {
330        self.reset(module);
331
332        if module.uses_mesh_shaders() && self.options.shader_model < ShaderModel::V6_5 {
333            return Err(Error::ShaderModelTooLow(
334                "mesh shaders".to_string(),
335                ShaderModel::V6_5,
336            ));
337        }
338
339        // Write special constants, if needed
340        if let Some(ref bt) = self.options.special_constants_binding {
341            writeln!(self.out, "struct {SPECIAL_CBUF_TYPE} {{")?;
342            writeln!(self.out, "{}int {};", back::INDENT, SPECIAL_FIRST_VERTEX)?;
343            writeln!(self.out, "{}int {};", back::INDENT, SPECIAL_FIRST_INSTANCE)?;
344            writeln!(self.out, "{}uint {};", back::INDENT, SPECIAL_OTHER)?;
345            writeln!(self.out, "}};")?;
346            write!(
347                self.out,
348                "ConstantBuffer<{}> {}: register(b{}",
349                SPECIAL_CBUF_TYPE, SPECIAL_CBUF_VAR, bt.register
350            )?;
351            if bt.space != 0 {
352                write!(self.out, ", space{}", bt.space)?;
353            }
354            writeln!(self.out, ");")?;
355
356            // Extra newline for readability
357            writeln!(self.out)?;
358        }
359
360        for (group, bt) in self.options.dynamic_storage_buffer_offsets_targets.iter() {
361            writeln!(self.out, "struct __dynamic_buffer_offsetsTy{group} {{")?;
362            for i in 0..bt.size {
363                writeln!(self.out, "{}uint _{};", back::INDENT, i)?;
364            }
365            writeln!(self.out, "}};")?;
366            writeln!(
367                self.out,
368                "ConstantBuffer<__dynamic_buffer_offsetsTy{}> __dynamic_buffer_offsets{}: register(b{}, space{});",
369                group, group, bt.register, bt.space
370            )?;
371
372            // Extra newline for readability
373            writeln!(self.out)?;
374        }
375
376        // Save all entry point output types
377        let ep_results = module
378            .entry_points
379            .iter()
380            .map(|ep| (ep.stage, ep.function.result.clone()))
381            .collect::<Vec<(ShaderStage, Option<crate::FunctionResult>)>>();
382
383        self.write_all_mat_cx2_typedefs_and_functions(module)?;
384
385        // Write all structs
386        for (handle, ty) in module.types.iter() {
387            if let TypeInner::Struct { ref members, span } = ty.inner {
388                if module.types[members.last().unwrap().ty]
389                    .inner
390                    .is_dynamically_sized(&module.types)
391                {
392                    // unsized arrays can only be in storage buffers,
393                    // for which we use `ByteAddressBuffer` anyway.
394                    continue;
395                }
396
397                let ep_result = ep_results.iter().find(|e| {
398                    if let Some(ref result) = e.1 {
399                        result.ty == handle
400                    } else {
401                        false
402                    }
403                });
404
405                self.write_struct(
406                    module,
407                    handle,
408                    members,
409                    span,
410                    ep_result.map(|r| (r.0, Io::Output)),
411                )?;
412                writeln!(self.out)?;
413            }
414        }
415
416        self.write_special_functions(module)?;
417
418        self.write_wrapped_expression_functions(module, &module.global_expressions, None)?;
419        self.write_wrapped_zero_value_functions(module, &module.global_expressions)?;
420
421        // Write all named constants
422        let mut constants = module
423            .constants
424            .iter()
425            .filter(|&(_, c)| c.name.is_some())
426            .peekable();
427        while let Some((handle, _)) = constants.next() {
428            self.write_global_constant(module, handle)?;
429            // Add extra newline for readability on last iteration
430            if constants.peek().is_none() {
431                writeln!(self.out)?;
432            }
433        }
434
435        // Write all globals
436        for (global, _) in module.global_variables.iter() {
437            self.write_global(module, global)?;
438        }
439
440        if !module.global_variables.is_empty() {
441            // Add extra newline for readability
442            writeln!(self.out)?;
443        }
444
445        let ep_range = get_entry_points(module, self.pipeline_options.entry_point.as_ref())
446            .map_err(|(stage, name)| Error::EntryPointNotFound(stage, name))?;
447
448        // Write all entry points wrapped structs
449        for index in ep_range.clone() {
450            let ep = &module.entry_points[index];
451            let ep_name = self.names[&NameKey::EntryPoint(index as u16)].clone();
452            let ep_io = self.write_ep_interface(module, ep, &ep_name, fragment_entry_point)?;
453            self.entry_point_io.insert(index, ep_io);
454        }
455
456        // Write all regular functions
457        for (handle, function) in module.functions.iter() {
458            let info = &module_info[handle];
459
460            // Check if all of the globals are accessible
461            if !self.options.fake_missing_bindings {
462                if let Some((var_handle, _)) =
463                    module
464                        .global_variables
465                        .iter()
466                        .find(|&(var_handle, var)| match var.binding {
467                            Some(ref binding) if !info[var_handle].is_empty() => {
468                                self.options.resolve_resource_binding(binding).is_err()
469                                    && self
470                                        .options
471                                        .resolve_external_texture_resource_binding(binding)
472                                        .is_err()
473                            }
474                            _ => false,
475                        })
476                {
477                    log::debug!(
478                        "Skipping function {:?} (name {:?}) because global {:?} is inaccessible",
479                        handle,
480                        function.name,
481                        var_handle
482                    );
483                    continue;
484                }
485            }
486
487            let ctx = back::FunctionCtx {
488                ty: back::FunctionType::Function(handle),
489                info,
490                expressions: &function.expressions,
491                named_expressions: &function.named_expressions,
492            };
493            let name = self.names[&NameKey::Function(handle)].clone();
494
495            self.write_wrapped_functions(module, &ctx)?;
496
497            self.write_function(module, name.as_str(), function, &ctx, info, String::new())?;
498
499            writeln!(self.out)?;
500        }
501
502        let mut translated_ep_names = Vec::with_capacity(ep_range.len());
503
504        // Write all entry points
505        for index in ep_range {
506            let ep = &module.entry_points[index];
507            let info = module_info.get_entry_point(index);
508
509            if !self.options.fake_missing_bindings {
510                let mut ep_error = None;
511                for (var_handle, var) in module.global_variables.iter() {
512                    match var.binding {
513                        Some(ref binding) if !info[var_handle].is_empty() => {
514                            if let Err(err) = self.options.resolve_resource_binding(binding) {
515                                if self
516                                    .options
517                                    .resolve_external_texture_resource_binding(binding)
518                                    .is_err()
519                                {
520                                    ep_error = Some(err);
521                                    break;
522                                }
523                            }
524                        }
525                        _ => {}
526                    }
527                }
528                if let Some(err) = ep_error {
529                    translated_ep_names.push(Err(err));
530                    continue;
531                }
532            }
533
534            let ctx = back::FunctionCtx {
535                ty: back::FunctionType::EntryPoint(index as u16),
536                info,
537                expressions: &ep.function.expressions,
538                named_expressions: &ep.function.named_expressions,
539            };
540
541            self.write_wrapped_functions(module, &ctx)?;
542
543            // Mesh/task shaders have a wrapper entry point which is declared after the "main"
544            // user-written function. We therefore cannot always just document the next function.
545            let mut attribute_string = String::new();
546            if ep.stage.compute_like() {
547                // HLSL is calling workgroup size "num threads"
548                let num_threads = ep.workgroup_size;
549                writeln!(
550                    attribute_string,
551                    "[numthreads({}, {}, {})]",
552                    num_threads[0], num_threads[1], num_threads[2]
553                )?;
554            }
555            if let Some(ref info) = ep.mesh_info {
556                let topology_str = match info.topology {
557                    crate::MeshOutputTopology::Points => unreachable!(),
558                    crate::MeshOutputTopology::Lines => "line",
559                    crate::MeshOutputTopology::Triangles => "triangle",
560                };
561                writeln!(attribute_string, "[outputtopology(\"{topology_str}\")]")?;
562            }
563
564            let name = self.names[&NameKey::EntryPoint(index as u16)].clone();
565            self.write_function(module, &name, &ep.function, &ctx, info, attribute_string)?;
566
567            if index < module.entry_points.len() - 1 {
568                writeln!(self.out)?;
569            }
570
571            translated_ep_names.push(Ok(name));
572        }
573
574        Ok(super::ReflectionInfo {
575            entry_point_names: translated_ep_names,
576        })
577    }
578
579    fn write_modifier(&mut self, binding: &crate::Binding) -> BackendResult {
580        match *binding {
581            crate::Binding::BuiltIn(crate::BuiltIn::Position { invariant: true }) => {
582                write!(self.out, "precise ")?;
583            }
584            crate::Binding::BuiltIn(crate::BuiltIn::Barycentric { perspective: false }) => {
585                write!(self.out, "noperspective ")?;
586            }
587            crate::Binding::Location {
588                interpolation,
589                sampling,
590                ..
591            } => {
592                if let Some(interpolation) = interpolation {
593                    if let Some(string) = interpolation.to_hlsl_str() {
594                        write!(self.out, "{string} ")?
595                    }
596                }
597
598                if let Some(sampling) = sampling {
599                    if let Some(string) = sampling.to_hlsl_str() {
600                        write!(self.out, "{string} ")?
601                    }
602                }
603            }
604            crate::Binding::BuiltIn(_) => {}
605        }
606
607        Ok(())
608    }
609
610    //TODO: we could force fragment outputs to always go through `entry_point_io.output` path
611    // if they are struct, so that the `stage` argument here could be omitted.
612    pub(super) fn write_semantic(
613        &mut self,
614        binding: &Option<crate::Binding>,
615        stage: Option<(ShaderStage, Io)>,
616    ) -> BackendResult {
617        let is_per_primitive = match *binding {
618            Some(crate::Binding::BuiltIn(builtin)) if !is_subgroup_builtin_binding(binding) => {
619                if builtin == crate::BuiltIn::ViewIndex
620                    && self.options.shader_model < ShaderModel::V6_1
621                {
622                    return Err(Error::ShaderModelTooLow(
623                        "used @builtin(view_index) or SV_ViewID".to_string(),
624                        ShaderModel::V6_1,
625                    ));
626                }
627                if let Some(builtin_str) = builtin.to_hlsl_str()? {
628                    write!(self.out, " : {builtin_str}")?;
629                }
630                false
631            }
632            Some(crate::Binding::Location {
633                blend_src: Some(1),
634                per_primitive,
635                ..
636            }) => {
637                write!(self.out, " : SV_Target1")?;
638                per_primitive
639            }
640            Some(crate::Binding::Location {
641                location,
642                per_primitive,
643                ..
644            }) => {
645                if stage == Some((ShaderStage::Fragment, Io::Output)) {
646                    write!(self.out, " : SV_Target{location}")?;
647                } else {
648                    write!(self.out, " : {LOCATION_SEMANTIC}{location}")?;
649                }
650                per_primitive
651            }
652            _ => false,
653        };
654        if is_per_primitive {
655            write!(self.out, " : primitive")?;
656        }
657
658        Ok(())
659    }
660
661    pub(super) fn write_interface_struct(
662        &mut self,
663        module: &Module,
664        shader_stage: (ShaderStage, Io),
665        struct_name: String,
666        var_name: Option<&str>,
667        mut members: Vec<EpStructMember>,
668    ) -> Result<EntryPointBinding, Error> {
669        let struct_name = self.namer.call(&struct_name);
670        // Sort the members so that first come the user-defined varyings
671        // in ascending locations, and then built-ins. This allows VS and FS
672        // interfaces to match with regards to order.
673        members.sort_by_key(|m| InterfaceKey::new(m.binding.as_ref()));
674
675        write!(self.out, "struct {struct_name}")?;
676        writeln!(self.out, " {{")?;
677        let mut local_invocation_index_name = None;
678        let mut subgroup_id_used = false;
679        for m in members.iter() {
680            // Sanity check that each IO member is a built-in or is assigned a
681            // location. Also see note about nesting in `write_ep_input_struct`.
682            debug_assert!(m.binding.is_some());
683
684            match m.binding {
685                Some(crate::Binding::BuiltIn(crate::BuiltIn::SubgroupId)) => {
686                    subgroup_id_used = true;
687                }
688                Some(crate::Binding::BuiltIn(crate::BuiltIn::LocalInvocationIndex)) => {
689                    local_invocation_index_name = Some(m.name.clone());
690                }
691                _ => (),
692            }
693
694            if is_subgroup_builtin_binding(&m.binding) {
695                continue;
696            }
697            write!(self.out, "{}", back::INDENT)?;
698            if let Some(ref binding) = m.binding {
699                self.write_modifier(binding)?;
700            }
701            self.write_type(module, m.ty)?;
702            write!(self.out, " {}", &m.name)?;
703            self.write_semantic(&m.binding, Some(shader_stage))?;
704            writeln!(self.out, ";")?;
705        }
706        if subgroup_id_used && local_invocation_index_name.is_none() {
707            let name = self.namer.call("local_invocation_index");
708            writeln!(self.out, "{}uint {name} : SV_GroupIndex;", back::INDENT)?;
709            local_invocation_index_name = Some(name);
710        }
711        writeln!(self.out, "}};")?;
712        writeln!(self.out)?;
713
714        // See ordering notes on EntryPointInterface fields
715        match shader_stage.1 {
716            Io::Input => {
717                // bring back the original order
718                members.sort_by_key(|m| m.index);
719            }
720            Io::Output | Io::MeshVertices | Io::MeshPrimitives => {
721                // keep it sorted by binding
722            }
723        }
724
725        Ok(EntryPointBinding {
726            arg_name: self
727                .namer
728                .call(var_name.unwrap_or(struct_name.to_lowercase().as_str())),
729            ty_name: struct_name,
730            members,
731            local_invocation_index_name,
732        })
733    }
734
735    /// Flatten all entry point arguments into a single struct.
736    /// This is needed since we need to re-order them: first placing user locations,
737    /// then built-ins.
738    fn write_ep_input_struct(
739        &mut self,
740        module: &Module,
741        func: &crate::Function,
742        stage: ShaderStage,
743        entry_point_name: &str,
744    ) -> Result<EntryPointBinding, Error> {
745        let struct_name = format!("{stage:?}Input_{entry_point_name}");
746
747        let mut fake_members = Vec::new();
748        for arg in func.arguments.iter() {
749            // NOTE: We don't need to handle nesting structs. All members must
750            // be either built-ins or assigned a location. I.E. `binding` is
751            // `Some`. This is checked in `VaryingContext::validate`. See:
752            // https://gpuweb.github.io/gpuweb/wgsl/#input-output-locations
753            match module.types[arg.ty].inner {
754                TypeInner::Struct { ref members, .. } => {
755                    for member in members.iter() {
756                        let name = self.namer.call_or(&member.name, "member");
757                        let index = fake_members.len() as u32;
758                        fake_members.push(EpStructMember {
759                            name,
760                            ty: member.ty,
761                            binding: member.binding.clone(),
762                            index,
763                        });
764                    }
765                }
766                _ => {
767                    let member_name = self.namer.call_or(&arg.name, "member");
768                    let index = fake_members.len() as u32;
769                    fake_members.push(EpStructMember {
770                        name: member_name,
771                        ty: arg.ty,
772                        binding: arg.binding.clone(),
773                        index,
774                    });
775                }
776            }
777        }
778
779        self.write_interface_struct(module, (stage, Io::Input), struct_name, None, fake_members)
780    }
781
782    /// Flatten all entry point results into a single struct.
783    /// This is needed since we need to re-order them: first placing user locations,
784    /// then built-ins.
785    fn write_ep_output_struct(
786        &mut self,
787        module: &Module,
788        result: &crate::FunctionResult,
789        stage: ShaderStage,
790        entry_point_name: &str,
791        frag_ep: Option<&FragmentEntryPoint<'_>>,
792    ) -> Result<EntryPointBinding, Error> {
793        let struct_name = format!("{stage:?}Output_{entry_point_name}");
794
795        let empty = [];
796        let members = match module.types[result.ty].inner {
797            TypeInner::Struct { ref members, .. } => members,
798            ref other => {
799                log::error!("Unexpected {other:?} output type without a binding");
800                &empty[..]
801            }
802        };
803
804        // Gather list of fragment input locations. We use this below to remove user-defined
805        // varyings from VS outputs that aren't in the FS inputs. This makes the VS interface match
806        // as long as the FS inputs are a subset of the VS outputs. This is only applied if the
807        // writer is supplied with information about the fragment entry point.
808        let fs_input_locs = if let (Some(frag_ep), ShaderStage::Vertex) = (frag_ep, stage) {
809            let mut fs_input_locs = Vec::new();
810            for arg in frag_ep.func.arguments.iter() {
811                let mut push_if_location = |binding: &Option<crate::Binding>| match *binding {
812                    Some(crate::Binding::Location { location, .. }) => fs_input_locs.push(location),
813                    Some(crate::Binding::BuiltIn(_)) | None => {}
814                };
815
816                // NOTE: We don't need to handle struct nesting. See note in
817                // `write_ep_input_struct`.
818                match frag_ep.module.types[arg.ty].inner {
819                    TypeInner::Struct { ref members, .. } => {
820                        for member in members.iter() {
821                            push_if_location(&member.binding);
822                        }
823                    }
824                    _ => push_if_location(&arg.binding),
825                }
826            }
827            fs_input_locs.sort();
828            Some(fs_input_locs)
829        } else {
830            None
831        };
832
833        let mut fake_members = Vec::new();
834        for (index, member) in members.iter().enumerate() {
835            if let Some(ref fs_input_locs) = fs_input_locs {
836                match member.binding {
837                    Some(crate::Binding::Location { location, .. }) => {
838                        if fs_input_locs.binary_search(&location).is_err() {
839                            continue;
840                        }
841                    }
842                    Some(crate::Binding::BuiltIn(_)) | None => {}
843                }
844            }
845
846            let member_name = self.namer.call_or(&member.name, "member");
847            fake_members.push(EpStructMember {
848                name: member_name,
849                ty: member.ty,
850                binding: member.binding.clone(),
851                index: index as u32,
852            });
853        }
854
855        self.write_interface_struct(module, (stage, Io::Output), struct_name, None, fake_members)
856    }
857
858    /// Writes special interface structures for an entry point. The special structures have
859    /// all the fields flattened into them and sorted by binding. They are needed to emulate
860    /// subgroup built-ins and to make the interfaces between VS outputs and FS inputs match.
861    fn write_ep_interface(
862        &mut self,
863        module: &Module,
864        ep: &crate::EntryPoint,
865        ep_name: &str,
866        frag_ep: Option<&FragmentEntryPoint<'_>>,
867    ) -> Result<EntryPointInterface, Error> {
868        let func = &ep.function;
869        let stage = ep.stage;
870        Ok(EntryPointInterface {
871            input: if !func.arguments.is_empty()
872                && (stage == ShaderStage::Fragment
873                    || func
874                        .arguments
875                        .iter()
876                        .any(|arg| is_subgroup_builtin_binding(&arg.binding)))
877            {
878                Some(self.write_ep_input_struct(module, func, stage, ep_name)?)
879            } else {
880                None
881            },
882            output: match func.result {
883                Some(ref fr) if fr.binding.is_none() && stage == ShaderStage::Vertex => {
884                    Some(self.write_ep_output_struct(module, fr, stage, ep_name, frag_ep)?)
885                }
886                _ => None,
887            },
888            mesh_vertices: if let Some(ref info) = ep.mesh_info {
889                Some(self.write_ep_mesh_output_struct(module, ep_name, false, info)?)
890            } else {
891                None
892            },
893            mesh_primitives: if let Some(ref info) = ep.mesh_info {
894                Some(self.write_ep_mesh_output_struct(module, ep_name, true, info)?)
895            } else {
896                None
897            },
898            mesh_indices: if let Some(ref info) = ep.mesh_info {
899                Some(self.write_ep_mesh_output_indices(info.topology)?)
900            } else {
901                None
902            },
903        })
904    }
905
906    fn write_ep_argument_initialization(
907        &mut self,
908        ep: &crate::EntryPoint,
909        ep_input: &EntryPointBinding,
910        fake_member: &EpStructMember,
911    ) -> BackendResult {
912        match fake_member.binding {
913            Some(crate::Binding::BuiltIn(crate::BuiltIn::SubgroupSize)) => {
914                write!(self.out, "WaveGetLaneCount()")?
915            }
916            Some(crate::Binding::BuiltIn(crate::BuiltIn::SubgroupInvocationId)) => {
917                write!(self.out, "WaveGetLaneIndex()")?
918            }
919            Some(crate::Binding::BuiltIn(crate::BuiltIn::NumSubgroups)) => write!(
920                self.out,
921                "({}u + WaveGetLaneCount() - 1u) / WaveGetLaneCount()",
922                ep.workgroup_size[0] * ep.workgroup_size[1] * ep.workgroup_size[2]
923            )?,
924            Some(crate::Binding::BuiltIn(crate::BuiltIn::SubgroupId)) => {
925                write!(
926                    self.out,
927                    "{}.{} / WaveGetLaneCount()",
928                    ep_input.arg_name,
929                    // When writing SubgroupId, we always guarantee that local_invocation_index_name is written
930                    ep_input.local_invocation_index_name.as_ref().unwrap()
931                )?;
932            }
933            Some(crate::Binding::Location {
934                interpolation: Some(crate::Interpolation::PerVertex),
935                ..
936            }) => {
937                if self.options.shader_model < ShaderModel::V6_1 {
938                    return Err(Error::ShaderModelTooLow(
939                        "per_vertex fragment inputs".to_string(),
940                        ShaderModel::V6_1,
941                    ));
942                }
943                write!(
944                    self.out,
945                    "{{ GetAttributeAtVertex({0}.{1}, 0), GetAttributeAtVertex({0}.{1}, 1), GetAttributeAtVertex({0}.{1}, 2) }}",
946                    ep_input.arg_name,
947                    fake_member.name,
948                )?;
949            }
950            _ => {
951                write!(self.out, "{}.{}", ep_input.arg_name, fake_member.name)?;
952            }
953        }
954        Ok(())
955    }
956
957    /// Write an entry point preface that initializes the arguments as specified in IR.
958    fn write_ep_arguments_initialization(
959        &mut self,
960        module: &Module,
961        func: &crate::Function,
962        ep_index: u16,
963    ) -> BackendResult {
964        let ep = &module.entry_points[ep_index as usize];
965        let ep_input = match self
966            .entry_point_io
967            .get_mut(&(ep_index as usize))
968            .unwrap()
969            .input
970            .take()
971        {
972            Some(ep_input) => ep_input,
973            None => return Ok(()),
974        };
975        let mut fake_iter = ep_input.members.iter();
976        for (arg_index, arg) in func.arguments.iter().enumerate() {
977            write!(self.out, "{}", back::INDENT)?;
978            self.write_type(module, arg.ty)?;
979            let arg_name = &self.names[&NameKey::EntryPointArgument(ep_index, arg_index as u32)];
980            write!(self.out, " {arg_name}")?;
981            match module.types[arg.ty].inner {
982                TypeInner::Array { base, size, .. } => {
983                    self.write_array_size(module, base, size)?;
984                    write!(self.out, " = ")?;
985                    self.write_ep_argument_initialization(
986                        ep,
987                        &ep_input,
988                        fake_iter.next().unwrap(),
989                    )?;
990                    writeln!(self.out, ";")?;
991                }
992                TypeInner::Struct { ref members, .. } => {
993                    write!(self.out, " = {{ ")?;
994                    for index in 0..members.len() {
995                        if index != 0 {
996                            write!(self.out, ", ")?;
997                        }
998                        self.write_ep_argument_initialization(
999                            ep,
1000                            &ep_input,
1001                            fake_iter.next().unwrap(),
1002                        )?;
1003                    }
1004                    writeln!(self.out, " }};")?;
1005                }
1006                _ => {
1007                    write!(self.out, " = ")?;
1008                    self.write_ep_argument_initialization(
1009                        ep,
1010                        &ep_input,
1011                        fake_iter.next().unwrap(),
1012                    )?;
1013                    writeln!(self.out, ";")?;
1014                }
1015            }
1016        }
1017        assert!(fake_iter.next().is_none());
1018        Ok(())
1019    }
1020
1021    /// Helper method used to write global variables
1022    /// # Notes
1023    /// Always adds a newline
1024    fn write_global(
1025        &mut self,
1026        module: &Module,
1027        handle: Handle<crate::GlobalVariable>,
1028    ) -> BackendResult {
1029        let global = &module.global_variables[handle];
1030        let inner = &module.types[global.ty].inner;
1031
1032        let handle_ty = match *inner {
1033            TypeInner::BindingArray { ref base, .. } => &module.types[*base].inner,
1034            _ => inner,
1035        };
1036
1037        // External textures are handled entirely differently, so defer entirely to that method.
1038        // We do so prior to calling resolve_resource_binding() below, as we even need to resolve
1039        // their bindings separately.
1040        let is_external_texture = matches!(
1041            *handle_ty,
1042            TypeInner::Image {
1043                class: crate::ImageClass::External,
1044                ..
1045            }
1046        );
1047        if is_external_texture {
1048            return self.write_global_external_texture(module, handle, global);
1049        }
1050
1051        if let Some(ref binding) = global.binding {
1052            if let Err(err) = self.options.resolve_resource_binding(binding) {
1053                log::debug!(
1054                    "Skipping global {:?} (name {:?}) for being inaccessible: {}",
1055                    handle,
1056                    global.name,
1057                    err,
1058                );
1059                return Ok(());
1060            }
1061        }
1062
1063        // Samplers are handled entirely differently, so defer entirely to that method.
1064        let is_sampler = matches!(*handle_ty, TypeInner::Sampler { .. });
1065
1066        if is_sampler {
1067            return self.write_global_sampler(module, handle, global);
1068        }
1069
1070        // https://docs.microsoft.com/en-us/windows/win32/direct3dhlsl/dx-graphics-hlsl-variable-register
1071        let register_ty = match global.space {
1072            crate::AddressSpace::Function => unreachable!("Function address space"),
1073            crate::AddressSpace::Private => {
1074                write!(self.out, "static ")?;
1075                self.write_type(module, global.ty)?;
1076                ""
1077            }
1078            crate::AddressSpace::WorkGroup | crate::AddressSpace::TaskPayload => {
1079                write!(self.out, "groupshared ")?;
1080                self.write_type(module, global.ty)?;
1081                ""
1082            }
1083            crate::AddressSpace::Uniform => {
1084                // constant buffer declarations are expected to be inlined, e.g.
1085                // `cbuffer foo: register(b0) { field1: type1; }`
1086                write!(self.out, "cbuffer")?;
1087                "b"
1088            }
1089            crate::AddressSpace::Storage { access } => {
1090                if global
1091                    .memory_decorations
1092                    .contains(crate::MemoryDecorations::COHERENT)
1093                {
1094                    write!(self.out, "globallycoherent ")?;
1095                }
1096                let (prefix, register) = if access.contains(crate::StorageAccess::STORE) {
1097                    ("RW", "u")
1098                } else {
1099                    ("", "t")
1100                };
1101                write!(self.out, "{prefix}ByteAddressBuffer")?;
1102                register
1103            }
1104            crate::AddressSpace::Handle => {
1105                let register = match *handle_ty {
1106                    // all storage textures are UAV, unconditionally
1107                    TypeInner::Image {
1108                        class: crate::ImageClass::Storage { .. },
1109                        ..
1110                    } => "u",
1111                    _ => "t",
1112                };
1113                self.write_type(module, global.ty)?;
1114                register
1115            }
1116            crate::AddressSpace::Immediate => {
1117                // The type of the immediates will be wrapped in `ConstantBuffer`
1118                write!(self.out, "ConstantBuffer<")?;
1119                "b"
1120            }
1121            crate::AddressSpace::RayPayload | crate::AddressSpace::IncomingRayPayload => {
1122                unimplemented!()
1123            }
1124        };
1125
1126        // If the global is a immediate data write the type now because it will be a
1127        // generic argument to `ConstantBuffer`
1128        if global.space == crate::AddressSpace::Immediate {
1129            self.write_global_type(module, global.ty)?;
1130
1131            // need to write the array size if the type was emitted with `write_type`
1132            if let TypeInner::Array { base, size, .. } = module.types[global.ty].inner {
1133                self.write_array_size(module, base, size)?;
1134            }
1135
1136            // Close the angled brackets for the generic argument
1137            write!(self.out, ">")?;
1138        }
1139
1140        let name = &self.names[&NameKey::GlobalVariable(handle)];
1141        write!(self.out, " {name}")?;
1142
1143        // Immediates need to be assigned a binding explicitly by the consumer
1144        // since naga has no way to know the binding from the shader alone
1145        if global.space == crate::AddressSpace::Immediate {
1146            match module.types[global.ty].inner {
1147                TypeInner::Struct { .. } => {}
1148                _ => {
1149                    return Err(Error::Unimplemented(format!(
1150                        "push-constant '{name}' has non-struct type; tracked by: https://github.com/gfx-rs/wgpu/issues/5683"
1151                    )));
1152                }
1153            }
1154
1155            let target = self
1156                .options
1157                .immediates_target
1158                .as_ref()
1159                .expect("No bind target was defined for the immediates block");
1160            write!(self.out, ": register(b{}", target.register)?;
1161            if target.space != 0 {
1162                write!(self.out, ", space{}", target.space)?;
1163            }
1164            write!(self.out, ")")?;
1165        }
1166
1167        if let Some(ref binding) = global.binding {
1168            // this was already resolved earlier when we started evaluating an entry point.
1169            let bt = self.options.resolve_resource_binding(binding).unwrap();
1170
1171            // need to write the binding array size if the type was emitted with `write_type`
1172            if let TypeInner::BindingArray { base, size, .. } = module.types[global.ty].inner {
1173                if let Some(overridden_size) = bt.binding_array_size {
1174                    write!(self.out, "[{overridden_size}]")?;
1175                } else {
1176                    self.write_array_size(module, base, size)?;
1177                }
1178            }
1179
1180            write!(self.out, " : register({}{}", register_ty, bt.register)?;
1181            if bt.space != 0 {
1182                write!(self.out, ", space{}", bt.space)?;
1183            }
1184            write!(self.out, ")")?;
1185        } else {
1186            // need to write the array size if the type was emitted with `write_type`
1187            if let TypeInner::Array { base, size, .. } = module.types[global.ty].inner {
1188                self.write_array_size(module, base, size)?;
1189            }
1190            if global.space == crate::AddressSpace::Private {
1191                write!(self.out, " = ")?;
1192                if let Some(init) = global.init {
1193                    self.write_const_expression(module, init, &module.global_expressions)?;
1194                } else {
1195                    self.write_default_init(module, global.ty)?;
1196                }
1197            }
1198        }
1199
1200        if global.space == crate::AddressSpace::Uniform {
1201            write!(self.out, " {{ ")?;
1202
1203            self.write_global_type(module, global.ty)?;
1204
1205            write!(
1206                self.out,
1207                " {}",
1208                &self.names[&NameKey::GlobalVariable(handle)]
1209            )?;
1210
1211            // need to write the array size if the type was emitted with `write_type`
1212            if let TypeInner::Array { base, size, .. } = module.types[global.ty].inner {
1213                self.write_array_size(module, base, size)?;
1214            }
1215
1216            writeln!(self.out, "; }}")?;
1217        } else {
1218            writeln!(self.out, ";")?;
1219        }
1220
1221        Ok(())
1222    }
1223
1224    fn write_global_sampler(
1225        &mut self,
1226        module: &Module,
1227        handle: Handle<crate::GlobalVariable>,
1228        global: &crate::GlobalVariable,
1229    ) -> BackendResult {
1230        let binding = *global.binding.as_ref().unwrap();
1231
1232        let key = super::SamplerIndexBufferKey {
1233            group: binding.group,
1234        };
1235        self.write_wrapped_sampler_buffer(key)?;
1236
1237        // This was already validated, so we can confidently unwrap it.
1238        let bt = self.options.resolve_resource_binding(&binding).unwrap();
1239
1240        match module.types[global.ty].inner {
1241            TypeInner::Sampler { comparison } => {
1242                // If we are generating a static access, we create a variable for the sampler.
1243                //
1244                // This prevents the DXIL from containing multiple lookups for the sampler, which
1245                // the backend compiler will then have to eliminate. AMD does seem to be able to
1246                // eliminate these, but better safe than sorry.
1247
1248                write!(self.out, "static const ")?;
1249                self.write_type(module, global.ty)?;
1250
1251                let heap_var = if comparison {
1252                    COMPARISON_SAMPLER_HEAP_VAR
1253                } else {
1254                    SAMPLER_HEAP_VAR
1255                };
1256
1257                let index_buffer_name = &self.wrapped.sampler_index_buffers[&key];
1258                let name = &self.names[&NameKey::GlobalVariable(handle)];
1259                writeln!(
1260                    self.out,
1261                    " {name} = {heap_var}[{index_buffer_name}[{register}]];",
1262                    register = bt.register
1263                )?;
1264            }
1265            TypeInner::BindingArray { .. } => {
1266                // If we are generating a binding array, we cannot directly access the sampler as the index
1267                // into the sampler index buffer is unknown at compile time. Instead we generate a constant
1268                // that represents the "base" index into the sampler index buffer. This constant is added
1269                // to the user provided index to get the final index into the sampler index buffer.
1270
1271                let name = &self.names[&NameKey::GlobalVariable(handle)];
1272                writeln!(
1273                    self.out,
1274                    "static const uint {name} = {register};",
1275                    register = bt.register
1276                )?;
1277            }
1278            _ => unreachable!(),
1279        };
1280
1281        Ok(())
1282    }
1283
1284    /// Write the declarations for an external texture global variable.
1285    /// These are emitted as multiple global variables: Three `Texture2D`s
1286    /// (one for each plane) and a parameters cbuffer.
1287    fn write_global_external_texture(
1288        &mut self,
1289        module: &Module,
1290        handle: Handle<crate::GlobalVariable>,
1291        global: &crate::GlobalVariable,
1292    ) -> BackendResult {
1293        let res_binding = global
1294            .binding
1295            .as_ref()
1296            .expect("External texture global variables must have a resource binding");
1297        let ext_tex_bindings = match self
1298            .options
1299            .resolve_external_texture_resource_binding(res_binding)
1300        {
1301            Ok(bindings) => bindings,
1302            Err(err) => {
1303                log::debug!(
1304                    "Skipping global {:?} (name {:?}) for being inaccessible: {}",
1305                    handle,
1306                    global.name,
1307                    err,
1308                );
1309                return Ok(());
1310            }
1311        };
1312
1313        let mut write_plane = |bt: &super::BindTarget, name| -> BackendResult {
1314            write!(
1315                self.out,
1316                "Texture2D<float4> {}: register(t{}",
1317                name, bt.register
1318            )?;
1319            if bt.space != 0 {
1320                write!(self.out, ", space{}", bt.space)?;
1321            }
1322            writeln!(self.out, ");")?;
1323            Ok(())
1324        };
1325        for (i, bt) in ext_tex_bindings.planes.iter().enumerate() {
1326            let plane_name = &self.names
1327                [&NameKey::ExternalTextureGlobalVariable(handle, ExternalTextureNameKey::Plane(i))];
1328            write_plane(bt, plane_name)?;
1329        }
1330
1331        let params_name = &self.names
1332            [&NameKey::ExternalTextureGlobalVariable(handle, ExternalTextureNameKey::Params)];
1333        let params_ty_name =
1334            &self.names[&NameKey::Type(module.special_types.external_texture_params.unwrap())];
1335        write!(
1336            self.out,
1337            "cbuffer {}: register(b{}",
1338            params_name, ext_tex_bindings.params.register
1339        )?;
1340        if ext_tex_bindings.params.space != 0 {
1341            write!(self.out, ", space{}", ext_tex_bindings.params.space)?;
1342        }
1343        writeln!(self.out, ") {{ {params_ty_name} {params_name}; }};")?;
1344
1345        Ok(())
1346    }
1347
1348    /// Helper method used to write global constants
1349    ///
1350    /// # Notes
1351    /// Ends in a newline
1352    fn write_global_constant(
1353        &mut self,
1354        module: &Module,
1355        handle: Handle<crate::Constant>,
1356    ) -> BackendResult {
1357        write!(self.out, "static const ")?;
1358        let constant = &module.constants[handle];
1359        self.write_type(module, constant.ty)?;
1360        let name = &self.names[&NameKey::Constant(handle)];
1361        write!(self.out, " {name}")?;
1362        // Write size for array type
1363        if let TypeInner::Array { base, size, .. } = module.types[constant.ty].inner {
1364            self.write_array_size(module, base, size)?;
1365        }
1366        write!(self.out, " = ")?;
1367        self.write_const_expression(module, constant.init, &module.global_expressions)?;
1368        writeln!(self.out, ";")?;
1369        Ok(())
1370    }
1371
1372    pub(super) fn write_array_size(
1373        &mut self,
1374        module: &Module,
1375        base: Handle<crate::Type>,
1376        size: crate::ArraySize,
1377    ) -> BackendResult {
1378        write!(self.out, "[")?;
1379
1380        match size.resolve(module.to_ctx())? {
1381            proc::IndexableLength::Known(size) => {
1382                write!(self.out, "{size}")?;
1383            }
1384            proc::IndexableLength::Dynamic => unreachable!(),
1385        }
1386
1387        write!(self.out, "]")?;
1388
1389        if let TypeInner::Array {
1390            base: next_base,
1391            size: next_size,
1392            ..
1393        } = module.types[base].inner
1394        {
1395            self.write_array_size(module, next_base, next_size)?;
1396        }
1397
1398        Ok(())
1399    }
1400
1401    /// Helper method used to write structs
1402    ///
1403    /// # Notes
1404    /// Ends in a newline
1405    fn write_struct(
1406        &mut self,
1407        module: &Module,
1408        handle: Handle<crate::Type>,
1409        members: &[crate::StructMember],
1410        span: u32,
1411        shader_stage: Option<(ShaderStage, Io)>,
1412    ) -> BackendResult {
1413        // Write struct name
1414        let struct_name = &self.names[&NameKey::Type(handle)];
1415        writeln!(self.out, "struct {struct_name} {{")?;
1416
1417        let mut last_offset = 0;
1418        for (index, member) in members.iter().enumerate() {
1419            if member.binding.is_none() && member.offset > last_offset {
1420                // using int as padding should work as long as the backend
1421                // doesn't support a type that's less than 4 bytes in size
1422                // (Error::UnsupportedScalar catches this)
1423                let padding = (member.offset - last_offset) / 4;
1424                for i in 0..padding {
1425                    writeln!(self.out, "{}int _pad{}_{};", back::INDENT, index, i)?;
1426                }
1427            }
1428            let ty_inner = &module.types[member.ty].inner;
1429            last_offset = member.offset + ty_inner.size_hlsl(module.to_ctx())?;
1430
1431            // The indentation is only for readability
1432            write!(self.out, "{}", back::INDENT)?;
1433
1434            match module.types[member.ty].inner {
1435                TypeInner::Array { base, size, .. } => {
1436                    // HLSL arrays are written as `type name[size]`
1437
1438                    self.write_global_type(module, member.ty)?;
1439
1440                    // Write `name`
1441                    write!(
1442                        self.out,
1443                        " {}",
1444                        &self.names[&NameKey::StructMember(handle, index as u32)]
1445                    )?;
1446                    // Write [size]
1447                    self.write_array_size(module, base, size)?;
1448                }
1449                // We treat matrices of the form `matCx2` as a sequence of C `vec2`s.
1450                // See the module-level block comment in mod.rs for details.
1451                TypeInner::Matrix {
1452                    rows,
1453                    columns,
1454                    scalar,
1455                } if member.binding.is_none() && rows == crate::VectorSize::Bi => {
1456                    let vec_ty = TypeInner::Vector { size: rows, scalar };
1457                    let field_name_key = NameKey::StructMember(handle, index as u32);
1458
1459                    for i in 0..columns as u8 {
1460                        if i != 0 {
1461                            write!(self.out, "; ")?;
1462                        }
1463                        self.write_value_type(module, &vec_ty)?;
1464                        write!(self.out, " {}_{}", &self.names[&field_name_key], i)?;
1465                    }
1466                }
1467                _ => {
1468                    // Write modifier before type
1469                    if let Some(ref binding) = member.binding {
1470                        self.write_modifier(binding)?;
1471                    }
1472
1473                    // Even though Naga IR matrices are column-major, we must describe
1474                    // matrices passed from the CPU as being in row-major order.
1475                    // See the module-level block comment in mod.rs for details.
1476                    if let TypeInner::Matrix { .. } = module.types[member.ty].inner {
1477                        write!(self.out, "row_major ")?;
1478                    }
1479
1480                    // Write the member type and name
1481                    self.write_type(module, member.ty)?;
1482                    write!(
1483                        self.out,
1484                        " {}",
1485                        &self.names[&NameKey::StructMember(handle, index as u32)]
1486                    )?;
1487                }
1488            }
1489
1490            self.write_semantic(&member.binding, shader_stage)?;
1491            writeln!(self.out, ";")?;
1492        }
1493
1494        // add padding at the end since sizes of types don't get rounded up to their alignment in HLSL
1495        if members.last().unwrap().binding.is_none() && span > last_offset {
1496            let padding = (span - last_offset) / 4;
1497            for i in 0..padding {
1498                writeln!(self.out, "{}int _end_pad_{};", back::INDENT, i)?;
1499            }
1500        }
1501
1502        writeln!(self.out, "}};")?;
1503        Ok(())
1504    }
1505
1506    /// Helper method used to write global/structs non image/sampler types
1507    ///
1508    /// # Notes
1509    /// Adds no trailing or leading whitespace
1510    pub(super) fn write_global_type(
1511        &mut self,
1512        module: &Module,
1513        ty: Handle<crate::Type>,
1514    ) -> BackendResult {
1515        let matrix_data = get_inner_matrix_data(module, ty);
1516
1517        // We treat matrices of the form `matCx2` as a sequence of C `vec2`s.
1518        // See the module-level block comment in mod.rs for details.
1519        if let Some(MatrixType {
1520            columns,
1521            rows: crate::VectorSize::Bi,
1522            width: 4,
1523        }) = matrix_data
1524        {
1525            write!(self.out, "__mat{}x2", columns as u8)?;
1526        } else {
1527            // Even though Naga IR matrices are column-major, we must describe
1528            // matrices passed from the CPU as being in row-major order.
1529            // See the module-level block comment in mod.rs for details.
1530            if matrix_data.is_some() {
1531                write!(self.out, "row_major ")?;
1532            }
1533
1534            self.write_type(module, ty)?;
1535        }
1536
1537        Ok(())
1538    }
1539
1540    /// Helper method used to write non image/sampler types
1541    ///
1542    /// # Notes
1543    /// Adds no trailing or leading whitespace
1544    pub(super) fn write_type(&mut self, module: &Module, ty: Handle<crate::Type>) -> BackendResult {
1545        let inner = &module.types[ty].inner;
1546        match *inner {
1547            TypeInner::Struct { .. } => write!(self.out, "{}", self.names[&NameKey::Type(ty)])?,
1548            // hlsl array has the size separated from the base type
1549            TypeInner::Array { base, .. } | TypeInner::BindingArray { base, .. } => {
1550                self.write_type(module, base)?
1551            }
1552            ref other => self.write_value_type(module, other)?,
1553        }
1554
1555        Ok(())
1556    }
1557
1558    /// Helper method used to write value types
1559    ///
1560    /// # Notes
1561    /// Adds no trailing or leading whitespace
1562    pub(super) fn write_value_type(&mut self, module: &Module, inner: &TypeInner) -> BackendResult {
1563        match *inner {
1564            TypeInner::Scalar(scalar) | TypeInner::Atomic(scalar) => {
1565                write!(self.out, "{}", scalar.to_hlsl_str()?)?;
1566            }
1567            TypeInner::Vector { size, scalar } => {
1568                write!(
1569                    self.out,
1570                    "{}{}",
1571                    scalar.to_hlsl_str()?,
1572                    common::vector_size_str(size)
1573                )?;
1574            }
1575            TypeInner::Matrix {
1576                columns,
1577                rows,
1578                scalar,
1579            } => {
1580                // The IR supports only float matrix
1581                // https://docs.microsoft.com/en-us/windows/win32/direct3dhlsl/dx-graphics-hlsl-matrix
1582
1583                // Because of the implicit transpose all matrices have in HLSL, we need to transpose the size as well.
1584                write!(
1585                    self.out,
1586                    "{}{}x{}",
1587                    scalar.to_hlsl_str()?,
1588                    common::vector_size_str(columns),
1589                    common::vector_size_str(rows),
1590                )?;
1591            }
1592            TypeInner::Image {
1593                dim,
1594                arrayed,
1595                class,
1596            } => {
1597                self.write_image_type(dim, arrayed, class)?;
1598            }
1599            TypeInner::Sampler { comparison } => {
1600                let sampler = if comparison {
1601                    "SamplerComparisonState"
1602                } else {
1603                    "SamplerState"
1604                };
1605                write!(self.out, "{sampler}")?;
1606            }
1607            // HLSL arrays are written as `type name[size]`
1608            // Current code is written arrays only as `[size]`
1609            // Base `type` and `name` should be written outside
1610            TypeInner::Array { base, size, .. } | TypeInner::BindingArray { base, size } => {
1611                self.write_array_size(module, base, size)?;
1612            }
1613            TypeInner::AccelerationStructure { .. } => {
1614                write!(self.out, "RaytracingAccelerationStructure")?;
1615            }
1616            TypeInner::RayQuery { .. } => {
1617                // these are constant flags, there are dynamic flags also but constant flags are not supported by naga
1618                write!(self.out, "RayQuery<RAY_FLAG_NONE>")?;
1619            }
1620            _ => return Err(Error::Unimplemented(format!("write_value_type {inner:?}"))),
1621        }
1622
1623        Ok(())
1624    }
1625
1626    /// Helper method used to write functions
1627    /// # Notes
1628    /// Ends in a newline
1629    fn write_function(
1630        &mut self,
1631        module: &Module,
1632        name: &str,
1633        func: &crate::Function,
1634        func_ctx: &back::FunctionCtx<'_>,
1635        info: &valid::FunctionInfo,
1636        header: String,
1637    ) -> BackendResult {
1638        // Function Declaration Syntax - https://docs.microsoft.com/en-us/windows/win32/direct3dhlsl/dx-graphics-hlsl-function-syntax
1639
1640        self.update_expressions_to_bake(module, func, info);
1641        let ep = match func_ctx.ty {
1642            back::FunctionType::EntryPoint(idx) => Some(&module.entry_points[idx as usize]),
1643            back::FunctionType::Function(_) => None,
1644        };
1645
1646        let nested = matches!(
1647            ep,
1648            Some(crate::EntryPoint {
1649                stage: ShaderStage::Task | ShaderStage::Mesh,
1650                ..
1651            })
1652        );
1653        if !nested {
1654            write!(self.out, "{header}")?;
1655        }
1656
1657        if let Some(ref result) = func.result {
1658            // Write typedef if return type is an array
1659            let array_return_type = match module.types[result.ty].inner {
1660                TypeInner::Array { base, size, .. } => {
1661                    let array_return_type = self.namer.call(&format!("ret_{name}"));
1662                    write!(self.out, "typedef ")?;
1663                    self.write_type(module, result.ty)?;
1664                    write!(self.out, " {array_return_type}")?;
1665                    self.write_array_size(module, base, size)?;
1666                    writeln!(self.out, ";")?;
1667                    Some(array_return_type)
1668                }
1669                _ => None,
1670            };
1671
1672            // Write modifier
1673            if let Some(
1674                ref binding @ crate::Binding::BuiltIn(crate::BuiltIn::Position { invariant: true }),
1675            ) = result.binding
1676            {
1677                self.write_modifier(binding)?;
1678            }
1679
1680            // Write return type
1681            match func_ctx.ty {
1682                back::FunctionType::Function(_) => {
1683                    if let Some(array_return_type) = array_return_type {
1684                        write!(self.out, "{array_return_type}")?;
1685                    } else {
1686                        self.write_type(module, result.ty)?;
1687                    }
1688                }
1689                back::FunctionType::EntryPoint(index) => {
1690                    if let Some(ref ep_output) =
1691                        self.entry_point_io.get(&(index as usize)).unwrap().output
1692                    {
1693                        write!(self.out, "{}", ep_output.ty_name)?;
1694                    } else {
1695                        self.write_type(module, result.ty)?;
1696                    }
1697                }
1698            }
1699        } else {
1700            write!(self.out, "void")?;
1701        }
1702
1703        let nested_name = if nested {
1704            self.namer.call(&format!("_{name}"))
1705        } else {
1706            name.to_string()
1707        };
1708
1709        // Write function name
1710        write!(self.out, " {nested_name}(")?;
1711
1712        let need_workgroup_variables_initialization =
1713            self.need_workgroup_variables_initialization(func_ctx, module);
1714
1715        let mut any_args_written = false;
1716        let mut separator = || {
1717            if any_args_written {
1718                ", "
1719            } else {
1720                any_args_written = true;
1721                ""
1722            }
1723        };
1724
1725        let needs_local_invocation_index_name = need_workgroup_variables_initialization || nested;
1726        let mut local_invocation_index_name = None;
1727        // For nested entry points, collect arg names as we write them so that
1728        // write_nested_function_outer can pass the exact same names to the call site.
1729        let mut nested_wgsl_args: Vec<String> = Vec::new();
1730        let mut nested_task_payload_name: Option<String> = None;
1731        // Write function arguments for non entry point functions
1732        match func_ctx.ty {
1733            back::FunctionType::Function(handle) => {
1734                for (index, arg) in func.arguments.iter().enumerate() {
1735                    write!(self.out, "{}", separator())?;
1736                    self.write_function_argument(module, handle, arg, index)?;
1737                }
1738                // If this reads a task payload variable the variable needs to be passed as an `in` argument
1739                for (var_handle, var) in module.global_variables.iter() {
1740                    let uses = info[var_handle];
1741                    if uses.contains(valid::GlobalUse::READ)
1742                        && !uses.contains(valid::GlobalUse::WRITE)
1743                        && var.space == crate::AddressSpace::TaskPayload
1744                    {
1745                        self.function_task_payload_var.insert(handle, var_handle);
1746                        write!(self.out, "{}in ", separator())?;
1747
1748                        self.write_type(module, var.ty)?;
1749                        let name = &self.names[&NameKey::GlobalVariable(var_handle)];
1750                        write!(self.out, " {name}")?;
1751                        break;
1752                    }
1753                }
1754            }
1755            back::FunctionType::EntryPoint(ep_index) => {
1756                let ep = &module.entry_points[ep_index as usize];
1757                if let Some(ref ep_input) =
1758                    self.entry_point_io.get(&(ep_index as usize)).unwrap().input
1759                {
1760                    write!(self.out, "{} {}", ep_input.ty_name, ep_input.arg_name)?;
1761                    separator();
1762                    nested_wgsl_args.push(ep_input.arg_name.clone());
1763                } else {
1764                    let stage = ep.stage;
1765                    for (index, arg) in func.arguments.iter().enumerate() {
1766                        write!(self.out, "{}", separator())?;
1767                        self.write_type(module, arg.ty)?;
1768
1769                        let argument_name =
1770                            &self.names[&NameKey::EntryPointArgument(ep_index, index as u32)];
1771
1772                        if arg.binding
1773                            == Some(crate::Binding::BuiltIn(
1774                                crate::BuiltIn::LocalInvocationIndex,
1775                            ))
1776                        {
1777                            local_invocation_index_name = Some(argument_name.clone());
1778                        }
1779
1780                        nested_wgsl_args.push(argument_name.clone());
1781                        write!(self.out, " {argument_name}")?;
1782                        if let TypeInner::Array { base, size, .. } = module.types[arg.ty].inner {
1783                            self.write_array_size(module, base, size)?;
1784                        }
1785
1786                        self.write_semantic(&arg.binding, Some((stage, Io::Input)))?;
1787                    }
1788                }
1789                if ep.stage == ShaderStage::Mesh {
1790                    if let Some(var_handle) = ep.task_payload {
1791                        let var = &module.global_variables[var_handle];
1792                        write!(self.out, "{}in ", separator())?;
1793                        self.write_type(module, var.ty)?;
1794                        let arg_name = &self.names[&NameKey::GlobalVariable(var_handle)];
1795                        write!(self.out, " {arg_name}")?;
1796                        nested_task_payload_name = Some(arg_name.clone());
1797                        if let TypeInner::Array { base, size, .. } = module.types[var.ty].inner {
1798                            self.write_array_size(module, base, size)?;
1799                        }
1800                    }
1801                }
1802                if needs_local_invocation_index_name && local_invocation_index_name.is_none() {
1803                    let name = self.namer.call("local_invocation_index");
1804                    write!(self.out, "{}uint {name}", separator())?;
1805                    write!(self.out, " : SV_GroupIndex")?;
1806                    local_invocation_index_name = Some(name);
1807                }
1808            }
1809        }
1810        // Ends of arguments
1811        write!(self.out, ")")?;
1812
1813        // Write semantic if it present
1814        if let back::FunctionType::EntryPoint(index) = func_ctx.ty {
1815            let stage = module.entry_points[index as usize].stage;
1816            if let Some(crate::FunctionResult { ref binding, .. }) = func.result {
1817                self.write_semantic(binding, Some((stage, Io::Output)))?;
1818            }
1819        }
1820
1821        // Function body start
1822        writeln!(self.out)?;
1823        writeln!(self.out, "{{")?;
1824
1825        if need_workgroup_variables_initialization && !nested {
1826            let back::FunctionType::EntryPoint(index) = func_ctx.ty else {
1827                unreachable!();
1828            };
1829            writeln!(
1830                self.out,
1831                "{}if ({} == 0) {{",
1832                back::INDENT,
1833                // need_workgroup_variables_initialization forces this to be written
1834                // if the user doesn't specify it (so this must be Some())
1835                local_invocation_index_name.as_ref().unwrap(),
1836            )?;
1837            self.write_workgroup_variables_initialization(
1838                func_ctx,
1839                module,
1840                module.entry_points[index as usize].stage,
1841            )?;
1842
1843            writeln!(self.out, "{}}}", back::INDENT)?;
1844            self.write_control_barrier(crate::Barrier::WORK_GROUP, back::Level(1))?;
1845        }
1846
1847        if let back::FunctionType::EntryPoint(index) = func_ctx.ty {
1848            self.write_ep_arguments_initialization(module, func, index)?;
1849        }
1850
1851        // Write function local variables
1852        for (handle, local) in func.local_variables.iter() {
1853            // Write indentation (only for readability)
1854            write!(self.out, "{}", back::INDENT)?;
1855
1856            // Write the local name
1857            // The leading space is important
1858            self.write_type(module, local.ty)?;
1859            write!(self.out, " {}", self.names[&func_ctx.name_key(handle)])?;
1860            // Write size for array type
1861            if let TypeInner::Array { base, size, .. } = module.types[local.ty].inner {
1862                self.write_array_size(module, base, size)?;
1863            }
1864
1865            let is_ray_query = match module.types[local.ty].inner {
1866                // from https://microsoft.github.io/DirectX-Specs/d3d/Raytracing.html#tracerayinline-example-1 it seems that ray queries shouldn't be zeroed
1867                TypeInner::RayQuery { .. } => true,
1868                _ => {
1869                    write!(self.out, " = ")?;
1870                    // Write the local initializer if needed
1871                    if let Some(init) = local.init {
1872                        self.write_expr(module, init, func_ctx)?;
1873                    } else {
1874                        // Zero initialize local variables
1875                        self.write_default_init(module, local.ty)?;
1876                    }
1877                    false
1878                }
1879            };
1880            // Finish the local with `;` and add a newline (only for readability)
1881            writeln!(self.out, ";")?;
1882            // If it's a ray query, we also want a tracker variable
1883            if is_ray_query {
1884                write!(self.out, "{}", back::INDENT)?;
1885                self.write_value_type(module, &TypeInner::Scalar(Scalar::U32))?;
1886                writeln!(
1887                    self.out,
1888                    " {RAY_QUERY_TRACKER_VARIABLE_PREFIX}{} = 0;",
1889                    self.names[&func_ctx.name_key(handle)]
1890                )?;
1891            }
1892        }
1893
1894        if !func.local_variables.is_empty() {
1895            writeln!(self.out)?;
1896        }
1897
1898        // Write the function body (statement list)
1899        for sta in func.body.iter() {
1900            // The indentation should always be 1 when writing the function body
1901            self.write_stmt(module, sta, func_ctx, back::Level(1))?;
1902        }
1903
1904        writeln!(self.out, "}}")?;
1905
1906        if nested {
1907            self.write_nested_function_outer(
1908                module,
1909                func_ctx,
1910                &header,
1911                name,
1912                need_workgroup_variables_initialization,
1913                &nested_name,
1914                ep.unwrap(),
1915                NestedEntryPointArgs {
1916                    user_args: nested_wgsl_args,
1917                    task_payload: nested_task_payload_name,
1918                    // guaranteed to be set for nested functions (task/mesh shaders)
1919                    local_invocation_index: local_invocation_index_name.unwrap(),
1920                },
1921            )?;
1922        }
1923
1924        self.named_expressions.clear();
1925
1926        Ok(())
1927    }
1928
1929    fn write_function_argument(
1930        &mut self,
1931        module: &Module,
1932        handle: Handle<crate::Function>,
1933        arg: &crate::FunctionArgument,
1934        index: usize,
1935    ) -> BackendResult {
1936        // External texture arguments must be expanded into separate
1937        // arguments for each plane and the params buffer.
1938        if let TypeInner::Image {
1939            class: crate::ImageClass::External,
1940            ..
1941        } = module.types[arg.ty].inner
1942        {
1943            return self.write_function_external_texture_argument(module, handle, index);
1944        }
1945
1946        // Write argument type
1947        let arg_ty = match module.types[arg.ty].inner {
1948            // pointers in function arguments are expected and resolve to `inout`
1949            TypeInner::Pointer { base, .. } => {
1950                //TODO: can we narrow this down to just `in` when possible?
1951                write!(self.out, "inout ")?;
1952                base
1953            }
1954            _ => arg.ty,
1955        };
1956        self.write_type(module, arg_ty)?;
1957
1958        let argument_name = &self.names[&NameKey::FunctionArgument(handle, index as u32)];
1959
1960        // Write argument name. Space is important.
1961        write!(self.out, " {argument_name}")?;
1962        if let TypeInner::Array { base, size, .. } = module.types[arg_ty].inner {
1963            self.write_array_size(module, base, size)?;
1964        }
1965
1966        Ok(())
1967    }
1968
1969    fn write_function_external_texture_argument(
1970        &mut self,
1971        module: &Module,
1972        handle: Handle<crate::Function>,
1973        index: usize,
1974    ) -> BackendResult {
1975        let plane_names = [0, 1, 2].map(|i| {
1976            &self.names[&NameKey::ExternalTextureFunctionArgument(
1977                handle,
1978                index as u32,
1979                ExternalTextureNameKey::Plane(i),
1980            )]
1981        });
1982        let params_name = &self.names[&NameKey::ExternalTextureFunctionArgument(
1983            handle,
1984            index as u32,
1985            ExternalTextureNameKey::Params,
1986        )];
1987        let params_ty_name =
1988            &self.names[&NameKey::Type(module.special_types.external_texture_params.unwrap())];
1989        write!(
1990            self.out,
1991            "Texture2D<float4> {}, Texture2D<float4> {}, Texture2D<float4> {}, {params_ty_name} {params_name}",
1992            plane_names[0], plane_names[1], plane_names[2],
1993        )?;
1994        Ok(())
1995    }
1996
1997    fn need_workgroup_variables_initialization(
1998        &mut self,
1999        func_ctx: &back::FunctionCtx,
2000        module: &Module,
2001    ) -> bool {
2002        self.options.zero_initialize_workgroup_memory
2003            && func_ctx.ty.is_compute_like_entry_point(module)
2004            && module.global_variables.iter().any(|(handle, var)| {
2005                !func_ctx.info[handle].is_empty() && var.space.is_workgroup_like()
2006            })
2007    }
2008
2009    pub(super) fn write_workgroup_variables_initialization(
2010        &mut self,
2011        func_ctx: &back::FunctionCtx,
2012        module: &Module,
2013        stage: ShaderStage,
2014    ) -> BackendResult {
2015        let vars = module.global_variables.iter().filter(|&(handle, var)| {
2016            // Read-only in mesh shaders
2017            let task_needs_zero =
2018                (var.space == crate::AddressSpace::TaskPayload) && stage == ShaderStage::Task;
2019            !func_ctx.info[handle].is_empty()
2020                && (var.space == crate::AddressSpace::WorkGroup || task_needs_zero)
2021        });
2022
2023        for (handle, var) in vars {
2024            let name = &self.names[&NameKey::GlobalVariable(handle)];
2025            write!(self.out, "{}{} = ", back::Level(2), name)?;
2026            self.write_default_init(module, var.ty)?;
2027            writeln!(self.out, ";")?;
2028        }
2029        Ok(())
2030    }
2031
2032    /// Helper method used to write switches
2033    fn write_switch(
2034        &mut self,
2035        module: &Module,
2036        func_ctx: &back::FunctionCtx<'_>,
2037        level: back::Level,
2038        selector: Handle<crate::Expression>,
2039        cases: &[crate::SwitchCase],
2040    ) -> BackendResult {
2041        // Write all cases
2042        let indent_level_1 = level.next();
2043        let indent_level_2 = indent_level_1.next();
2044
2045        // See docs of `back::continue_forward` module.
2046        if let Some(variable) = self.continue_ctx.enter_switch(&mut self.namer) {
2047            writeln!(self.out, "{level}bool {variable} = false;",)?;
2048        };
2049
2050        // Check if there is only one body, by seeing if all except the last case are fall through
2051        // with empty bodies. FXC doesn't handle these switches correctly, so
2052        // we generate a `do {} while(false);` loop instead. There must be a default case, so there
2053        // is no need to check if one of the cases would have matched.
2054        let one_body = cases
2055            .iter()
2056            .rev()
2057            .skip(1)
2058            .all(|case| case.fall_through && case.body.is_empty());
2059        if one_body {
2060            // Start the do-while
2061            writeln!(self.out, "{level}do {{")?;
2062            // Note: Expressions have no side-effects so we don't need to emit selector expression.
2063
2064            // Body
2065            if let Some(case) = cases.last() {
2066                for sta in case.body.iter() {
2067                    self.write_stmt(module, sta, func_ctx, indent_level_1)?;
2068                }
2069            }
2070            // End do-while
2071            writeln!(self.out, "{level}}} while(false);")?;
2072        } else {
2073            // Start the switch
2074            write!(self.out, "{level}")?;
2075            write!(self.out, "switch(")?;
2076            self.write_expr(module, selector, func_ctx)?;
2077            writeln!(self.out, ") {{")?;
2078
2079            for (i, case) in cases.iter().enumerate() {
2080                match case.value {
2081                    crate::SwitchValue::I32(value) => {
2082                        write!(self.out, "{indent_level_1}case {value}:")?
2083                    }
2084                    crate::SwitchValue::U32(value) => {
2085                        write!(self.out, "{indent_level_1}case {value}u:")?
2086                    }
2087                    crate::SwitchValue::Default => write!(self.out, "{indent_level_1}default:")?,
2088                }
2089
2090                // The new block is not only stylistic, it plays a role here:
2091                // We might end up having to write the same case body
2092                // multiple times due to FXC not supporting fallthrough.
2093                // Therefore, some `Expression`s written by `Statement::Emit`
2094                // will end up having the same name (`_expr<handle_index>`).
2095                // So we need to put each case in its own scope.
2096                let write_block_braces = !(case.fall_through && case.body.is_empty());
2097                if write_block_braces {
2098                    writeln!(self.out, " {{")?;
2099                } else {
2100                    writeln!(self.out)?;
2101                }
2102
2103                // Although FXC does support a series of case clauses before
2104                // a block[^yes], it does not support fallthrough from a
2105                // non-empty case block to the next[^no]. If this case has a
2106                // non-empty body with a fallthrough, emulate that by
2107                // duplicating the bodies of all the cases it would fall
2108                // into as extensions of this case's own body. This makes
2109                // the HLSL output potentially quadratic in the size of the
2110                // Naga IR.
2111                //
2112                // [^yes]: ```hlsl
2113                // case 1:
2114                // case 2: do_stuff()
2115                // ```
2116                // [^no]: ```hlsl
2117                // case 1: do_this();
2118                // case 2: do_that();
2119                // ```
2120                if case.fall_through && !case.body.is_empty() {
2121                    let curr_len = i + 1;
2122                    let end_case_idx = curr_len
2123                        + cases
2124                            .iter()
2125                            .skip(curr_len)
2126                            .position(|case| !case.fall_through)
2127                            .unwrap();
2128                    let indent_level_3 = indent_level_2.next();
2129                    for case in &cases[i..=end_case_idx] {
2130                        writeln!(self.out, "{indent_level_2}{{")?;
2131                        let prev_len = self.named_expressions.len();
2132                        for sta in case.body.iter() {
2133                            self.write_stmt(module, sta, func_ctx, indent_level_3)?;
2134                        }
2135                        // Clear all named expressions that were previously inserted by the statements in the block
2136                        self.named_expressions.truncate(prev_len);
2137                        writeln!(self.out, "{indent_level_2}}}")?;
2138                    }
2139
2140                    let last_case = &cases[end_case_idx];
2141                    if last_case.body.last().is_none_or(|s| !s.is_terminator()) {
2142                        writeln!(self.out, "{indent_level_2}break;")?;
2143                    }
2144                } else {
2145                    for sta in case.body.iter() {
2146                        self.write_stmt(module, sta, func_ctx, indent_level_2)?;
2147                    }
2148                    if !case.fall_through && case.body.last().is_none_or(|s| !s.is_terminator()) {
2149                        writeln!(self.out, "{indent_level_2}break;")?;
2150                    }
2151                }
2152
2153                if write_block_braces {
2154                    writeln!(self.out, "{indent_level_1}}}")?;
2155                }
2156            }
2157
2158            writeln!(self.out, "{level}}}")?;
2159        }
2160
2161        // Handle any forwarded continue statements.
2162        use back::continue_forward::ExitControlFlow;
2163        let op = match self.continue_ctx.exit_switch() {
2164            ExitControlFlow::None => None,
2165            ExitControlFlow::Continue { variable } => Some(("continue", variable)),
2166            ExitControlFlow::Break { variable } => Some(("break", variable)),
2167        };
2168        if let Some((control_flow, variable)) = op {
2169            writeln!(self.out, "{level}if ({variable}) {{")?;
2170            writeln!(self.out, "{indent_level_1}{control_flow};")?;
2171            writeln!(self.out, "{level}}}")?;
2172        }
2173
2174        Ok(())
2175    }
2176
2177    fn write_index(
2178        &mut self,
2179        module: &Module,
2180        index: Index,
2181        func_ctx: &back::FunctionCtx<'_>,
2182    ) -> BackendResult {
2183        match index {
2184            Index::Static(index) => {
2185                write!(self.out, "{index}")?;
2186            }
2187            Index::Expression(index) => {
2188                self.write_expr(module, index, func_ctx)?;
2189            }
2190        }
2191        Ok(())
2192    }
2193
2194    /// Helper method used to write statements
2195    ///
2196    /// # Notes
2197    /// Always adds a newline
2198    fn write_stmt(
2199        &mut self,
2200        module: &Module,
2201        stmt: &crate::Statement,
2202        func_ctx: &back::FunctionCtx<'_>,
2203        level: back::Level,
2204    ) -> BackendResult {
2205        use crate::Statement;
2206
2207        match *stmt {
2208            Statement::Emit(ref range) => {
2209                for handle in range.clone() {
2210                    let ptr_class = func_ctx.resolve_type(handle, &module.types).pointer_space();
2211                    let expr_name = if ptr_class.is_some() {
2212                        // HLSL can't save a pointer-valued expression in a variable,
2213                        // but we shouldn't ever need to: they should never be named expressions,
2214                        // and none of the expression types flagged by bake_ref_count can be pointer-valued.
2215                        None
2216                    } else if let Some(name) = func_ctx.named_expressions.get(&handle) {
2217                        // Front end provides names for all variables at the start of writing.
2218                        // But we write them to step by step. We need to recache them
2219                        // Otherwise, we could accidentally write variable name instead of full expression.
2220                        // Also, we use sanitized names! It defense backend from generating variable with name from reserved keywords.
2221                        Some(self.namer.call(name))
2222                    } else if self.need_bake_expressions.contains(&handle) {
2223                        Some(Baked(handle).to_string())
2224                    } else {
2225                        None
2226                    };
2227
2228                    if let Some(name) = expr_name {
2229                        write!(self.out, "{level}")?;
2230                        self.write_named_expr(module, handle, name, handle, func_ctx)?;
2231                    }
2232                }
2233            }
2234            // TODO: copy-paste from glsl-out
2235            Statement::Block(ref block) => {
2236                write!(self.out, "{level}")?;
2237                writeln!(self.out, "{{")?;
2238                for sta in block.iter() {
2239                    // Increase the indentation to help with readability
2240                    self.write_stmt(module, sta, func_ctx, level.next())?
2241                }
2242                writeln!(self.out, "{level}}}")?
2243            }
2244            // TODO: copy-paste from glsl-out
2245            Statement::If {
2246                condition,
2247                ref accept,
2248                ref reject,
2249            } => {
2250                write!(self.out, "{level}")?;
2251                write!(self.out, "if (")?;
2252                self.write_expr(module, condition, func_ctx)?;
2253                writeln!(self.out, ") {{")?;
2254
2255                let l2 = level.next();
2256                for sta in accept {
2257                    // Increase indentation to help with readability
2258                    self.write_stmt(module, sta, func_ctx, l2)?;
2259                }
2260
2261                // If there are no statements in the reject block we skip writing it
2262                // This is only for readability
2263                if !reject.is_empty() {
2264                    writeln!(self.out, "{level}}} else {{")?;
2265
2266                    for sta in reject {
2267                        // Increase indentation to help with readability
2268                        self.write_stmt(module, sta, func_ctx, l2)?;
2269                    }
2270                }
2271
2272                writeln!(self.out, "{level}}}")?
2273            }
2274            // TODO: copy-paste from glsl-out
2275            Statement::Kill => writeln!(self.out, "{level}discard;")?,
2276            Statement::Return { value: None } => {
2277                writeln!(self.out, "{level}return;")?;
2278            }
2279            Statement::Return { value: Some(expr) } => {
2280                let base_ty_res = &func_ctx.info[expr].ty;
2281                let mut resolved = base_ty_res.inner_with(&module.types);
2282                if let TypeInner::Pointer { base, space: _ } = *resolved {
2283                    resolved = &module.types[base].inner;
2284                }
2285
2286                if let TypeInner::Struct { .. } = *resolved {
2287                    // We can safely unwrap here, since we now we working with struct
2288                    let ty = base_ty_res.handle().unwrap();
2289                    let struct_name = &self.names[&NameKey::Type(ty)];
2290                    let variable_name = self.namer.call(&struct_name.to_lowercase());
2291                    write!(self.out, "{level}const {struct_name} {variable_name} = ",)?;
2292                    self.write_expr(module, expr, func_ctx)?;
2293                    writeln!(self.out, ";")?;
2294
2295                    // for entry point returns, we may need to reshuffle the outputs into a different struct
2296                    let ep_output = match func_ctx.ty {
2297                        back::FunctionType::Function(_) => None,
2298                        back::FunctionType::EntryPoint(index) => self
2299                            .entry_point_io
2300                            .get(&(index as usize))
2301                            .unwrap()
2302                            .output
2303                            .as_ref(),
2304                    };
2305                    let final_name = match ep_output {
2306                        Some(ep_output) => {
2307                            let final_name = self.namer.call(&variable_name);
2308                            write!(
2309                                self.out,
2310                                "{}const {} {} = {{ ",
2311                                level, ep_output.ty_name, final_name,
2312                            )?;
2313                            for (index, m) in ep_output.members.iter().enumerate() {
2314                                if index != 0 {
2315                                    write!(self.out, ", ")?;
2316                                }
2317                                let member_name = &self.names[&NameKey::StructMember(ty, m.index)];
2318                                write!(self.out, "{variable_name}.{member_name}")?;
2319                            }
2320                            writeln!(self.out, " }};")?;
2321                            final_name
2322                        }
2323                        None => variable_name,
2324                    };
2325                    writeln!(self.out, "{level}return {final_name};")?;
2326                } else {
2327                    write!(self.out, "{level}return ")?;
2328                    self.write_expr(module, expr, func_ctx)?;
2329                    writeln!(self.out, ";")?
2330                }
2331            }
2332            Statement::Store { pointer, value } => {
2333                let ty_inner = func_ctx.resolve_type(pointer, &module.types);
2334                if ty_inner.is_atomic_pointer(&module.types) {
2335                    let pointer_space = ty_inner.pointer_space().unwrap();
2336                    let dummy = self.namer.call("dummy");
2337                    write!(self.out, "{level}{{ ")?;
2338                    if let TypeInner::Pointer { base, .. } = *ty_inner {
2339                        self.write_value_type(module, &module.types[base].inner)?;
2340                    }
2341                    write!(self.out, " {dummy} = 0; ")?;
2342                    match pointer_space {
2343                        crate::AddressSpace::WorkGroup => {
2344                            write!(self.out, "InterlockedExchange(")?;
2345                            self.write_expr(module, pointer, func_ctx)?;
2346                        }
2347                        crate::AddressSpace::Storage { .. } => {
2348                            let var_handle = self.fill_access_chain(module, pointer, func_ctx)?;
2349                            let var_name = &self.names[&NameKey::GlobalVariable(var_handle)];
2350                            write!(self.out, "{var_name}.InterlockedExchange(")?;
2351                            let chain = mem::take(&mut self.temp_access_chain);
2352                            self.write_storage_address(module, &chain, func_ctx)?;
2353                            self.temp_access_chain = chain;
2354                        }
2355                        _ => unreachable!(),
2356                    }
2357                    write!(self.out, ", ")?;
2358                    self.write_expr(module, value, func_ctx)?;
2359                    writeln!(self.out, ", {dummy}); }}")?;
2360                } else if let Some(crate::AddressSpace::Storage { .. }) = ty_inner.pointer_space() {
2361                    let var_handle = self.fill_access_chain(module, pointer, func_ctx)?;
2362                    self.write_storage_store(
2363                        module,
2364                        var_handle,
2365                        StoreValue::Expression(value),
2366                        func_ctx,
2367                        level,
2368                        None,
2369                    )?;
2370                } else {
2371                    // We treat matrices of the form `matCx2` as a sequence of C `vec2`s.
2372                    // See the module-level block comment in mod.rs for details.
2373                    //
2374                    // We handle matrix Stores here directly (including sub accesses for Vectors and Scalars).
2375                    // Loads are handled by `Expression::AccessIndex` (since sub accesses work fine for Loads).
2376                    enum MatrixAccess {
2377                        Direct {
2378                            base: Handle<crate::Expression>,
2379                            index: u32,
2380                        },
2381                        Struct {
2382                            columns: crate::VectorSize,
2383                            base: Handle<crate::Expression>,
2384                        },
2385                    }
2386
2387                    let get_members = |expr: Handle<crate::Expression>| {
2388                        let resolved = func_ctx.resolve_type(expr, &module.types);
2389                        match *resolved {
2390                            TypeInner::Pointer { base, .. } => match module.types[base].inner {
2391                                TypeInner::Struct { ref members, .. } => Some(members),
2392                                _ => None,
2393                            },
2394                            _ => None,
2395                        }
2396                    };
2397
2398                    write!(self.out, "{level}")?;
2399
2400                    let matrix_access_on_lhs =
2401                        find_matrix_in_access_chain(module, pointer, func_ctx).and_then(
2402                            |(matrix_expr, vector, scalar)| match (
2403                                func_ctx.resolve_type(matrix_expr, &module.types),
2404                                &func_ctx.expressions[matrix_expr],
2405                            ) {
2406                                (
2407                                    &TypeInner::Pointer { base: ty, .. },
2408                                    &crate::Expression::AccessIndex { base, index },
2409                                ) if matches!(
2410                                    module.types[ty].inner,
2411                                    TypeInner::Matrix {
2412                                        rows: crate::VectorSize::Bi,
2413                                        ..
2414                                    }
2415                                ) && get_members(base)
2416                                    .map(|members| members[index as usize].binding.is_none())
2417                                    == Some(true) =>
2418                                {
2419                                    Some((MatrixAccess::Direct { base, index }, vector, scalar))
2420                                }
2421                                _ => {
2422                                    if let Some(MatrixType {
2423                                        columns,
2424                                        rows: crate::VectorSize::Bi,
2425                                        width: 4,
2426                                    }) = get_inner_matrix_of_struct_array_member(
2427                                        module,
2428                                        matrix_expr,
2429                                        func_ctx,
2430                                        true,
2431                                    ) {
2432                                        Some((
2433                                            MatrixAccess::Struct {
2434                                                columns,
2435                                                base: matrix_expr,
2436                                            },
2437                                            vector,
2438                                            scalar,
2439                                        ))
2440                                    } else {
2441                                        None
2442                                    }
2443                                }
2444                            },
2445                        );
2446
2447                    match matrix_access_on_lhs {
2448                        Some((MatrixAccess::Direct { index, base }, vector, scalar)) => {
2449                            let base_ty_res = &func_ctx.info[base].ty;
2450                            let resolved = base_ty_res.inner_with(&module.types);
2451                            let ty = match *resolved {
2452                                TypeInner::Pointer { base, .. } => base,
2453                                _ => base_ty_res.handle().unwrap(),
2454                            };
2455
2456                            if let Some(Index::Static(vec_index)) = vector {
2457                                self.write_expr(module, base, func_ctx)?;
2458                                write!(
2459                                    self.out,
2460                                    ".{}_{}",
2461                                    &self.names[&NameKey::StructMember(ty, index)],
2462                                    vec_index
2463                                )?;
2464
2465                                if let Some(scalar_index) = scalar {
2466                                    write!(self.out, "[")?;
2467                                    self.write_index(module, scalar_index, func_ctx)?;
2468                                    write!(self.out, "]")?;
2469                                }
2470
2471                                write!(self.out, " = ")?;
2472                                self.write_expr(module, value, func_ctx)?;
2473                                writeln!(self.out, ";")?;
2474                            } else {
2475                                let access = WrappedStructMatrixAccess { ty, index };
2476                                match (&vector, &scalar) {
2477                                    (&Some(_), &Some(_)) => {
2478                                        self.write_wrapped_struct_matrix_set_scalar_function_name(
2479                                            access,
2480                                        )?;
2481                                    }
2482                                    (&Some(_), &None) => {
2483                                        self.write_wrapped_struct_matrix_set_vec_function_name(
2484                                            access,
2485                                        )?;
2486                                    }
2487                                    (&None, _) => {
2488                                        self.write_wrapped_struct_matrix_set_function_name(access)?;
2489                                    }
2490                                }
2491
2492                                write!(self.out, "(")?;
2493                                self.write_expr(module, base, func_ctx)?;
2494                                write!(self.out, ", ")?;
2495                                self.write_expr(module, value, func_ctx)?;
2496
2497                                if let Some(Index::Expression(vec_index)) = vector {
2498                                    write!(self.out, ", ")?;
2499                                    self.write_expr(module, vec_index, func_ctx)?;
2500
2501                                    if let Some(scalar_index) = scalar {
2502                                        write!(self.out, ", ")?;
2503                                        self.write_index(module, scalar_index, func_ctx)?;
2504                                    }
2505                                }
2506                                writeln!(self.out, ");")?;
2507                            }
2508                        }
2509                        Some((
2510                            MatrixAccess::Struct { columns, base },
2511                            Some(Index::Expression(vec_index)),
2512                            scalar,
2513                        )) => {
2514                            // We handle `Store`s to __matCx2 column vectors and scalar elements via
2515                            // the previously injected functions __set_col_of_matCx2 / __set_el_of_matCx2.
2516
2517                            if scalar.is_some() {
2518                                write!(self.out, "__set_el_of_mat{}x2", columns as u8)?;
2519                            } else {
2520                                write!(self.out, "__set_col_of_mat{}x2", columns as u8)?;
2521                            }
2522                            write!(self.out, "(")?;
2523                            self.write_expr(module, base, func_ctx)?;
2524                            write!(self.out, ", ")?;
2525                            self.write_expr(module, vec_index, func_ctx)?;
2526
2527                            if let Some(scalar_index) = scalar {
2528                                write!(self.out, ", ")?;
2529                                self.write_index(module, scalar_index, func_ctx)?;
2530                            }
2531
2532                            write!(self.out, ", ")?;
2533                            self.write_expr(module, value, func_ctx)?;
2534
2535                            writeln!(self.out, ");")?;
2536                        }
2537                        Some((MatrixAccess::Struct { .. }, Some(Index::Static(_)), _))
2538                        | Some((MatrixAccess::Struct { .. }, None, _))
2539                        | None => {
2540                            self.write_expr(module, pointer, func_ctx)?;
2541                            write!(self.out, " = ")?;
2542
2543                            // We cast the RHS of this store in cases where the LHS
2544                            // is a struct member with type:
2545                            //  - matCx2 or
2546                            //  - a (possibly nested) array of matCx2's
2547                            if let Some(MatrixType {
2548                                columns,
2549                                rows: crate::VectorSize::Bi,
2550                                width: 4,
2551                            }) = get_inner_matrix_of_struct_array_member(
2552                                module, pointer, func_ctx, false,
2553                            ) {
2554                                let mut resolved = func_ctx.resolve_type(pointer, &module.types);
2555                                if let TypeInner::Pointer { base, .. } = *resolved {
2556                                    resolved = &module.types[base].inner;
2557                                }
2558
2559                                write!(self.out, "(__mat{}x2", columns as u8)?;
2560                                if let TypeInner::Array { base, size, .. } = *resolved {
2561                                    self.write_array_size(module, base, size)?;
2562                                }
2563                                write!(self.out, ")")?;
2564                            }
2565
2566                            self.write_expr(module, value, func_ctx)?;
2567                            writeln!(self.out, ";")?
2568                        }
2569                    }
2570                }
2571            }
2572            Statement::Loop {
2573                ref body,
2574                ref continuing,
2575                break_if,
2576            } => {
2577                let force_loop_bound_statements = self.gen_force_bounded_loop_statements(level);
2578                let gate_name = (!continuing.is_empty() || break_if.is_some())
2579                    .then(|| self.namer.call("loop_init"));
2580
2581                if let Some((ref decl, _)) = force_loop_bound_statements {
2582                    writeln!(self.out, "{decl}")?;
2583                }
2584                if let Some(ref gate_name) = gate_name {
2585                    writeln!(self.out, "{level}bool {gate_name} = true;")?;
2586                }
2587
2588                self.continue_ctx.enter_loop();
2589                writeln!(self.out, "{level}while(true) {{")?;
2590                if let Some((_, ref break_and_inc)) = force_loop_bound_statements {
2591                    writeln!(self.out, "{break_and_inc}")?;
2592                }
2593                let l2 = level.next();
2594                if let Some(gate_name) = gate_name {
2595                    writeln!(self.out, "{l2}if (!{gate_name}) {{")?;
2596                    let l3 = l2.next();
2597                    for sta in continuing.iter() {
2598                        self.write_stmt(module, sta, func_ctx, l3)?;
2599                    }
2600                    if let Some(condition) = break_if {
2601                        write!(self.out, "{l3}if (")?;
2602                        self.write_expr(module, condition, func_ctx)?;
2603                        writeln!(self.out, ") {{")?;
2604                        writeln!(self.out, "{}break;", l3.next())?;
2605                        writeln!(self.out, "{l3}}}")?;
2606                    }
2607                    writeln!(self.out, "{l2}}}")?;
2608                    writeln!(self.out, "{l2}{gate_name} = false;")?;
2609                }
2610
2611                for sta in body.iter() {
2612                    self.write_stmt(module, sta, func_ctx, l2)?;
2613                }
2614
2615                writeln!(self.out, "{level}}}")?;
2616                self.continue_ctx.exit_loop();
2617            }
2618            Statement::Break => writeln!(self.out, "{level}break;")?,
2619            Statement::Continue => {
2620                if let Some(variable) = self.continue_ctx.continue_encountered() {
2621                    writeln!(self.out, "{level}{variable} = true;")?;
2622                    writeln!(self.out, "{level}break;")?
2623                } else {
2624                    writeln!(self.out, "{level}continue;")?
2625                }
2626            }
2627            Statement::ControlBarrier(barrier) => {
2628                self.write_control_barrier(barrier, level)?;
2629            }
2630            Statement::MemoryBarrier(barrier) => {
2631                self.write_memory_barrier(barrier, level)?;
2632            }
2633            Statement::ImageStore {
2634                image,
2635                coordinate,
2636                array_index,
2637                value,
2638            } => {
2639                write!(self.out, "{level}")?;
2640                self.write_expr(module, image, func_ctx)?;
2641
2642                write!(self.out, "[")?;
2643                if let Some(index) = array_index {
2644                    // Array index accepted only for texture_storage_2d_array, so we can safety use int3(coordinate, array_index) here
2645                    write!(self.out, "int3(")?;
2646                    self.write_expr(module, coordinate, func_ctx)?;
2647                    write!(self.out, ", ")?;
2648                    self.write_expr(module, index, func_ctx)?;
2649                    write!(self.out, ")")?;
2650                } else {
2651                    self.write_expr(module, coordinate, func_ctx)?;
2652                }
2653                write!(self.out, "]")?;
2654
2655                write!(self.out, " = ")?;
2656                self.write_expr(module, value, func_ctx)?;
2657                writeln!(self.out, ";")?;
2658            }
2659            Statement::Call {
2660                function,
2661                ref arguments,
2662                result,
2663            } => {
2664                write!(self.out, "{level}")?;
2665
2666                if let Some(expr) = result {
2667                    write!(self.out, "const ")?;
2668                    let name = Baked(expr).to_string();
2669                    let expr_ty = &func_ctx.info[expr].ty;
2670                    let ty_inner = match *expr_ty {
2671                        proc::TypeResolution::Handle(handle) => {
2672                            self.write_type(module, handle)?;
2673                            &module.types[handle].inner
2674                        }
2675                        proc::TypeResolution::Value(ref value) => {
2676                            self.write_value_type(module, value)?;
2677                            value
2678                        }
2679                    };
2680                    write!(self.out, " {name}")?;
2681                    if let TypeInner::Array { base, size, .. } = *ty_inner {
2682                        self.write_array_size(module, base, size)?;
2683                    }
2684                    write!(self.out, " = ")?;
2685                    self.named_expressions.insert(expr, name);
2686                }
2687                let func_name = &self.names[&NameKey::Function(function)];
2688                write!(self.out, "{func_name}(")?;
2689                let mut any_args_written = false;
2690                let mut separator = || {
2691                    if any_args_written {
2692                        ", "
2693                    } else {
2694                        any_args_written = true;
2695                        ""
2696                    }
2697                };
2698                for argument in arguments {
2699                    write!(self.out, "{}", separator())?;
2700                    self.write_expr(module, *argument, func_ctx)?;
2701                }
2702                if let Some(&var) = self.function_task_payload_var.get(&function) {
2703                    let name = &self.names[&NameKey::GlobalVariable(var)];
2704                    // Pass it through directly, whether its an in variable to this function or the global variable
2705                    write!(self.out, "{}{name}", separator())?;
2706                }
2707                writeln!(self.out, ");")?;
2708            }
2709            Statement::Atomic {
2710                pointer,
2711                ref fun,
2712                value,
2713                result,
2714            } => {
2715                write!(self.out, "{level}")?;
2716                let res_var_info = if let Some(res_handle) = result {
2717                    let name = Baked(res_handle).to_string();
2718                    match func_ctx.info[res_handle].ty {
2719                        proc::TypeResolution::Handle(handle) => self.write_type(module, handle)?,
2720                        proc::TypeResolution::Value(ref value) => {
2721                            self.write_value_type(module, value)?
2722                        }
2723                    };
2724                    write!(self.out, " {name}; ")?;
2725                    self.named_expressions.insert(res_handle, name.clone());
2726                    Some((res_handle, name))
2727                } else {
2728                    None
2729                };
2730                let pointer_space = func_ctx
2731                    .resolve_type(pointer, &module.types)
2732                    .pointer_space()
2733                    .unwrap();
2734                let fun_str = fun.to_hlsl_suffix();
2735                let compare_expr = match *fun {
2736                    crate::AtomicFunction::Exchange { compare: Some(cmp) } => Some(cmp),
2737                    _ => None,
2738                };
2739                match pointer_space {
2740                    crate::AddressSpace::WorkGroup => {
2741                        write!(self.out, "Interlocked{fun_str}(")?;
2742                        self.write_expr(module, pointer, func_ctx)?;
2743                        self.emit_hlsl_atomic_tail(
2744                            module,
2745                            func_ctx,
2746                            fun,
2747                            compare_expr,
2748                            value,
2749                            &res_var_info,
2750                        )?;
2751                    }
2752                    crate::AddressSpace::Storage { .. } => {
2753                        let var_handle = self.fill_access_chain(module, pointer, func_ctx)?;
2754                        let var_name = &self.names[&NameKey::GlobalVariable(var_handle)];
2755                        let width = match func_ctx.resolve_type(value, &module.types) {
2756                            &TypeInner::Scalar(Scalar { width: 8, .. }) => "64",
2757                            _ => "",
2758                        };
2759                        write!(self.out, "{var_name}.Interlocked{fun_str}{width}(")?;
2760                        let chain = mem::take(&mut self.temp_access_chain);
2761                        self.write_storage_address(module, &chain, func_ctx)?;
2762                        self.temp_access_chain = chain;
2763                        self.emit_hlsl_atomic_tail(
2764                            module,
2765                            func_ctx,
2766                            fun,
2767                            compare_expr,
2768                            value,
2769                            &res_var_info,
2770                        )?;
2771                    }
2772                    ref other => {
2773                        return Err(Error::Custom(format!(
2774                            "invalid address space {other:?} for atomic statement"
2775                        )))
2776                    }
2777                }
2778                if let Some(cmp) = compare_expr {
2779                    if let Some(&(_res_handle, ref res_name)) = res_var_info.as_ref() {
2780                        write!(
2781                            self.out,
2782                            "{level}{res_name}.exchanged = ({res_name}.old_value == "
2783                        )?;
2784                        self.write_expr(module, cmp, func_ctx)?;
2785                        writeln!(self.out, ");")?;
2786                    }
2787                }
2788            }
2789            Statement::ImageAtomic {
2790                image,
2791                coordinate,
2792                array_index,
2793                fun,
2794                value,
2795            } => {
2796                write!(self.out, "{level}")?;
2797
2798                let fun_str = fun.to_hlsl_suffix();
2799                write!(self.out, "Interlocked{fun_str}(")?;
2800                self.write_expr(module, image, func_ctx)?;
2801                write!(self.out, "[")?;
2802                self.write_texture_coordinates(
2803                    "int",
2804                    coordinate,
2805                    array_index,
2806                    None,
2807                    module,
2808                    func_ctx,
2809                )?;
2810                write!(self.out, "],")?;
2811
2812                self.write_expr(module, value, func_ctx)?;
2813                writeln!(self.out, ");")?;
2814            }
2815            Statement::WorkGroupUniformLoad { pointer, result } => {
2816                self.write_control_barrier(crate::Barrier::WORK_GROUP, level)?;
2817                write!(self.out, "{level}")?;
2818                let name = Baked(result).to_string();
2819                self.write_named_expr(module, pointer, name, result, func_ctx)?;
2820
2821                self.write_control_barrier(crate::Barrier::WORK_GROUP, level)?;
2822            }
2823            Statement::Switch {
2824                selector,
2825                ref cases,
2826            } => {
2827                self.write_switch(module, func_ctx, level, selector, cases)?;
2828            }
2829            Statement::RayQuery { query, ref fun } => {
2830                // There are three possibilities for a ptr to be:
2831                // 1. A variable
2832                // 2. A function argument
2833                // 3. part of a struct
2834                //
2835                // 2 and 3 are not possible, a ray query (in naga IR)
2836                // is not allowed to be passed into a function, and
2837                // all languages disallow it in a struct (you get fun results if
2838                // you try it :) ).
2839                //
2840                // Therefore, the ray query expression must be a variable.
2841                let crate::Expression::LocalVariable(query_var) = func_ctx.expressions[query]
2842                else {
2843                    unreachable!()
2844                };
2845
2846                let tracker_expr_name = format!(
2847                    "{RAY_QUERY_TRACKER_VARIABLE_PREFIX}{}",
2848                    self.names[&func_ctx.name_key(query_var)]
2849                );
2850
2851                match *fun {
2852                    RayQueryFunction::Initialize {
2853                        acceleration_structure,
2854                        descriptor,
2855                    } => {
2856                        self.write_initialize_function(
2857                            module,
2858                            level,
2859                            query,
2860                            acceleration_structure,
2861                            descriptor,
2862                            &tracker_expr_name,
2863                            func_ctx,
2864                        )?;
2865                    }
2866                    RayQueryFunction::Proceed { result } => {
2867                        self.write_proceed(
2868                            module,
2869                            level,
2870                            query,
2871                            result,
2872                            &tracker_expr_name,
2873                            func_ctx,
2874                        )?;
2875                    }
2876                    RayQueryFunction::GenerateIntersection { hit_t } => {
2877                        self.write_generate_intersection(
2878                            module,
2879                            level,
2880                            query,
2881                            hit_t,
2882                            &tracker_expr_name,
2883                            func_ctx,
2884                        )?;
2885                    }
2886                    RayQueryFunction::ConfirmIntersection => {
2887                        self.write_confirm_intersection(
2888                            module,
2889                            level,
2890                            query,
2891                            &tracker_expr_name,
2892                            func_ctx,
2893                        )?;
2894                    }
2895                    RayQueryFunction::Terminate => {
2896                        self.write_terminate(module, level, query, &tracker_expr_name, func_ctx)?;
2897                    }
2898                }
2899            }
2900            Statement::SubgroupBallot { result, predicate } => {
2901                write!(self.out, "{level}")?;
2902                let name = Baked(result).to_string();
2903                write!(self.out, "const uint4 {name} = ")?;
2904                self.named_expressions.insert(result, name);
2905
2906                write!(self.out, "WaveActiveBallot(")?;
2907                match predicate {
2908                    Some(predicate) => self.write_expr(module, predicate, func_ctx)?,
2909                    None => write!(self.out, "true")?,
2910                }
2911                writeln!(self.out, ");")?;
2912            }
2913            Statement::SubgroupCollectiveOperation {
2914                op,
2915                collective_op,
2916                argument,
2917                result,
2918            } => {
2919                write!(self.out, "{level}")?;
2920                write!(self.out, "const ")?;
2921                let name = Baked(result).to_string();
2922                match func_ctx.info[result].ty {
2923                    proc::TypeResolution::Handle(handle) => self.write_type(module, handle)?,
2924                    proc::TypeResolution::Value(ref value) => {
2925                        self.write_value_type(module, value)?
2926                    }
2927                };
2928                write!(self.out, " {name} = ")?;
2929                self.named_expressions.insert(result, name);
2930
2931                match (collective_op, op) {
2932                    (crate::CollectiveOperation::Reduce, crate::SubgroupOperation::All) => {
2933                        write!(self.out, "WaveActiveAllTrue(")?
2934                    }
2935                    (crate::CollectiveOperation::Reduce, crate::SubgroupOperation::Any) => {
2936                        write!(self.out, "WaveActiveAnyTrue(")?
2937                    }
2938                    (crate::CollectiveOperation::Reduce, crate::SubgroupOperation::Add) => {
2939                        write!(self.out, "WaveActiveSum(")?
2940                    }
2941                    (crate::CollectiveOperation::Reduce, crate::SubgroupOperation::Mul) => {
2942                        write!(self.out, "WaveActiveProduct(")?
2943                    }
2944                    (crate::CollectiveOperation::Reduce, crate::SubgroupOperation::Max) => {
2945                        write!(self.out, "WaveActiveMax(")?
2946                    }
2947                    (crate::CollectiveOperation::Reduce, crate::SubgroupOperation::Min) => {
2948                        write!(self.out, "WaveActiveMin(")?
2949                    }
2950                    (crate::CollectiveOperation::Reduce, crate::SubgroupOperation::And) => {
2951                        write!(self.out, "WaveActiveBitAnd(")?
2952                    }
2953                    (crate::CollectiveOperation::Reduce, crate::SubgroupOperation::Or) => {
2954                        write!(self.out, "WaveActiveBitOr(")?
2955                    }
2956                    (crate::CollectiveOperation::Reduce, crate::SubgroupOperation::Xor) => {
2957                        write!(self.out, "WaveActiveBitXor(")?
2958                    }
2959                    (crate::CollectiveOperation::ExclusiveScan, crate::SubgroupOperation::Add) => {
2960                        write!(self.out, "WavePrefixSum(")?
2961                    }
2962                    (crate::CollectiveOperation::ExclusiveScan, crate::SubgroupOperation::Mul) => {
2963                        write!(self.out, "WavePrefixProduct(")?
2964                    }
2965                    (crate::CollectiveOperation::InclusiveScan, crate::SubgroupOperation::Add) => {
2966                        self.write_expr(module, argument, func_ctx)?;
2967                        write!(self.out, " + WavePrefixSum(")?;
2968                    }
2969                    (crate::CollectiveOperation::InclusiveScan, crate::SubgroupOperation::Mul) => {
2970                        self.write_expr(module, argument, func_ctx)?;
2971                        write!(self.out, " * WavePrefixProduct(")?;
2972                    }
2973                    _ => unimplemented!(),
2974                }
2975                self.write_expr(module, argument, func_ctx)?;
2976                writeln!(self.out, ");")?;
2977            }
2978            Statement::SubgroupGather {
2979                mode,
2980                argument,
2981                result,
2982            } => {
2983                write!(self.out, "{level}")?;
2984                write!(self.out, "const ")?;
2985                let name = Baked(result).to_string();
2986                match func_ctx.info[result].ty {
2987                    proc::TypeResolution::Handle(handle) => self.write_type(module, handle)?,
2988                    proc::TypeResolution::Value(ref value) => {
2989                        self.write_value_type(module, value)?
2990                    }
2991                };
2992                write!(self.out, " {name} = ")?;
2993                self.named_expressions.insert(result, name);
2994                match mode {
2995                    crate::GatherMode::BroadcastFirst => {
2996                        write!(self.out, "WaveReadLaneFirst(")?;
2997                        self.write_expr(module, argument, func_ctx)?;
2998                    }
2999                    crate::GatherMode::QuadBroadcast(index) => {
3000                        write!(self.out, "QuadReadLaneAt(")?;
3001                        self.write_expr(module, argument, func_ctx)?;
3002                        write!(self.out, ", ")?;
3003                        self.write_expr(module, index, func_ctx)?;
3004                    }
3005                    crate::GatherMode::QuadSwap(direction) => {
3006                        match direction {
3007                            crate::Direction::X => {
3008                                write!(self.out, "QuadReadAcrossX(")?;
3009                            }
3010                            crate::Direction::Y => {
3011                                write!(self.out, "QuadReadAcrossY(")?;
3012                            }
3013                            crate::Direction::Diagonal => {
3014                                write!(self.out, "QuadReadAcrossDiagonal(")?;
3015                            }
3016                        }
3017                        self.write_expr(module, argument, func_ctx)?;
3018                    }
3019                    _ => {
3020                        write!(self.out, "WaveReadLaneAt(")?;
3021                        self.write_expr(module, argument, func_ctx)?;
3022                        write!(self.out, ", ")?;
3023                        match mode {
3024                            crate::GatherMode::BroadcastFirst => unreachable!(),
3025                            crate::GatherMode::Broadcast(index)
3026                            | crate::GatherMode::Shuffle(index) => {
3027                                self.write_expr(module, index, func_ctx)?;
3028                            }
3029                            crate::GatherMode::ShuffleDown(index) => {
3030                                write!(self.out, "WaveGetLaneIndex() + ")?;
3031                                self.write_expr(module, index, func_ctx)?;
3032                            }
3033                            crate::GatherMode::ShuffleUp(index) => {
3034                                write!(self.out, "WaveGetLaneIndex() - ")?;
3035                                self.write_expr(module, index, func_ctx)?;
3036                            }
3037                            crate::GatherMode::ShuffleXor(index) => {
3038                                write!(self.out, "WaveGetLaneIndex() ^ ")?;
3039                                self.write_expr(module, index, func_ctx)?;
3040                            }
3041                            crate::GatherMode::QuadBroadcast(_) => unreachable!(),
3042                            crate::GatherMode::QuadSwap(_) => unreachable!(),
3043                        }
3044                    }
3045                }
3046                writeln!(self.out, ");")?;
3047            }
3048            Statement::CooperativeStore { .. } => unimplemented!(),
3049            Statement::RayPipelineFunction(_) => unreachable!(),
3050        }
3051
3052        Ok(())
3053    }
3054
3055    fn write_const_expression(
3056        &mut self,
3057        module: &Module,
3058        expr: Handle<crate::Expression>,
3059        arena: &crate::Arena<crate::Expression>,
3060    ) -> BackendResult {
3061        self.write_possibly_const_expression(module, expr, arena, |writer, expr| {
3062            writer.write_const_expression(module, expr, arena)
3063        })
3064    }
3065
3066    pub(super) fn write_literal(&mut self, literal: crate::Literal) -> BackendResult {
3067        match literal {
3068            crate::Literal::F64(value) => write!(self.out, "{value:?}L")?,
3069            crate::Literal::F32(value) => write!(self.out, "{value:?}")?,
3070            crate::Literal::F16(value) => write!(self.out, "{value:?}h")?,
3071            crate::Literal::U32(value) => write!(self.out, "{value}u")?,
3072            // `-2147483648` is parsed by some compilers as unary negation of
3073            // positive 2147483648, which is too large for an int, causing
3074            // issues for some compilers. Neither DXC nor FXC appear to have
3075            // this problem, but this is not specified and could change. We
3076            // therefore use `-2147483647 - 1` as a precaution.
3077            crate::Literal::I32(value) if value == i32::MIN => {
3078                write!(self.out, "int({} - 1)", value + 1)?
3079            }
3080            // HLSL has no suffix for explicit i32 literals, but not using any suffix
3081            // makes the type ambiguous which prevents overload resolution from
3082            // working. So we explicitly use the int() constructor syntax.
3083            crate::Literal::I32(value) => write!(self.out, "int({value})")?,
3084            crate::Literal::U64(value) => write!(self.out, "{value}uL")?,
3085            // I64 version of the minimum I32 value issue described above.
3086            crate::Literal::I64(value) if value == i64::MIN => {
3087                write!(self.out, "({}L - 1L)", value + 1)?;
3088            }
3089            crate::Literal::I64(value) => write!(self.out, "{value}L")?,
3090            crate::Literal::Bool(value) => write!(self.out, "{value}")?,
3091            crate::Literal::AbstractInt(_) | crate::Literal::AbstractFloat(_) => {
3092                return Err(Error::Custom(
3093                    "Abstract types should not appear in IR presented to backends".into(),
3094                ));
3095            }
3096        }
3097        Ok(())
3098    }
3099
3100    fn write_possibly_const_expression<E>(
3101        &mut self,
3102        module: &Module,
3103        expr: Handle<crate::Expression>,
3104        expressions: &crate::Arena<crate::Expression>,
3105        write_expression: E,
3106    ) -> BackendResult
3107    where
3108        E: Fn(&mut Self, Handle<crate::Expression>) -> BackendResult,
3109    {
3110        use crate::Expression;
3111
3112        match expressions[expr] {
3113            Expression::Literal(literal) => {
3114                self.write_literal(literal)?;
3115            }
3116            Expression::Constant(handle) => {
3117                let constant = &module.constants[handle];
3118                if constant.name.is_some() {
3119                    write!(self.out, "{}", self.names[&NameKey::Constant(handle)])?;
3120                } else {
3121                    self.write_const_expression(module, constant.init, &module.global_expressions)?;
3122                }
3123            }
3124            Expression::ZeroValue(ty) => {
3125                self.write_wrapped_zero_value_function_name(module, WrappedZeroValue { ty })?;
3126                write!(self.out, "()")?;
3127            }
3128            Expression::Compose { ty, ref components } => {
3129                match module.types[ty].inner {
3130                    TypeInner::Struct { .. } | TypeInner::Array { .. } => {
3131                        self.write_wrapped_constructor_function_name(
3132                            module,
3133                            WrappedConstructor { ty },
3134                        )?;
3135                    }
3136                    _ => {
3137                        self.write_type(module, ty)?;
3138                    }
3139                };
3140                write!(self.out, "(")?;
3141                for (index, component) in components.iter().enumerate() {
3142                    if index != 0 {
3143                        write!(self.out, ", ")?;
3144                    }
3145                    write_expression(self, *component)?;
3146                }
3147                write!(self.out, ")")?;
3148            }
3149            Expression::Splat { size, value } => {
3150                // hlsl is not supported one value constructor
3151                // if we write, for example, int4(0), dxc returns error:
3152                // error: too few elements in vector initialization (expected 4 elements, have 1)
3153                let number_of_components = match size {
3154                    crate::VectorSize::Bi => "xx",
3155                    crate::VectorSize::Tri => "xxx",
3156                    crate::VectorSize::Quad => "xxxx",
3157                };
3158                write!(self.out, "(")?;
3159                write_expression(self, value)?;
3160                write!(self.out, ").{number_of_components}")?
3161            }
3162            _ => {
3163                return Err(Error::Override);
3164            }
3165        }
3166
3167        Ok(())
3168    }
3169
3170    /// Helper method to write expressions
3171    ///
3172    /// # Notes
3173    /// Doesn't add any newlines or leading/trailing spaces
3174    pub(super) fn write_expr(
3175        &mut self,
3176        module: &Module,
3177        expr: Handle<crate::Expression>,
3178        func_ctx: &back::FunctionCtx<'_>,
3179    ) -> BackendResult {
3180        use crate::Expression;
3181
3182        // Handle the special semantics of vertex_index/instance_index
3183        let ff_input = if self.options.special_constants_binding.is_some() {
3184            func_ctx.is_fixed_function_input(expr, module)
3185        } else {
3186            None
3187        };
3188        let closing_bracket = match ff_input {
3189            Some(crate::BuiltIn::VertexIndex) => {
3190                write!(self.out, "({SPECIAL_CBUF_VAR}.{SPECIAL_FIRST_VERTEX} + ")?;
3191                ")"
3192            }
3193            Some(crate::BuiltIn::InstanceIndex) => {
3194                write!(self.out, "({SPECIAL_CBUF_VAR}.{SPECIAL_FIRST_INSTANCE} + ",)?;
3195                ")"
3196            }
3197            Some(crate::BuiltIn::NumWorkGroups) => {
3198                // Note: despite their names (`FIRST_VERTEX` and `FIRST_INSTANCE`),
3199                // in compute shaders the special constants contain the number
3200                // of workgroups, which we are using here.
3201                write!(
3202                    self.out,
3203                    "uint3({SPECIAL_CBUF_VAR}.{SPECIAL_FIRST_VERTEX}, {SPECIAL_CBUF_VAR}.{SPECIAL_FIRST_INSTANCE}, {SPECIAL_CBUF_VAR}.{SPECIAL_OTHER})",
3204                )?;
3205                return Ok(());
3206            }
3207            _ => "",
3208        };
3209
3210        if let Some(name) = self.named_expressions.get(&expr) {
3211            write!(self.out, "{name}{closing_bracket}")?;
3212            return Ok(());
3213        }
3214
3215        let expression = &func_ctx.expressions[expr];
3216
3217        match *expression {
3218            Expression::Literal(_)
3219            | Expression::Constant(_)
3220            | Expression::ZeroValue(_)
3221            | Expression::Compose { .. }
3222            | Expression::Splat { .. } => {
3223                self.write_possibly_const_expression(
3224                    module,
3225                    expr,
3226                    func_ctx.expressions,
3227                    |writer, expr| writer.write_expr(module, expr, func_ctx),
3228                )?;
3229            }
3230            Expression::Override(_) => return Err(Error::Override),
3231            // Avoid undefined behaviour for addition, subtraction, and
3232            // multiplication of signed integers by casting operands to
3233            // unsigned, performing the operation, then casting the result back
3234            // to signed.
3235            // TODO(#7109): This relies on the asint()/asuint() functions which only work
3236            // for 32-bit types, so we must find another solution for different bit widths.
3237            Expression::Binary {
3238                op:
3239                    op @ crate::BinaryOperator::Add
3240                    | op @ crate::BinaryOperator::Subtract
3241                    | op @ crate::BinaryOperator::Multiply,
3242                left,
3243                right,
3244            } if matches!(
3245                func_ctx.resolve_type(expr, &module.types).scalar(),
3246                Some(Scalar::I32)
3247            ) =>
3248            {
3249                write!(self.out, "asint(asuint(",)?;
3250                self.write_expr(module, left, func_ctx)?;
3251                write!(self.out, ") {} asuint(", back::binary_operation_str(op))?;
3252                self.write_expr(module, right, func_ctx)?;
3253                write!(self.out, "))")?;
3254            }
3255            // All of the multiplication can be expressed as `mul`,
3256            // except vector * vector, which needs to use the "*" operator.
3257            Expression::Binary {
3258                op: crate::BinaryOperator::Multiply,
3259                left,
3260                right,
3261            } if func_ctx.resolve_type(left, &module.types).is_matrix()
3262                || func_ctx.resolve_type(right, &module.types).is_matrix() =>
3263            {
3264                // We intentionally flip the order of multiplication as our matrices are implicitly transposed.
3265                write!(self.out, "mul(")?;
3266                self.write_expr(module, right, func_ctx)?;
3267                write!(self.out, ", ")?;
3268                self.write_expr(module, left, func_ctx)?;
3269                write!(self.out, ")")?;
3270            }
3271
3272            // WGSL says that floating-point division by zero should return
3273            // infinity. Microsoft's Direct3D 11 functional specification
3274            // (https://microsoft.github.io/DirectX-Specs/d3d/archive/D3D11_3_FunctionalSpec.htm)
3275            // says:
3276            //
3277            //     Divide by 0 produces +/- INF, except 0/0 which results in NaN.
3278            //
3279            // which is what we want. The DXIL specification for the FDiv
3280            // instruction corroborates this:
3281            //
3282            // https://github.com/microsoft/DirectXShaderCompiler/blob/main/docs/DXIL.rst#fdiv
3283            Expression::Binary {
3284                op: crate::BinaryOperator::Divide,
3285                left,
3286                right,
3287            } if matches!(
3288                func_ctx.resolve_type(expr, &module.types).scalar_kind(),
3289                Some(ScalarKind::Sint | ScalarKind::Uint)
3290            ) =>
3291            {
3292                write!(self.out, "{DIV_FUNCTION}(")?;
3293                self.write_expr(module, left, func_ctx)?;
3294                write!(self.out, ", ")?;
3295                self.write_expr(module, right, func_ctx)?;
3296                write!(self.out, ")")?;
3297            }
3298
3299            Expression::Binary {
3300                op: crate::BinaryOperator::Modulo,
3301                left,
3302                right,
3303            } if matches!(
3304                func_ctx.resolve_type(expr, &module.types).scalar_kind(),
3305                Some(ScalarKind::Sint | ScalarKind::Uint | ScalarKind::Float)
3306            ) =>
3307            {
3308                write!(self.out, "{MOD_FUNCTION}(")?;
3309                self.write_expr(module, left, func_ctx)?;
3310                write!(self.out, ", ")?;
3311                self.write_expr(module, right, func_ctx)?;
3312                write!(self.out, ")")?;
3313            }
3314
3315            Expression::Binary { op, left, right } => {
3316                write!(self.out, "(")?;
3317                self.write_expr(module, left, func_ctx)?;
3318                write!(self.out, " {} ", back::binary_operation_str(op))?;
3319                self.write_expr(module, right, func_ctx)?;
3320                write!(self.out, ")")?;
3321            }
3322            Expression::Access { base, index } => {
3323                if let Some(crate::AddressSpace::Storage { .. }) =
3324                    func_ctx.resolve_type(expr, &module.types).pointer_space()
3325                {
3326                    // do nothing, the chain is written on `Load`/`Store`
3327                } else {
3328                    // We use the function __get_col_of_matCx2 here in cases
3329                    // where `base`s type resolves to a matCx2 and is part of a
3330                    // struct member with type of (possibly nested) array of matCx2's.
3331                    //
3332                    // Note that this only works for `Load`s and we handle
3333                    // `Store`s differently in `Statement::Store`.
3334                    if let Some(MatrixType {
3335                        columns,
3336                        rows: crate::VectorSize::Bi,
3337                        width: 4,
3338                    }) = get_inner_matrix_of_struct_array_member(module, base, func_ctx, true)
3339                        .or_else(|| get_global_uniform_matrix(module, base, func_ctx))
3340                    {
3341                        write!(self.out, "__get_col_of_mat{}x2(", columns as u8)?;
3342                        self.write_expr(module, base, func_ctx)?;
3343                        write!(self.out, ", ")?;
3344                        self.write_expr(module, index, func_ctx)?;
3345                        write!(self.out, ")")?;
3346                        return Ok(());
3347                    }
3348
3349                    let resolved = func_ctx.resolve_type(base, &module.types);
3350
3351                    let (indexing_binding_array, non_uniform_qualifier) = match *resolved {
3352                        TypeInner::BindingArray { .. } => {
3353                            let uniformity = &func_ctx.info[index].uniformity;
3354
3355                            (true, uniformity.non_uniform_result.is_some())
3356                        }
3357                        _ => (false, false),
3358                    };
3359
3360                    self.write_expr(module, base, func_ctx)?;
3361
3362                    let array_sampler_info = self.sampler_binding_array_info_from_expression(
3363                        module, func_ctx, base, resolved,
3364                    );
3365
3366                    if let Some(ref info) = array_sampler_info {
3367                        write!(self.out, "{}[", info.sampler_heap_name)?;
3368                    } else {
3369                        write!(self.out, "[")?;
3370                    }
3371
3372                    let needs_bound_check = self.options.restrict_indexing
3373                        && !indexing_binding_array
3374                        && match resolved.pointer_space() {
3375                            Some(
3376                                crate::AddressSpace::Function
3377                                | crate::AddressSpace::Private
3378                                | crate::AddressSpace::WorkGroup
3379                                | crate::AddressSpace::Immediate
3380                                | crate::AddressSpace::TaskPayload
3381                                | crate::AddressSpace::RayPayload
3382                                | crate::AddressSpace::IncomingRayPayload,
3383                            )
3384                            | None => true,
3385                            Some(crate::AddressSpace::Uniform) => {
3386                                // check if BindTarget.restrict_indexing is set, this is used for dynamic buffers
3387                                let var_handle = self.fill_access_chain(module, base, func_ctx)?;
3388                                let bind_target = self
3389                                    .options
3390                                    .resolve_resource_binding(
3391                                        module.global_variables[var_handle]
3392                                            .binding
3393                                            .as_ref()
3394                                            .unwrap(),
3395                                    )
3396                                    .unwrap();
3397                                bind_target.restrict_indexing
3398                            }
3399                            Some(
3400                                crate::AddressSpace::Handle | crate::AddressSpace::Storage { .. },
3401                            ) => unreachable!(),
3402                        };
3403                    // Decide whether this index needs to be clamped to fall within range.
3404                    let restriction_needed = if needs_bound_check {
3405                        index::access_needs_check(
3406                            base,
3407                            index::GuardedIndex::Expression(index),
3408                            module,
3409                            func_ctx.expressions,
3410                            func_ctx.info,
3411                        )
3412                    } else {
3413                        None
3414                    };
3415                    if let Some(limit) = restriction_needed {
3416                        write!(self.out, "min(uint(")?;
3417                        self.write_expr(module, index, func_ctx)?;
3418                        write!(self.out, "), ")?;
3419                        match limit {
3420                            index::IndexableLength::Known(limit) => {
3421                                write!(self.out, "{}u", limit - 1)?;
3422                            }
3423                            index::IndexableLength::Dynamic => unreachable!(),
3424                        }
3425                        write!(self.out, ")")?;
3426                    } else {
3427                        if non_uniform_qualifier {
3428                            write!(self.out, "NonUniformResourceIndex(")?;
3429                        }
3430                        if let Some(ref info) = array_sampler_info {
3431                            write!(
3432                                self.out,
3433                                "{}[{} + ",
3434                                info.sampler_index_buffer_name, info.binding_array_base_index_name,
3435                            )?;
3436                        }
3437                        self.write_expr(module, index, func_ctx)?;
3438                        if array_sampler_info.is_some() {
3439                            write!(self.out, "]")?;
3440                        }
3441                        if non_uniform_qualifier {
3442                            write!(self.out, ")")?;
3443                        }
3444                    }
3445
3446                    write!(self.out, "]")?;
3447                }
3448            }
3449            Expression::AccessIndex { base, index } => {
3450                if let Some(crate::AddressSpace::Storage { .. }) =
3451                    func_ctx.resolve_type(expr, &module.types).pointer_space()
3452                {
3453                    // do nothing, the chain is written on `Load`/`Store`
3454                } else {
3455                    // See if we need to write the matrix column access in a
3456                    // special way since the type of `base` is our special
3457                    // __matCx2 struct.
3458                    if let Some(MatrixType {
3459                        rows: crate::VectorSize::Bi,
3460                        width: 4,
3461                        ..
3462                    }) = get_inner_matrix_of_struct_array_member(module, base, func_ctx, true)
3463                        .or_else(|| get_global_uniform_matrix(module, base, func_ctx))
3464                    {
3465                        self.write_expr(module, base, func_ctx)?;
3466                        write!(self.out, "._{index}")?;
3467                        return Ok(());
3468                    }
3469
3470                    let base_ty_res = &func_ctx.info[base].ty;
3471                    let mut resolved = base_ty_res.inner_with(&module.types);
3472                    let base_ty_handle = match *resolved {
3473                        TypeInner::Pointer { base, .. } => {
3474                            resolved = &module.types[base].inner;
3475                            Some(base)
3476                        }
3477                        _ => base_ty_res.handle(),
3478                    };
3479
3480                    // We treat matrices of the form `matCx2` as a sequence of C `vec2`s.
3481                    // See the module-level block comment in mod.rs for details.
3482                    //
3483                    // We handle matrix reconstruction here for Loads.
3484                    // Stores are handled directly by `Statement::Store`.
3485                    if let TypeInner::Struct { ref members, .. } = *resolved {
3486                        let member = &members[index as usize];
3487
3488                        match module.types[member.ty].inner {
3489                            TypeInner::Matrix {
3490                                rows: crate::VectorSize::Bi,
3491                                ..
3492                            } if member.binding.is_none() => {
3493                                let ty = base_ty_handle.unwrap();
3494                                self.write_wrapped_struct_matrix_get_function_name(
3495                                    WrappedStructMatrixAccess { ty, index },
3496                                )?;
3497                                write!(self.out, "(")?;
3498                                self.write_expr(module, base, func_ctx)?;
3499                                write!(self.out, ")")?;
3500                                return Ok(());
3501                            }
3502                            _ => {}
3503                        }
3504                    }
3505
3506                    let array_sampler_info = self.sampler_binding_array_info_from_expression(
3507                        module, func_ctx, base, resolved,
3508                    );
3509
3510                    if let Some(ref info) = array_sampler_info {
3511                        write!(
3512                            self.out,
3513                            "{}[{}",
3514                            info.sampler_heap_name, info.sampler_index_buffer_name
3515                        )?;
3516                    }
3517
3518                    self.write_expr(module, base, func_ctx)?;
3519
3520                    match *resolved {
3521                        // We specifically lift the ValuePointer to this case. While `[0]` is valid
3522                        // HLSL for any vector behind a value pointer, FXC completely miscompiles
3523                        // it and generates completely nonsensical DXBC.
3524                        //
3525                        // See https://github.com/gfx-rs/naga/issues/2095 for more details.
3526                        TypeInner::Vector { .. } | TypeInner::ValuePointer { .. } => {
3527                            // Write vector access as a swizzle
3528                            write!(self.out, ".{}", back::COMPONENTS[index as usize])?
3529                        }
3530                        TypeInner::Matrix { .. }
3531                        | TypeInner::Array { .. }
3532                        | TypeInner::BindingArray { .. } => {
3533                            if let Some(ref info) = array_sampler_info {
3534                                write!(
3535                                    self.out,
3536                                    "[{} + {index}]",
3537                                    info.binding_array_base_index_name
3538                                )?;
3539                            } else {
3540                                write!(self.out, "[{index}]")?;
3541                            }
3542                        }
3543                        TypeInner::Struct { .. } => {
3544                            // This will never panic in case the type is a `Struct`, this is not true
3545                            // for other types so we can only check while inside this match arm
3546                            let ty = base_ty_handle.unwrap();
3547
3548                            write!(
3549                                self.out,
3550                                ".{}",
3551                                &self.names[&NameKey::StructMember(ty, index)]
3552                            )?
3553                        }
3554                        ref other => return Err(Error::Custom(format!("Cannot index {other:?}"))),
3555                    }
3556
3557                    if array_sampler_info.is_some() {
3558                        write!(self.out, "]")?;
3559                    }
3560                }
3561            }
3562            Expression::FunctionArgument(pos) => {
3563                let ty = func_ctx.resolve_type(expr, &module.types);
3564
3565                // We know that any external texture function argument has been expanded into
3566                // separate consecutive arguments for each plane and the parameters buffer. And we
3567                // also know that external textures can only ever be used as an argument to another
3568                // function. Therefore we can simply emit each of the expanded arguments in a
3569                // consecutive comma-separated list.
3570                if let TypeInner::Image {
3571                    class: crate::ImageClass::External,
3572                    ..
3573                } = *ty
3574                {
3575                    let plane_names = [0, 1, 2].map(|i| {
3576                        &self.names[&func_ctx
3577                            .external_texture_argument_key(pos, ExternalTextureNameKey::Plane(i))]
3578                    });
3579                    let params_name = &self.names[&func_ctx
3580                        .external_texture_argument_key(pos, ExternalTextureNameKey::Params)];
3581                    write!(
3582                        self.out,
3583                        "{}, {}, {}, {}",
3584                        plane_names[0], plane_names[1], plane_names[2], params_name
3585                    )?;
3586                } else {
3587                    let key = func_ctx.argument_key(pos);
3588                    let name = &self.names[&key];
3589                    write!(self.out, "{name}")?;
3590                }
3591            }
3592            Expression::ImageSample {
3593                coordinate,
3594                image,
3595                sampler,
3596                clamp_to_edge: true,
3597                gather: None,
3598                array_index: None,
3599                offset: None,
3600                level: crate::SampleLevel::Zero,
3601                depth_ref: None,
3602            } => {
3603                write!(self.out, "{IMAGE_SAMPLE_BASE_CLAMP_TO_EDGE_FUNCTION}(")?;
3604                self.write_expr(module, image, func_ctx)?;
3605                write!(self.out, ", ")?;
3606                self.write_expr(module, sampler, func_ctx)?;
3607                write!(self.out, ", ")?;
3608                self.write_expr(module, coordinate, func_ctx)?;
3609                write!(self.out, ")")?;
3610            }
3611            Expression::ImageSample {
3612                image,
3613                sampler,
3614                gather,
3615                coordinate,
3616                array_index,
3617                offset,
3618                level,
3619                depth_ref,
3620                clamp_to_edge,
3621            } => {
3622                if clamp_to_edge {
3623                    return Err(Error::Custom(
3624                        "ImageSample::clamp_to_edge should have been validated out".to_string(),
3625                    ));
3626                }
3627
3628                use crate::SampleLevel as Sl;
3629                const COMPONENTS: [&str; 4] = ["", "Green", "Blue", "Alpha"];
3630
3631                let (base_str, component_str) = match gather {
3632                    Some(component) => ("Gather", COMPONENTS[component as usize]),
3633                    None => ("Sample", ""),
3634                };
3635                let cmp_str = match depth_ref {
3636                    Some(_) => "Cmp",
3637                    None => "",
3638                };
3639                let level_str = match level {
3640                    Sl::Zero if gather.is_none() => "LevelZero",
3641                    Sl::Auto | Sl::Zero => "",
3642                    Sl::Exact(_) => "Level",
3643                    Sl::Bias(_) => "Bias",
3644                    Sl::Gradient { .. } => "Grad",
3645                };
3646
3647                self.write_expr(module, image, func_ctx)?;
3648                write!(self.out, ".{base_str}{cmp_str}{component_str}{level_str}(")?;
3649                self.write_expr(module, sampler, func_ctx)?;
3650                write!(self.out, ", ")?;
3651                self.write_texture_coordinates(
3652                    "float",
3653                    coordinate,
3654                    array_index,
3655                    None,
3656                    module,
3657                    func_ctx,
3658                )?;
3659
3660                if let Some(depth_ref) = depth_ref {
3661                    write!(self.out, ", ")?;
3662                    self.write_expr(module, depth_ref, func_ctx)?;
3663                }
3664
3665                match level {
3666                    Sl::Auto | Sl::Zero => {}
3667                    Sl::Exact(expr) => {
3668                        write!(self.out, ", ")?;
3669                        self.write_expr(module, expr, func_ctx)?;
3670                    }
3671                    Sl::Bias(expr) => {
3672                        write!(self.out, ", ")?;
3673                        self.write_expr(module, expr, func_ctx)?;
3674                    }
3675                    Sl::Gradient { x, y } => {
3676                        write!(self.out, ", ")?;
3677                        self.write_expr(module, x, func_ctx)?;
3678                        write!(self.out, ", ")?;
3679                        self.write_expr(module, y, func_ctx)?;
3680                    }
3681                }
3682
3683                if let Some(offset) = offset {
3684                    write!(self.out, ", ")?;
3685                    write!(self.out, "int2(")?; // work around https://github.com/microsoft/DirectXShaderCompiler/issues/5082#issuecomment-1540147807
3686                    self.write_const_expression(module, offset, func_ctx.expressions)?;
3687                    write!(self.out, ")")?;
3688                }
3689
3690                write!(self.out, ")")?;
3691            }
3692            Expression::ImageQuery { image, query } => {
3693                // use wrapped image query function
3694                if let TypeInner::Image {
3695                    dim,
3696                    arrayed,
3697                    class,
3698                } = *func_ctx.resolve_type(image, &module.types)
3699                {
3700                    let wrapped_image_query = WrappedImageQuery {
3701                        dim,
3702                        arrayed,
3703                        class,
3704                        query: query.into(),
3705                    };
3706
3707                    self.write_wrapped_image_query_function_name(wrapped_image_query)?;
3708                    write!(self.out, "(")?;
3709                    // Image always first param
3710                    self.write_expr(module, image, func_ctx)?;
3711                    if let crate::ImageQuery::Size { level: Some(level) } = query {
3712                        write!(self.out, ", ")?;
3713                        self.write_expr(module, level, func_ctx)?;
3714                    }
3715                    write!(self.out, ")")?;
3716                }
3717            }
3718            Expression::ImageLoad {
3719                image,
3720                coordinate,
3721                array_index,
3722                sample,
3723                level,
3724            } => self.write_image_load(
3725                &module,
3726                expr,
3727                func_ctx,
3728                image,
3729                coordinate,
3730                array_index,
3731                sample,
3732                level,
3733            )?,
3734            Expression::GlobalVariable(handle) => {
3735                let global_variable = &module.global_variables[handle];
3736                let ty = &module.types[global_variable.ty].inner;
3737
3738                // In the case of binding arrays of samplers, we need to not write anything
3739                // as the we are in the wrong position to fully write the expression.
3740                //
3741                // The entire writing is done by AccessIndex.
3742                let is_binding_array_of_samplers = match *ty {
3743                    TypeInner::BindingArray { base, .. } => {
3744                        let base_ty = &module.types[base].inner;
3745                        matches!(*base_ty, TypeInner::Sampler { .. })
3746                    }
3747                    _ => false,
3748                };
3749
3750                let is_storage_space =
3751                    matches!(global_variable.space, crate::AddressSpace::Storage { .. });
3752
3753                // Our external texture global variable has been expanded into multiple
3754                // global variables, one for each plane and the parameters buffer.
3755                // External textures can only ever be used as arguments to a function
3756                // call, and we know that an external texture argument to any function
3757                // will have been expanded to separate consecutive arguments for each
3758                // plane and the parameters buffer. Therefore we can simply emit each of
3759                // the expanded global variables in a consecutive comma-separated list.
3760                if let TypeInner::Image {
3761                    class: crate::ImageClass::External,
3762                    ..
3763                } = *ty
3764                {
3765                    let plane_names = [0, 1, 2].map(|i| {
3766                        &self.names[&NameKey::ExternalTextureGlobalVariable(
3767                            handle,
3768                            ExternalTextureNameKey::Plane(i),
3769                        )]
3770                    });
3771                    let params_name = &self.names[&NameKey::ExternalTextureGlobalVariable(
3772                        handle,
3773                        ExternalTextureNameKey::Params,
3774                    )];
3775                    write!(
3776                        self.out,
3777                        "{}, {}, {}, {}",
3778                        plane_names[0], plane_names[1], plane_names[2], params_name
3779                    )?;
3780                } else if !is_binding_array_of_samplers && !is_storage_space {
3781                    let name = &self.names[&NameKey::GlobalVariable(handle)];
3782                    write!(self.out, "{name}")?;
3783                }
3784            }
3785            Expression::LocalVariable(handle) => {
3786                write!(self.out, "{}", self.names[&func_ctx.name_key(handle)])?
3787            }
3788            Expression::Load { pointer } => {
3789                match func_ctx
3790                    .resolve_type(pointer, &module.types)
3791                    .pointer_space()
3792                {
3793                    Some(crate::AddressSpace::Storage { .. }) => {
3794                        let var_handle = self.fill_access_chain(module, pointer, func_ctx)?;
3795                        let result_ty = func_ctx.info[expr].ty.clone();
3796                        self.write_storage_load(module, var_handle, result_ty, func_ctx)?;
3797                    }
3798                    _ => {
3799                        let mut close_paren = false;
3800
3801                        // We cast the value loaded to a native HLSL floatCx2
3802                        // in cases where it is of type:
3803                        //  - __matCx2 or
3804                        //  - a (possibly nested) array of __matCx2's
3805                        if let Some(MatrixType {
3806                            rows: crate::VectorSize::Bi,
3807                            width: 4,
3808                            ..
3809                        }) = get_inner_matrix_of_struct_array_member(
3810                            module, pointer, func_ctx, false,
3811                        )
3812                        .or_else(|| get_inner_matrix_of_global_uniform(module, pointer, func_ctx))
3813                        {
3814                            let mut resolved = func_ctx.resolve_type(pointer, &module.types);
3815                            let ptr_tr = resolved.pointer_base_type();
3816                            if let Some(ptr_ty) =
3817                                ptr_tr.as_ref().map(|tr| tr.inner_with(&module.types))
3818                            {
3819                                resolved = ptr_ty;
3820                            }
3821
3822                            write!(self.out, "((")?;
3823                            if let TypeInner::Array { base, size, .. } = *resolved {
3824                                self.write_type(module, base)?;
3825                                self.write_array_size(module, base, size)?;
3826                            } else {
3827                                self.write_value_type(module, resolved)?;
3828                            }
3829                            write!(self.out, ")")?;
3830                            close_paren = true;
3831                        }
3832
3833                        self.write_expr(module, pointer, func_ctx)?;
3834
3835                        if close_paren {
3836                            write!(self.out, ")")?;
3837                        }
3838                    }
3839                }
3840            }
3841            Expression::Unary { op, expr } => {
3842                // https://docs.microsoft.com/en-us/windows/win32/direct3dhlsl/dx-graphics-hlsl-operators#unary-operators
3843                let op_str = match op {
3844                    crate::UnaryOperator::Negate => {
3845                        match func_ctx.resolve_type(expr, &module.types).scalar() {
3846                            Some(Scalar::I32) => NEG_FUNCTION,
3847                            _ => "-",
3848                        }
3849                    }
3850                    crate::UnaryOperator::LogicalNot => "!",
3851                    crate::UnaryOperator::BitwiseNot => "~",
3852                };
3853                write!(self.out, "{op_str}(")?;
3854                self.write_expr(module, expr, func_ctx)?;
3855                write!(self.out, ")")?;
3856            }
3857            Expression::As {
3858                expr,
3859                kind,
3860                convert,
3861            } => {
3862                let inner = func_ctx.resolve_type(expr, &module.types);
3863                if inner.scalar_kind() == Some(ScalarKind::Float)
3864                    && (kind == ScalarKind::Sint || kind == ScalarKind::Uint)
3865                    && convert.is_some()
3866                {
3867                    // Use helper functions for float to int casts in order to
3868                    // avoid undefined behaviour when value is out of range for
3869                    // the target type.
3870                    let fun_name = match (kind, convert) {
3871                        (ScalarKind::Sint, Some(4)) => F2I32_FUNCTION,
3872                        (ScalarKind::Uint, Some(4)) => F2U32_FUNCTION,
3873                        (ScalarKind::Sint, Some(8)) => F2I64_FUNCTION,
3874                        (ScalarKind::Uint, Some(8)) => F2U64_FUNCTION,
3875                        _ => unreachable!(),
3876                    };
3877                    write!(self.out, "{fun_name}(")?;
3878                    self.write_expr(module, expr, func_ctx)?;
3879                    write!(self.out, ")")?;
3880                } else {
3881                    let close_paren = match convert {
3882                        Some(dst_width) => {
3883                            let scalar = Scalar {
3884                                kind,
3885                                width: dst_width,
3886                            };
3887                            match *inner {
3888                                TypeInner::Vector { size, .. } => {
3889                                    write!(
3890                                        self.out,
3891                                        "{}{}(",
3892                                        scalar.to_hlsl_str()?,
3893                                        common::vector_size_str(size)
3894                                    )?;
3895                                }
3896                                TypeInner::Scalar(_) => {
3897                                    write!(self.out, "{}(", scalar.to_hlsl_str()?,)?;
3898                                }
3899                                TypeInner::Matrix { columns, rows, .. } => {
3900                                    write!(
3901                                        self.out,
3902                                        "{}{}x{}(",
3903                                        scalar.to_hlsl_str()?,
3904                                        common::vector_size_str(columns),
3905                                        common::vector_size_str(rows)
3906                                    )?;
3907                                }
3908                                _ => {
3909                                    return Err(Error::Unimplemented(format!(
3910                                        "write_expr expression::as {inner:?}"
3911                                    )));
3912                                }
3913                            };
3914                            true
3915                        }
3916                        None => {
3917                            if inner.scalar_width() == Some(8) {
3918                                false
3919                            } else {
3920                                write!(self.out, "{}(", kind.to_hlsl_cast(),)?;
3921                                true
3922                            }
3923                        }
3924                    };
3925                    self.write_expr(module, expr, func_ctx)?;
3926                    if close_paren {
3927                        write!(self.out, ")")?;
3928                    }
3929                }
3930            }
3931            Expression::Math {
3932                fun,
3933                arg,
3934                arg1,
3935                arg2,
3936                arg3,
3937            } => {
3938                use crate::MathFunction as Mf;
3939
3940                enum Function {
3941                    Asincosh { is_sin: bool },
3942                    Atanh,
3943                    Pack2x16float,
3944                    Pack2x16snorm,
3945                    Pack2x16unorm,
3946                    Pack4x8snorm,
3947                    Pack4x8unorm,
3948                    Pack4xI8,
3949                    Pack4xU8,
3950                    Pack4xI8Clamp,
3951                    Pack4xU8Clamp,
3952                    Unpack2x16float,
3953                    Unpack2x16snorm,
3954                    Unpack2x16unorm,
3955                    Unpack4x8snorm,
3956                    Unpack4x8unorm,
3957                    Unpack4xI8,
3958                    Unpack4xU8,
3959                    Dot4I8Packed,
3960                    Dot4U8Packed,
3961                    QuantizeToF16,
3962                    Regular(&'static str),
3963                    MissingIntOverload(&'static str),
3964                    MissingIntReturnType(&'static str),
3965                    CountTrailingZeros,
3966                    CountLeadingZeros,
3967                }
3968
3969                let fun = match fun {
3970                    // comparison
3971                    Mf::Abs => match func_ctx.resolve_type(arg, &module.types).scalar() {
3972                        Some(Scalar::I32) => Function::Regular(ABS_FUNCTION),
3973                        _ => Function::Regular("abs"),
3974                    },
3975                    Mf::Min => Function::Regular("min"),
3976                    Mf::Max => Function::Regular("max"),
3977                    Mf::Clamp => Function::Regular("clamp"),
3978                    Mf::Saturate => Function::Regular("saturate"),
3979                    // trigonometry
3980                    Mf::Cos => Function::Regular("cos"),
3981                    Mf::Cosh => Function::Regular("cosh"),
3982                    Mf::Sin => Function::Regular("sin"),
3983                    Mf::Sinh => Function::Regular("sinh"),
3984                    Mf::Tan => Function::Regular("tan"),
3985                    Mf::Tanh => Function::Regular("tanh"),
3986                    Mf::Acos => Function::Regular("acos"),
3987                    Mf::Asin => Function::Regular("asin"),
3988                    Mf::Atan => Function::Regular("atan"),
3989                    Mf::Atan2 => Function::Regular("atan2"),
3990                    Mf::Asinh => Function::Asincosh { is_sin: true },
3991                    Mf::Acosh => Function::Asincosh { is_sin: false },
3992                    Mf::Atanh => Function::Atanh,
3993                    Mf::Radians => Function::Regular("radians"),
3994                    Mf::Degrees => Function::Regular("degrees"),
3995                    // decomposition
3996                    Mf::Ceil => Function::Regular("ceil"),
3997                    Mf::Floor => Function::Regular("floor"),
3998                    Mf::Round => Function::Regular("round"),
3999                    Mf::Fract => Function::Regular("frac"),
4000                    Mf::Trunc => Function::Regular("trunc"),
4001                    Mf::Modf => Function::Regular(MODF_FUNCTION),
4002                    Mf::Frexp => Function::Regular(FREXP_FUNCTION),
4003                    Mf::Ldexp => Function::Regular("ldexp"),
4004                    // exponent
4005                    Mf::Exp => Function::Regular("exp"),
4006                    Mf::Exp2 => Function::Regular("exp2"),
4007                    Mf::Log => Function::Regular("log"),
4008                    Mf::Log2 => Function::Regular("log2"),
4009                    Mf::Pow => Function::Regular("pow"),
4010                    // geometry
4011                    Mf::Dot => Function::Regular("dot"),
4012                    Mf::Dot4I8Packed => Function::Dot4I8Packed,
4013                    Mf::Dot4U8Packed => Function::Dot4U8Packed,
4014                    //Mf::Outer => ,
4015                    Mf::Cross => Function::Regular("cross"),
4016                    Mf::Distance => Function::Regular("distance"),
4017                    Mf::Length => Function::Regular("length"),
4018                    Mf::Normalize => Function::Regular("normalize"),
4019                    Mf::FaceForward => Function::Regular("faceforward"),
4020                    Mf::Reflect => Function::Regular("reflect"),
4021                    Mf::Refract => Function::Regular("refract"),
4022                    // computational
4023                    Mf::Sign => Function::Regular("sign"),
4024                    Mf::Fma => Function::Regular("mad"),
4025                    Mf::Mix => Function::Regular("lerp"),
4026                    Mf::Step => Function::Regular("step"),
4027                    Mf::SmoothStep => Function::Regular("smoothstep"),
4028                    Mf::Sqrt => Function::Regular("sqrt"),
4029                    Mf::InverseSqrt => Function::Regular("rsqrt"),
4030                    //Mf::Inverse =>,
4031                    Mf::Transpose => Function::Regular("transpose"),
4032                    Mf::Determinant => Function::Regular("determinant"),
4033                    Mf::QuantizeToF16 => Function::QuantizeToF16,
4034                    // bits
4035                    Mf::CountTrailingZeros => Function::CountTrailingZeros,
4036                    Mf::CountLeadingZeros => Function::CountLeadingZeros,
4037                    Mf::CountOneBits => Function::MissingIntOverload("countbits"),
4038                    Mf::ReverseBits => Function::MissingIntOverload("reversebits"),
4039                    Mf::FirstTrailingBit => Function::MissingIntReturnType("firstbitlow"),
4040                    Mf::FirstLeadingBit => Function::MissingIntReturnType("firstbithigh"),
4041                    Mf::ExtractBits => Function::Regular(EXTRACT_BITS_FUNCTION),
4042                    Mf::InsertBits => Function::Regular(INSERT_BITS_FUNCTION),
4043                    // Data Packing
4044                    Mf::Pack2x16float => Function::Pack2x16float,
4045                    Mf::Pack2x16snorm => Function::Pack2x16snorm,
4046                    Mf::Pack2x16unorm => Function::Pack2x16unorm,
4047                    Mf::Pack4x8snorm => Function::Pack4x8snorm,
4048                    Mf::Pack4x8unorm => Function::Pack4x8unorm,
4049                    Mf::Pack4xI8 => Function::Pack4xI8,
4050                    Mf::Pack4xU8 => Function::Pack4xU8,
4051                    Mf::Pack4xI8Clamp => Function::Pack4xI8Clamp,
4052                    Mf::Pack4xU8Clamp => Function::Pack4xU8Clamp,
4053                    // Data Unpacking
4054                    Mf::Unpack2x16float => Function::Unpack2x16float,
4055                    Mf::Unpack2x16snorm => Function::Unpack2x16snorm,
4056                    Mf::Unpack2x16unorm => Function::Unpack2x16unorm,
4057                    Mf::Unpack4x8snorm => Function::Unpack4x8snorm,
4058                    Mf::Unpack4x8unorm => Function::Unpack4x8unorm,
4059                    Mf::Unpack4xI8 => Function::Unpack4xI8,
4060                    Mf::Unpack4xU8 => Function::Unpack4xU8,
4061                    _ => return Err(Error::Unimplemented(format!("write_expr_math {fun:?}"))),
4062                };
4063
4064                match fun {
4065                    Function::Asincosh { is_sin } => {
4066                        write!(self.out, "log(")?;
4067                        self.write_expr(module, arg, func_ctx)?;
4068                        write!(self.out, " + sqrt(")?;
4069                        self.write_expr(module, arg, func_ctx)?;
4070                        write!(self.out, " * ")?;
4071                        self.write_expr(module, arg, func_ctx)?;
4072                        match is_sin {
4073                            true => write!(self.out, " + 1.0))")?,
4074                            false => write!(self.out, " - 1.0))")?,
4075                        }
4076                    }
4077                    Function::Atanh => {
4078                        write!(self.out, "0.5 * log((1.0 + ")?;
4079                        self.write_expr(module, arg, func_ctx)?;
4080                        write!(self.out, ") / (1.0 - ")?;
4081                        self.write_expr(module, arg, func_ctx)?;
4082                        write!(self.out, "))")?;
4083                    }
4084                    Function::Pack2x16float => {
4085                        write!(self.out, "(f32tof16(")?;
4086                        self.write_expr(module, arg, func_ctx)?;
4087                        write!(self.out, "[0]) | f32tof16(")?;
4088                        self.write_expr(module, arg, func_ctx)?;
4089                        write!(self.out, "[1]) << 16)")?;
4090                    }
4091                    Function::Pack2x16snorm => {
4092                        let scale = 32767;
4093
4094                        write!(self.out, "uint((int(round(clamp(")?;
4095                        self.write_expr(module, arg, func_ctx)?;
4096                        write!(
4097                            self.out,
4098                            "[0], -1.0, 1.0) * {scale}.0)) & 0xFFFF) | ((int(round(clamp("
4099                        )?;
4100                        self.write_expr(module, arg, func_ctx)?;
4101                        write!(self.out, "[1], -1.0, 1.0) * {scale}.0)) & 0xFFFF) << 16))",)?;
4102                    }
4103                    Function::Pack2x16unorm => {
4104                        let scale = 65535;
4105
4106                        write!(self.out, "(uint(round(clamp(")?;
4107                        self.write_expr(module, arg, func_ctx)?;
4108                        write!(self.out, "[0], 0.0, 1.0) * {scale}.0)) | uint(round(clamp(")?;
4109                        self.write_expr(module, arg, func_ctx)?;
4110                        write!(self.out, "[1], 0.0, 1.0) * {scale}.0)) << 16)")?;
4111                    }
4112                    Function::Pack4x8snorm => {
4113                        let scale = 127;
4114
4115                        write!(self.out, "uint((int(round(clamp(")?;
4116                        self.write_expr(module, arg, func_ctx)?;
4117                        write!(
4118                            self.out,
4119                            "[0], -1.0, 1.0) * {scale}.0)) & 0xFF) | ((int(round(clamp("
4120                        )?;
4121                        self.write_expr(module, arg, func_ctx)?;
4122                        write!(
4123                            self.out,
4124                            "[1], -1.0, 1.0) * {scale}.0)) & 0xFF) << 8) | ((int(round(clamp("
4125                        )?;
4126                        self.write_expr(module, arg, func_ctx)?;
4127                        write!(
4128                            self.out,
4129                            "[2], -1.0, 1.0) * {scale}.0)) & 0xFF) << 16) | ((int(round(clamp("
4130                        )?;
4131                        self.write_expr(module, arg, func_ctx)?;
4132                        write!(self.out, "[3], -1.0, 1.0) * {scale}.0)) & 0xFF) << 24))",)?;
4133                    }
4134                    Function::Pack4x8unorm => {
4135                        let scale = 255;
4136
4137                        write!(self.out, "(uint(round(clamp(")?;
4138                        self.write_expr(module, arg, func_ctx)?;
4139                        write!(self.out, "[0], 0.0, 1.0) * {scale}.0)) | uint(round(clamp(")?;
4140                        self.write_expr(module, arg, func_ctx)?;
4141                        write!(
4142                            self.out,
4143                            "[1], 0.0, 1.0) * {scale}.0)) << 8 | uint(round(clamp("
4144                        )?;
4145                        self.write_expr(module, arg, func_ctx)?;
4146                        write!(
4147                            self.out,
4148                            "[2], 0.0, 1.0) * {scale}.0)) << 16 | uint(round(clamp("
4149                        )?;
4150                        self.write_expr(module, arg, func_ctx)?;
4151                        write!(self.out, "[3], 0.0, 1.0) * {scale}.0)) << 24)")?;
4152                    }
4153                    fun @ (Function::Pack4xI8
4154                    | Function::Pack4xU8
4155                    | Function::Pack4xI8Clamp
4156                    | Function::Pack4xU8Clamp) => {
4157                        let was_signed =
4158                            matches!(fun, Function::Pack4xI8 | Function::Pack4xI8Clamp);
4159                        let clamp_bounds = match fun {
4160                            Function::Pack4xI8Clamp => Some(("-128", "127")),
4161                            Function::Pack4xU8Clamp => Some(("0", "255")),
4162                            _ => None,
4163                        };
4164                        if was_signed {
4165                            write!(self.out, "uint(")?;
4166                        }
4167                        let write_arg = |this: &mut Self| -> BackendResult {
4168                            if let Some((min, max)) = clamp_bounds {
4169                                write!(this.out, "clamp(")?;
4170                                this.write_expr(module, arg, func_ctx)?;
4171                                write!(this.out, ", {min}, {max})")?;
4172                            } else {
4173                                this.write_expr(module, arg, func_ctx)?;
4174                            }
4175                            Ok(())
4176                        };
4177                        write!(self.out, "(")?;
4178                        write_arg(self)?;
4179                        write!(self.out, "[0] & 0xFF) | ((")?;
4180                        write_arg(self)?;
4181                        write!(self.out, "[1] & 0xFF) << 8) | ((")?;
4182                        write_arg(self)?;
4183                        write!(self.out, "[2] & 0xFF) << 16) | ((")?;
4184                        write_arg(self)?;
4185                        write!(self.out, "[3] & 0xFF) << 24)")?;
4186                        if was_signed {
4187                            write!(self.out, ")")?;
4188                        }
4189                    }
4190
4191                    Function::Unpack2x16float => {
4192                        write!(self.out, "float2(f16tof32(")?;
4193                        self.write_expr(module, arg, func_ctx)?;
4194                        write!(self.out, "), f16tof32((")?;
4195                        self.write_expr(module, arg, func_ctx)?;
4196                        write!(self.out, ") >> 16))")?;
4197                    }
4198                    Function::Unpack2x16snorm => {
4199                        let scale = 32767;
4200
4201                        write!(self.out, "(float2(int2(")?;
4202                        self.write_expr(module, arg, func_ctx)?;
4203                        write!(self.out, " << 16, ")?;
4204                        self.write_expr(module, arg, func_ctx)?;
4205                        write!(self.out, ") >> 16) / {scale}.0)")?;
4206                    }
4207                    Function::Unpack2x16unorm => {
4208                        let scale = 65535;
4209
4210                        write!(self.out, "(float2(")?;
4211                        self.write_expr(module, arg, func_ctx)?;
4212                        write!(self.out, " & 0xFFFF, ")?;
4213                        self.write_expr(module, arg, func_ctx)?;
4214                        write!(self.out, " >> 16) / {scale}.0)")?;
4215                    }
4216                    Function::Unpack4x8snorm => {
4217                        let scale = 127;
4218
4219                        write!(self.out, "(float4(int4(")?;
4220                        self.write_expr(module, arg, func_ctx)?;
4221                        write!(self.out, " << 24, ")?;
4222                        self.write_expr(module, arg, func_ctx)?;
4223                        write!(self.out, " << 16, ")?;
4224                        self.write_expr(module, arg, func_ctx)?;
4225                        write!(self.out, " << 8, ")?;
4226                        self.write_expr(module, arg, func_ctx)?;
4227                        write!(self.out, ") >> 24) / {scale}.0)")?;
4228                    }
4229                    Function::Unpack4x8unorm => {
4230                        let scale = 255;
4231
4232                        write!(self.out, "(float4(")?;
4233                        self.write_expr(module, arg, func_ctx)?;
4234                        write!(self.out, " & 0xFF, ")?;
4235                        self.write_expr(module, arg, func_ctx)?;
4236                        write!(self.out, " >> 8 & 0xFF, ")?;
4237                        self.write_expr(module, arg, func_ctx)?;
4238                        write!(self.out, " >> 16 & 0xFF, ")?;
4239                        self.write_expr(module, arg, func_ctx)?;
4240                        write!(self.out, " >> 24) / {scale}.0)")?;
4241                    }
4242                    fun @ (Function::Unpack4xI8 | Function::Unpack4xU8) => {
4243                        write!(self.out, "(")?;
4244                        if matches!(fun, Function::Unpack4xU8) {
4245                            write!(self.out, "u")?;
4246                        }
4247                        write!(self.out, "int4(")?;
4248                        self.write_expr(module, arg, func_ctx)?;
4249                        write!(self.out, ", ")?;
4250                        self.write_expr(module, arg, func_ctx)?;
4251                        write!(self.out, " >> 8, ")?;
4252                        self.write_expr(module, arg, func_ctx)?;
4253                        write!(self.out, " >> 16, ")?;
4254                        self.write_expr(module, arg, func_ctx)?;
4255                        write!(self.out, " >> 24) << 24 >> 24)")?;
4256                    }
4257                    fun @ (Function::Dot4I8Packed | Function::Dot4U8Packed) => {
4258                        let arg1 = arg1.unwrap();
4259
4260                        if self.options.shader_model >= ShaderModel::V6_4 {
4261                            // Intrinsics `dot4add_{i, u}8packed` are available in SM 6.4 and later.
4262                            let function_name = match fun {
4263                                Function::Dot4I8Packed => "dot4add_i8packed",
4264                                Function::Dot4U8Packed => "dot4add_u8packed",
4265                                _ => unreachable!(),
4266                            };
4267                            write!(self.out, "{function_name}(")?;
4268                            self.write_expr(module, arg, func_ctx)?;
4269                            write!(self.out, ", ")?;
4270                            self.write_expr(module, arg1, func_ctx)?;
4271                            write!(self.out, ", 0)")?;
4272                        } else {
4273                            // Fall back to a polyfill as `dot4add_u8packed` is not available.
4274                            write!(self.out, "dot(")?;
4275
4276                            if matches!(fun, Function::Dot4U8Packed) {
4277                                write!(self.out, "u")?;
4278                            }
4279                            write!(self.out, "int4(")?;
4280                            self.write_expr(module, arg, func_ctx)?;
4281                            write!(self.out, ", ")?;
4282                            self.write_expr(module, arg, func_ctx)?;
4283                            write!(self.out, " >> 8, ")?;
4284                            self.write_expr(module, arg, func_ctx)?;
4285                            write!(self.out, " >> 16, ")?;
4286                            self.write_expr(module, arg, func_ctx)?;
4287                            write!(self.out, " >> 24) << 24 >> 24, ")?;
4288
4289                            if matches!(fun, Function::Dot4U8Packed) {
4290                                write!(self.out, "u")?;
4291                            }
4292                            write!(self.out, "int4(")?;
4293                            self.write_expr(module, arg1, func_ctx)?;
4294                            write!(self.out, ", ")?;
4295                            self.write_expr(module, arg1, func_ctx)?;
4296                            write!(self.out, " >> 8, ")?;
4297                            self.write_expr(module, arg1, func_ctx)?;
4298                            write!(self.out, " >> 16, ")?;
4299                            self.write_expr(module, arg1, func_ctx)?;
4300                            write!(self.out, " >> 24) << 24 >> 24)")?;
4301                        }
4302                    }
4303                    Function::QuantizeToF16 => {
4304                        write!(self.out, "f16tof32(f32tof16(")?;
4305                        self.write_expr(module, arg, func_ctx)?;
4306                        write!(self.out, "))")?;
4307                    }
4308                    Function::Regular(fun_name) => {
4309                        write!(self.out, "{fun_name}(")?;
4310                        self.write_expr(module, arg, func_ctx)?;
4311                        if let Some(arg) = arg1 {
4312                            write!(self.out, ", ")?;
4313                            self.write_expr(module, arg, func_ctx)?;
4314                        }
4315                        if let Some(arg) = arg2 {
4316                            write!(self.out, ", ")?;
4317                            self.write_expr(module, arg, func_ctx)?;
4318                        }
4319                        if let Some(arg) = arg3 {
4320                            write!(self.out, ", ")?;
4321                            self.write_expr(module, arg, func_ctx)?;
4322                        }
4323                        write!(self.out, ")")?
4324                    }
4325                    // These overloads are only missing on FXC, so this is only needed for 32bit types,
4326                    // as non-32bit types are DXC only.
4327                    Function::MissingIntOverload(fun_name) => {
4328                        let scalar_kind = func_ctx.resolve_type(arg, &module.types).scalar();
4329                        if let Some(Scalar::I32) = scalar_kind {
4330                            write!(self.out, "asint({fun_name}(asuint(")?;
4331                            self.write_expr(module, arg, func_ctx)?;
4332                            write!(self.out, ")))")?;
4333                        } else {
4334                            write!(self.out, "{fun_name}(")?;
4335                            self.write_expr(module, arg, func_ctx)?;
4336                            write!(self.out, ")")?;
4337                        }
4338                    }
4339                    // These overloads are only missing on FXC, so this is only needed for 32bit types,
4340                    // as non-32bit types are DXC only.
4341                    Function::MissingIntReturnType(fun_name) => {
4342                        let scalar_kind = func_ctx.resolve_type(arg, &module.types).scalar();
4343                        if let Some(Scalar::I32) = scalar_kind {
4344                            write!(self.out, "asint({fun_name}(")?;
4345                            self.write_expr(module, arg, func_ctx)?;
4346                            write!(self.out, "))")?;
4347                        } else {
4348                            write!(self.out, "{fun_name}(")?;
4349                            self.write_expr(module, arg, func_ctx)?;
4350                            write!(self.out, ")")?;
4351                        }
4352                    }
4353                    Function::CountTrailingZeros => {
4354                        match *func_ctx.resolve_type(arg, &module.types) {
4355                            TypeInner::Vector { size, scalar } => {
4356                                let s = match size {
4357                                    crate::VectorSize::Bi => ".xx",
4358                                    crate::VectorSize::Tri => ".xxx",
4359                                    crate::VectorSize::Quad => ".xxxx",
4360                                };
4361
4362                                let scalar_width_bits = scalar.width * 8;
4363
4364                                if scalar.kind == ScalarKind::Uint || scalar.width != 4 {
4365                                    write!(
4366                                        self.out,
4367                                        "min(({scalar_width_bits}u){s}, firstbitlow("
4368                                    )?;
4369                                    self.write_expr(module, arg, func_ctx)?;
4370                                    write!(self.out, "))")?;
4371                                } else {
4372                                    // This is only needed for the FXC path, on 32bit signed integers.
4373                                    write!(
4374                                        self.out,
4375                                        "asint(min(({scalar_width_bits}u){s}, firstbitlow("
4376                                    )?;
4377                                    self.write_expr(module, arg, func_ctx)?;
4378                                    write!(self.out, ")))")?;
4379                                }
4380                            }
4381                            TypeInner::Scalar(scalar) => {
4382                                let scalar_width_bits = scalar.width * 8;
4383
4384                                if scalar.kind == ScalarKind::Uint || scalar.width != 4 {
4385                                    write!(self.out, "min({scalar_width_bits}u, firstbitlow(")?;
4386                                    self.write_expr(module, arg, func_ctx)?;
4387                                    write!(self.out, "))")?;
4388                                } else {
4389                                    // This is only needed for the FXC path, on 32bit signed integers.
4390                                    write!(
4391                                        self.out,
4392                                        "asint(min({scalar_width_bits}u, firstbitlow("
4393                                    )?;
4394                                    self.write_expr(module, arg, func_ctx)?;
4395                                    write!(self.out, ")))")?;
4396                                }
4397                            }
4398                            _ => unreachable!(),
4399                        }
4400
4401                        return Ok(());
4402                    }
4403                    Function::CountLeadingZeros => {
4404                        match *func_ctx.resolve_type(arg, &module.types) {
4405                            TypeInner::Vector { size, scalar } => {
4406                                let s = match size {
4407                                    crate::VectorSize::Bi => ".xx",
4408                                    crate::VectorSize::Tri => ".xxx",
4409                                    crate::VectorSize::Quad => ".xxxx",
4410                                };
4411
4412                                // scalar width - 1
4413                                let constant = scalar.width * 8 - 1;
4414
4415                                if scalar.kind == ScalarKind::Uint {
4416                                    write!(self.out, "(({constant}u){s} - firstbithigh(")?;
4417                                    self.write_expr(module, arg, func_ctx)?;
4418                                    write!(self.out, "))")?;
4419                                } else {
4420                                    let conversion_func = match scalar.width {
4421                                        4 => "asint",
4422                                        _ => "",
4423                                    };
4424                                    write!(self.out, "(")?;
4425                                    self.write_expr(module, arg, func_ctx)?;
4426                                    write!(
4427                                        self.out,
4428                                        " < (0){s} ? (0){s} : ({constant}){s} - {conversion_func}(firstbithigh("
4429                                    )?;
4430                                    self.write_expr(module, arg, func_ctx)?;
4431                                    write!(self.out, ")))")?;
4432                                }
4433                            }
4434                            TypeInner::Scalar(scalar) => {
4435                                // scalar width - 1
4436                                let constant = scalar.width * 8 - 1;
4437
4438                                if let ScalarKind::Uint = scalar.kind {
4439                                    write!(self.out, "({constant}u - firstbithigh(")?;
4440                                    self.write_expr(module, arg, func_ctx)?;
4441                                    write!(self.out, "))")?;
4442                                } else {
4443                                    let conversion_func = match scalar.width {
4444                                        4 => "asint",
4445                                        _ => "",
4446                                    };
4447                                    write!(self.out, "(")?;
4448                                    self.write_expr(module, arg, func_ctx)?;
4449                                    write!(
4450                                        self.out,
4451                                        " < 0 ? 0 : {constant} - {conversion_func}(firstbithigh("
4452                                    )?;
4453                                    self.write_expr(module, arg, func_ctx)?;
4454                                    write!(self.out, ")))")?;
4455                                }
4456                            }
4457                            _ => unreachable!(),
4458                        }
4459
4460                        return Ok(());
4461                    }
4462                }
4463            }
4464            Expression::Swizzle {
4465                size,
4466                vector,
4467                pattern,
4468            } => {
4469                self.write_expr(module, vector, func_ctx)?;
4470                write!(self.out, ".")?;
4471                for &sc in pattern[..size as usize].iter() {
4472                    self.out.write_char(back::COMPONENTS[sc as usize])?;
4473                }
4474            }
4475            Expression::ArrayLength(expr) => {
4476                let var_handle = match func_ctx.expressions[expr] {
4477                    Expression::AccessIndex { base, index: _ } => {
4478                        match func_ctx.expressions[base] {
4479                            Expression::GlobalVariable(handle) => handle,
4480                            _ => unreachable!(),
4481                        }
4482                    }
4483                    Expression::GlobalVariable(handle) => handle,
4484                    _ => unreachable!(),
4485                };
4486
4487                let var = &module.global_variables[var_handle];
4488                let (offset, stride) = match module.types[var.ty].inner {
4489                    TypeInner::Array { stride, .. } => (0, stride),
4490                    TypeInner::Struct { ref members, .. } => {
4491                        let last = members.last().unwrap();
4492                        let stride = match module.types[last.ty].inner {
4493                            TypeInner::Array { stride, .. } => stride,
4494                            _ => unreachable!(),
4495                        };
4496                        (last.offset, stride)
4497                    }
4498                    _ => unreachable!(),
4499                };
4500
4501                let storage_access = match var.space {
4502                    crate::AddressSpace::Storage { access } => access,
4503                    _ => crate::StorageAccess::default(),
4504                };
4505                let wrapped_array_length = WrappedArrayLength {
4506                    writable: storage_access.contains(crate::StorageAccess::STORE),
4507                };
4508
4509                write!(self.out, "((")?;
4510                self.write_wrapped_array_length_function_name(wrapped_array_length)?;
4511                let var_name = &self.names[&NameKey::GlobalVariable(var_handle)];
4512                write!(self.out, "({var_name}) - {offset}) / {stride})")?
4513            }
4514            Expression::Derivative { axis, ctrl, expr } => {
4515                use crate::{DerivativeAxis as Axis, DerivativeControl as Ctrl};
4516                if axis == Axis::Width && (ctrl == Ctrl::Coarse || ctrl == Ctrl::Fine) {
4517                    let tail = match ctrl {
4518                        Ctrl::Coarse => "coarse",
4519                        Ctrl::Fine => "fine",
4520                        Ctrl::None => unreachable!(),
4521                    };
4522                    write!(self.out, "abs(ddx_{tail}(")?;
4523                    self.write_expr(module, expr, func_ctx)?;
4524                    write!(self.out, ")) + abs(ddy_{tail}(")?;
4525                    self.write_expr(module, expr, func_ctx)?;
4526                    write!(self.out, "))")?
4527                } else {
4528                    let fun_str = match (axis, ctrl) {
4529                        (Axis::X, Ctrl::Coarse) => "ddx_coarse",
4530                        (Axis::X, Ctrl::Fine) => "ddx_fine",
4531                        (Axis::X, Ctrl::None) => "ddx",
4532                        (Axis::Y, Ctrl::Coarse) => "ddy_coarse",
4533                        (Axis::Y, Ctrl::Fine) => "ddy_fine",
4534                        (Axis::Y, Ctrl::None) => "ddy",
4535                        (Axis::Width, Ctrl::Coarse | Ctrl::Fine) => unreachable!(),
4536                        (Axis::Width, Ctrl::None) => "fwidth",
4537                    };
4538                    write!(self.out, "{fun_str}(")?;
4539                    self.write_expr(module, expr, func_ctx)?;
4540                    write!(self.out, ")")?
4541                }
4542            }
4543            Expression::Relational { fun, argument } => {
4544                use crate::RelationalFunction as Rf;
4545
4546                let fun_str = match fun {
4547                    Rf::All => "all",
4548                    Rf::Any => "any",
4549                    Rf::IsNan => "isnan",
4550                    Rf::IsInf => "isinf",
4551                };
4552                write!(self.out, "{fun_str}(")?;
4553                self.write_expr(module, argument, func_ctx)?;
4554                write!(self.out, ")")?
4555            }
4556            Expression::Select {
4557                condition,
4558                accept,
4559                reject,
4560            } => {
4561                write!(self.out, "(")?;
4562                self.write_expr(module, condition, func_ctx)?;
4563                write!(self.out, " ? ")?;
4564                self.write_expr(module, accept, func_ctx)?;
4565                write!(self.out, " : ")?;
4566                self.write_expr(module, reject, func_ctx)?;
4567                write!(self.out, ")")?
4568            }
4569            Expression::RayQueryGetIntersection { query, committed } => {
4570                // For reasoning, see write_stmt
4571                let Expression::LocalVariable(query_var) = func_ctx.expressions[query] else {
4572                    unreachable!()
4573                };
4574
4575                let tracker_expr_name = format!(
4576                    "{RAY_QUERY_TRACKER_VARIABLE_PREFIX}{}",
4577                    self.names[&func_ctx.name_key(query_var)]
4578                );
4579
4580                if committed {
4581                    write!(self.out, "GetCommittedIntersection(")?;
4582                    self.write_expr(module, query, func_ctx)?;
4583                    write!(self.out, ", {tracker_expr_name})")?;
4584                } else {
4585                    write!(self.out, "GetCandidateIntersection(")?;
4586                    self.write_expr(module, query, func_ctx)?;
4587                    write!(self.out, ", {tracker_expr_name})")?;
4588                }
4589            }
4590            // Not supported yet
4591            Expression::RayQueryVertexPositions { .. }
4592            | Expression::CooperativeLoad { .. }
4593            | Expression::CooperativeMultiplyAdd { .. } => {
4594                unreachable!()
4595            }
4596            // Nothing to do here, since call expression already cached
4597            Expression::CallResult(_)
4598            | Expression::AtomicResult { .. }
4599            | Expression::WorkGroupUniformLoadResult { .. }
4600            | Expression::RayQueryProceedResult
4601            | Expression::SubgroupBallotResult
4602            | Expression::SubgroupOperationResult { .. } => {}
4603        }
4604
4605        if !closing_bracket.is_empty() {
4606            write!(self.out, "{closing_bracket}")?;
4607        }
4608        Ok(())
4609    }
4610
4611    #[allow(clippy::too_many_arguments)]
4612    fn write_image_load(
4613        &mut self,
4614        module: &&Module,
4615        expr: Handle<crate::Expression>,
4616        func_ctx: &back::FunctionCtx,
4617        image: Handle<crate::Expression>,
4618        coordinate: Handle<crate::Expression>,
4619        array_index: Option<Handle<crate::Expression>>,
4620        sample: Option<Handle<crate::Expression>>,
4621        level: Option<Handle<crate::Expression>>,
4622    ) -> Result<(), Error> {
4623        let mut wrapping_type = None;
4624        match *func_ctx.resolve_type(image, &module.types) {
4625            TypeInner::Image {
4626                class: crate::ImageClass::External,
4627                ..
4628            } => {
4629                write!(self.out, "{IMAGE_LOAD_EXTERNAL_FUNCTION}(")?;
4630                self.write_expr(module, image, func_ctx)?;
4631                write!(self.out, ", ")?;
4632                self.write_expr(module, coordinate, func_ctx)?;
4633                write!(self.out, ")")?;
4634                return Ok(());
4635            }
4636            TypeInner::Image {
4637                class: crate::ImageClass::Storage { format, .. },
4638                ..
4639            } => {
4640                if format.single_component() {
4641                    wrapping_type = Some(Scalar::from(format));
4642                }
4643            }
4644            _ => {}
4645        }
4646        if let Some(scalar) = wrapping_type {
4647            write!(
4648                self.out,
4649                "{}{}(",
4650                help::IMAGE_STORAGE_LOAD_SCALAR_WRAPPER,
4651                scalar.to_hlsl_str()?
4652            )?;
4653        }
4654        // https://docs.microsoft.com/en-us/windows/win32/direct3dhlsl/dx-graphics-hlsl-to-load
4655        self.write_expr(module, image, func_ctx)?;
4656        write!(self.out, ".Load(")?;
4657
4658        self.write_texture_coordinates("int", coordinate, array_index, level, module, func_ctx)?;
4659
4660        if let Some(sample) = sample {
4661            write!(self.out, ", ")?;
4662            self.write_expr(module, sample, func_ctx)?;
4663        }
4664
4665        // close bracket for Load function
4666        write!(self.out, ")")?;
4667
4668        if wrapping_type.is_some() {
4669            write!(self.out, ")")?;
4670        }
4671
4672        // return x component if return type is scalar
4673        if let TypeInner::Scalar(_) = *func_ctx.resolve_type(expr, &module.types) {
4674            write!(self.out, ".x")?;
4675        }
4676        Ok(())
4677    }
4678
4679    /// Find the [`BindingArraySamplerInfo`] from an expression so that such an access
4680    /// can be generated later.
4681    fn sampler_binding_array_info_from_expression(
4682        &mut self,
4683        module: &Module,
4684        func_ctx: &back::FunctionCtx<'_>,
4685        base: Handle<crate::Expression>,
4686        resolved: &TypeInner,
4687    ) -> Option<BindingArraySamplerInfo> {
4688        if let TypeInner::BindingArray {
4689            base: base_ty_handle,
4690            ..
4691        } = *resolved
4692        {
4693            let base_ty = &module.types[base_ty_handle].inner;
4694            if let TypeInner::Sampler { comparison, .. } = *base_ty {
4695                let base = &func_ctx.expressions[base];
4696
4697                if let crate::Expression::GlobalVariable(handle) = *base {
4698                    let variable = &module.global_variables[handle];
4699
4700                    let sampler_heap_name = match comparison {
4701                        true => COMPARISON_SAMPLER_HEAP_VAR,
4702                        false => SAMPLER_HEAP_VAR,
4703                    };
4704
4705                    return Some(BindingArraySamplerInfo {
4706                        sampler_heap_name,
4707                        sampler_index_buffer_name: self
4708                            .wrapped
4709                            .sampler_index_buffers
4710                            .get(&super::SamplerIndexBufferKey {
4711                                group: variable.binding.unwrap().group,
4712                            })
4713                            .unwrap()
4714                            .clone(),
4715                        binding_array_base_index_name: self.names[&NameKey::GlobalVariable(handle)]
4716                            .clone(),
4717                    });
4718                }
4719            }
4720        }
4721
4722        None
4723    }
4724
4725    fn write_named_expr(
4726        &mut self,
4727        module: &Module,
4728        handle: Handle<crate::Expression>,
4729        name: String,
4730        // The expression which is being named.
4731        // Generally, this is the same as handle, except in WorkGroupUniformLoad
4732        expr: Handle<crate::Expression>,
4733        func_ctx: &back::FunctionCtx,
4734    ) -> BackendResult {
4735        if let crate::Expression::Load { pointer } = func_ctx.expressions[expr] {
4736            let ty_inner = func_ctx.resolve_type(pointer, &module.types);
4737            if ty_inner.is_atomic_pointer(&module.types) {
4738                let pointer_space = ty_inner.pointer_space().unwrap();
4739                self.write_value_type(module, func_ctx.info[handle].ty.inner_with(&module.types))?;
4740                write!(self.out, " {name}; ")?;
4741                match pointer_space {
4742                    crate::AddressSpace::WorkGroup => {
4743                        write!(self.out, "InterlockedOr(")?;
4744                        self.write_expr(module, pointer, func_ctx)?;
4745                    }
4746                    crate::AddressSpace::Storage { .. } => {
4747                        let var_handle = self.fill_access_chain(module, pointer, func_ctx)?;
4748                        let var_name = &self.names[&NameKey::GlobalVariable(var_handle)];
4749                        write!(self.out, "{var_name}.InterlockedOr(")?;
4750                        let chain = mem::take(&mut self.temp_access_chain);
4751                        self.write_storage_address(module, &chain, func_ctx)?;
4752                        self.temp_access_chain = chain;
4753                    }
4754                    _ => unreachable!(),
4755                }
4756                writeln!(self.out, ", 0, {name});")?;
4757                self.named_expressions.insert(expr, name);
4758                return Ok(());
4759            }
4760        }
4761        match func_ctx.info[expr].ty {
4762            proc::TypeResolution::Handle(ty_handle) => match module.types[ty_handle].inner {
4763                TypeInner::Struct { .. } => {
4764                    let ty_name = &self.names[&NameKey::Type(ty_handle)];
4765                    write!(self.out, "{ty_name}")?;
4766                }
4767                _ => {
4768                    self.write_type(module, ty_handle)?;
4769                }
4770            },
4771            proc::TypeResolution::Value(ref inner) => {
4772                self.write_value_type(module, inner)?;
4773            }
4774        }
4775
4776        let resolved = func_ctx.resolve_type(expr, &module.types);
4777
4778        write!(self.out, " {name}")?;
4779        // If rhs is a array type, we should write array size
4780        if let TypeInner::Array { base, size, .. } = *resolved {
4781            self.write_array_size(module, base, size)?;
4782        }
4783        write!(self.out, " = ")?;
4784        self.write_expr(module, handle, func_ctx)?;
4785        writeln!(self.out, ";")?;
4786        self.named_expressions.insert(expr, name);
4787
4788        Ok(())
4789    }
4790
4791    /// Helper function that write default zero initialization
4792    pub(super) fn write_default_init(
4793        &mut self,
4794        module: &Module,
4795        ty: Handle<crate::Type>,
4796    ) -> BackendResult {
4797        write!(self.out, "(")?;
4798        self.write_type(module, ty)?;
4799        if let TypeInner::Array { base, size, .. } = module.types[ty].inner {
4800            self.write_array_size(module, base, size)?;
4801        }
4802        write!(self.out, ")0")?;
4803        Ok(())
4804    }
4805
4806    pub(super) fn write_control_barrier(
4807        &mut self,
4808        barrier: crate::Barrier,
4809        level: back::Level,
4810    ) -> BackendResult {
4811        if barrier.contains(crate::Barrier::STORAGE) {
4812            writeln!(self.out, "{level}DeviceMemoryBarrierWithGroupSync();")?;
4813        }
4814        if barrier.contains(crate::Barrier::WORK_GROUP) {
4815            writeln!(self.out, "{level}GroupMemoryBarrierWithGroupSync();")?;
4816        }
4817        if barrier.contains(crate::Barrier::SUB_GROUP) {
4818            // Does not exist in DirectX
4819        }
4820        if barrier.contains(crate::Barrier::TEXTURE) {
4821            writeln!(self.out, "{level}DeviceMemoryBarrierWithGroupSync();")?;
4822        }
4823        Ok(())
4824    }
4825
4826    fn write_memory_barrier(
4827        &mut self,
4828        barrier: crate::Barrier,
4829        level: back::Level,
4830    ) -> BackendResult {
4831        if barrier.contains(crate::Barrier::STORAGE) {
4832            writeln!(self.out, "{level}DeviceMemoryBarrier();")?;
4833        }
4834        if barrier.contains(crate::Barrier::WORK_GROUP) {
4835            writeln!(self.out, "{level}GroupMemoryBarrier();")?;
4836        }
4837        if barrier.contains(crate::Barrier::SUB_GROUP) {
4838            // Does not exist in DirectX
4839        }
4840        if barrier.contains(crate::Barrier::TEXTURE) {
4841            writeln!(self.out, "{level}DeviceMemoryBarrier();")?;
4842        }
4843        Ok(())
4844    }
4845
4846    /// Helper to emit the shared tail of an HLSL atomic call (arguments, value, result)
4847    fn emit_hlsl_atomic_tail(
4848        &mut self,
4849        module: &Module,
4850        func_ctx: &back::FunctionCtx<'_>,
4851        fun: &crate::AtomicFunction,
4852        compare_expr: Option<Handle<crate::Expression>>,
4853        value: Handle<crate::Expression>,
4854        res_var_info: &Option<(Handle<crate::Expression>, String)>,
4855    ) -> BackendResult {
4856        if let Some(cmp) = compare_expr {
4857            write!(self.out, ", ")?;
4858            self.write_expr(module, cmp, func_ctx)?;
4859        }
4860        write!(self.out, ", ")?;
4861        if let crate::AtomicFunction::Subtract = *fun {
4862            // we just wrote `InterlockedAdd`, so negate the argument
4863            write!(self.out, "-")?;
4864        }
4865        self.write_expr(module, value, func_ctx)?;
4866        if let Some(&(_res_handle, ref res_name)) = res_var_info.as_ref() {
4867            write!(self.out, ", ")?;
4868            if compare_expr.is_some() {
4869                write!(self.out, "{res_name}.old_value")?;
4870            } else {
4871                write!(self.out, "{res_name}")?;
4872            }
4873        }
4874        writeln!(self.out, ");")?;
4875        Ok(())
4876    }
4877}
4878
4879pub(super) struct MatrixType {
4880    pub(super) columns: crate::VectorSize,
4881    pub(super) rows: crate::VectorSize,
4882    pub(super) width: crate::Bytes,
4883}
4884
4885pub(super) fn get_inner_matrix_data(
4886    module: &Module,
4887    handle: Handle<crate::Type>,
4888) -> Option<MatrixType> {
4889    match module.types[handle].inner {
4890        TypeInner::Matrix {
4891            columns,
4892            rows,
4893            scalar,
4894        } => Some(MatrixType {
4895            columns,
4896            rows,
4897            width: scalar.width,
4898        }),
4899        TypeInner::Array { base, .. } => get_inner_matrix_data(module, base),
4900        _ => None,
4901    }
4902}
4903
4904/// If `base` is an access chain of the form `mat`, `mat[col]`, or `mat[col][row]`,
4905/// returns a tuple of the matrix, the column (vector) index (if present), and
4906/// the row (scalar) index (if present).
4907fn find_matrix_in_access_chain(
4908    module: &Module,
4909    base: Handle<crate::Expression>,
4910    func_ctx: &back::FunctionCtx<'_>,
4911) -> Option<(Handle<crate::Expression>, Option<Index>, Option<Index>)> {
4912    let mut current_base = base;
4913    let mut vector = None;
4914    let mut scalar = None;
4915    loop {
4916        let resolved_tr = func_ctx
4917            .resolve_type(current_base, &module.types)
4918            .pointer_base_type();
4919        let resolved = resolved_tr.as_ref()?.inner_with(&module.types);
4920
4921        match *resolved {
4922            TypeInner::Matrix { .. } => return Some((current_base, vector, scalar)),
4923            TypeInner::Scalar(_) | TypeInner::Vector { .. } => {}
4924            _ => return None,
4925        }
4926
4927        let index;
4928        (current_base, index) = match func_ctx.expressions[current_base] {
4929            crate::Expression::Access { base, index } => (base, Index::Expression(index)),
4930            crate::Expression::AccessIndex { base, index } => (base, Index::Static(index)),
4931            _ => return None,
4932        };
4933
4934        match *resolved {
4935            TypeInner::Scalar(_) => scalar = Some(index),
4936            TypeInner::Vector { .. } => vector = Some(index),
4937            _ => unreachable!(),
4938        }
4939    }
4940}
4941
4942/// Returns the matrix data if the access chain starting at `base`:
4943/// - starts with an expression with resolved type of [`TypeInner::Matrix`] if `direct = true`
4944/// - contains one or more expressions with resolved type of [`TypeInner::Array`] of [`TypeInner::Matrix`]
4945/// - ends at an expression with resolved type of [`TypeInner::Struct`]
4946pub(super) fn get_inner_matrix_of_struct_array_member(
4947    module: &Module,
4948    base: Handle<crate::Expression>,
4949    func_ctx: &back::FunctionCtx<'_>,
4950    direct: bool,
4951) -> Option<MatrixType> {
4952    let mut mat_data = None;
4953    let mut array_base = None;
4954
4955    let mut current_base = base;
4956    loop {
4957        let mut resolved = func_ctx.resolve_type(current_base, &module.types);
4958        if let TypeInner::Pointer { base, .. } = *resolved {
4959            resolved = &module.types[base].inner;
4960        };
4961
4962        match *resolved {
4963            TypeInner::Matrix {
4964                columns,
4965                rows,
4966                scalar,
4967            } => {
4968                mat_data = Some(MatrixType {
4969                    columns,
4970                    rows,
4971                    width: scalar.width,
4972                })
4973            }
4974            TypeInner::Array { base, .. } => {
4975                array_base = Some(base);
4976            }
4977            TypeInner::Struct { .. } => {
4978                if let Some(array_base) = array_base {
4979                    if direct {
4980                        return mat_data;
4981                    } else {
4982                        return get_inner_matrix_data(module, array_base);
4983                    }
4984                }
4985
4986                break;
4987            }
4988            _ => break,
4989        }
4990
4991        current_base = match func_ctx.expressions[current_base] {
4992            crate::Expression::Access { base, .. } => base,
4993            crate::Expression::AccessIndex { base, .. } => base,
4994            _ => break,
4995        };
4996    }
4997    None
4998}
4999
5000/// Simpler version of get_inner_matrix_of_global_uniform that only looks at the
5001/// immediate expression, rather than traversing an access chain.
5002fn get_global_uniform_matrix(
5003    module: &Module,
5004    base: Handle<crate::Expression>,
5005    func_ctx: &back::FunctionCtx<'_>,
5006) -> Option<MatrixType> {
5007    let base_tr = func_ctx
5008        .resolve_type(base, &module.types)
5009        .pointer_base_type();
5010    let base_ty = base_tr.as_ref().map(|tr| tr.inner_with(&module.types));
5011    match (&func_ctx.expressions[base], base_ty) {
5012        (
5013            &crate::Expression::GlobalVariable(handle),
5014            Some(&TypeInner::Matrix {
5015                columns,
5016                rows,
5017                scalar,
5018            }),
5019        ) if module.global_variables[handle].space == crate::AddressSpace::Uniform => {
5020            Some(MatrixType {
5021                columns,
5022                rows,
5023                width: scalar.width,
5024            })
5025        }
5026        _ => None,
5027    }
5028}
5029
5030/// Returns the matrix data if the access chain starting at `base`:
5031/// - starts with an expression with resolved type of [`TypeInner::Matrix`]
5032/// - contains zero or more expressions with resolved type of [`TypeInner::Array`] of [`TypeInner::Matrix`]
5033/// - ends with an [`Expression::GlobalVariable`](crate::Expression::GlobalVariable) in [`AddressSpace::Uniform`](crate::AddressSpace::Uniform)
5034fn get_inner_matrix_of_global_uniform(
5035    module: &Module,
5036    base: Handle<crate::Expression>,
5037    func_ctx: &back::FunctionCtx<'_>,
5038) -> Option<MatrixType> {
5039    let mut mat_data = None;
5040    let mut array_base = None;
5041
5042    let mut current_base = base;
5043    loop {
5044        let mut resolved = func_ctx.resolve_type(current_base, &module.types);
5045        if let TypeInner::Pointer { base, .. } = *resolved {
5046            resolved = &module.types[base].inner;
5047        };
5048
5049        match *resolved {
5050            TypeInner::Matrix {
5051                columns,
5052                rows,
5053                scalar,
5054            } => {
5055                mat_data = Some(MatrixType {
5056                    columns,
5057                    rows,
5058                    width: scalar.width,
5059                })
5060            }
5061            TypeInner::Array { base, .. } => {
5062                array_base = Some(base);
5063            }
5064            _ => break,
5065        }
5066
5067        current_base = match func_ctx.expressions[current_base] {
5068            crate::Expression::Access { base, .. } => base,
5069            crate::Expression::AccessIndex { base, .. } => base,
5070            crate::Expression::GlobalVariable(handle)
5071                if module.global_variables[handle].space == crate::AddressSpace::Uniform =>
5072            {
5073                return mat_data.or_else(|| {
5074                    array_base.and_then(|array_base| get_inner_matrix_data(module, array_base))
5075                })
5076            }
5077            _ => break,
5078        };
5079    }
5080    None
5081}