naga/back/hlsl/
writer.rs

1use alloc::{
2    format,
3    string::{String, ToString},
4    vec::Vec,
5};
6use core::{
7    fmt::{self, Write as _},
8    mem,
9};
10
11use super::{
12    help,
13    help::{
14        WrappedArrayLength, WrappedConstructor, WrappedImageQuery, WrappedStructMatrixAccess,
15        WrappedZeroValue,
16    },
17    storage::StoreValue,
18    BackendResult, Error, FragmentEntryPoint, Options, PipelineOptions, ShaderModel,
19};
20use crate::{
21    back::{self, get_entry_points, Baked},
22    common,
23    proc::{self, index, ExternalTextureNameKey, NameKey},
24    valid, Handle, Module, RayQueryFunction, Scalar, ScalarKind, ShaderStage, TypeInner,
25};
26
27const LOCATION_SEMANTIC: &str = "LOC";
28const SPECIAL_CBUF_TYPE: &str = "NagaConstants";
29const SPECIAL_CBUF_VAR: &str = "_NagaConstants";
30const SPECIAL_FIRST_VERTEX: &str = "first_vertex";
31const SPECIAL_FIRST_INSTANCE: &str = "first_instance";
32const SPECIAL_OTHER: &str = "other";
33
34pub(crate) const MODF_FUNCTION: &str = "naga_modf";
35pub(crate) const FREXP_FUNCTION: &str = "naga_frexp";
36pub(crate) const EXTRACT_BITS_FUNCTION: &str = "naga_extractBits";
37pub(crate) const INSERT_BITS_FUNCTION: &str = "naga_insertBits";
38pub(crate) const SAMPLER_HEAP_VAR: &str = "nagaSamplerHeap";
39pub(crate) const COMPARISON_SAMPLER_HEAP_VAR: &str = "nagaComparisonSamplerHeap";
40pub(crate) const SAMPLE_EXTERNAL_TEXTURE_FUNCTION: &str = "nagaSampleExternalTexture";
41pub(crate) const ABS_FUNCTION: &str = "naga_abs";
42pub(crate) const DIV_FUNCTION: &str = "naga_div";
43pub(crate) const MOD_FUNCTION: &str = "naga_mod";
44pub(crate) const NEG_FUNCTION: &str = "naga_neg";
45pub(crate) const F2I32_FUNCTION: &str = "naga_f2i32";
46pub(crate) const F2U32_FUNCTION: &str = "naga_f2u32";
47pub(crate) const F2I64_FUNCTION: &str = "naga_f2i64";
48pub(crate) const F2U64_FUNCTION: &str = "naga_f2u64";
49pub(crate) const IMAGE_SAMPLE_BASE_CLAMP_TO_EDGE_FUNCTION: &str =
50    "nagaTextureSampleBaseClampToEdge";
51pub(crate) const IMAGE_LOAD_EXTERNAL_FUNCTION: &str = "nagaTextureLoadExternal";
52pub(crate) const RAY_QUERY_TRACKER_VARIABLE_PREFIX: &str = "naga_query_init_tracker_for_";
53/// Prefix for variables in a naga statement
54pub(crate) const INTERNAL_PREFIX: &str = "naga_";
55
56enum Index {
57    Expression(Handle<crate::Expression>),
58    Static(u32),
59}
60
61pub(super) struct EpStructMember {
62    pub(super) name: String,
63    pub(super) ty: Handle<crate::Type>,
64    // technically, this should always be `Some`
65    // (we `debug_assert!` this in `write_interface_struct`)
66    pub(super) binding: Option<crate::Binding>,
67    pub(super) index: u32,
68}
69
70/// Structure contains information required for generating
71/// wrapped structure of all entry points arguments
72pub(super) struct EntryPointBinding {
73    /// Name of the fake EP argument that contains the struct
74    /// with all the flattened input data.
75    pub(super) arg_name: String,
76    /// Generated structure name
77    pub(super) ty_name: String,
78    /// Members of generated structure
79    pub(super) members: Vec<EpStructMember>,
80    pub(super) local_invocation_index_name: Option<String>,
81}
82
83pub(super) struct EntryPointInterface {
84    /// If `Some`, the input of an entry point is gathered in a special
85    /// struct with members sorted by binding.
86    /// The `EntryPointBinding::members` array is sorted by index,
87    /// so that we can walk it in `write_ep_arguments_initialization`.
88    pub(crate) input: Option<EntryPointBinding>,
89    /// If `Some`, the output of an entry point is flattened.
90    /// The `EntryPointBinding::members` array is sorted by binding,
91    /// So that we can walk it in `Statement::Return` handler.
92    pub(crate) output: Option<EntryPointBinding>,
93    pub(crate) mesh_vertices: Option<EntryPointBinding>,
94    pub(crate) mesh_primitives: Option<EntryPointBinding>,
95    pub(crate) mesh_indices: Option<EntryPointBinding>,
96}
97
98#[derive(Clone, Eq, PartialEq, PartialOrd, Ord)]
99enum InterfaceKey {
100    Location(u32),
101    BuiltIn(crate::BuiltIn),
102    Other,
103}
104
105impl InterfaceKey {
106    const fn new(binding: Option<&crate::Binding>) -> Self {
107        match binding {
108            Some(&crate::Binding::Location { location, .. }) => Self::Location(location),
109            Some(&crate::Binding::BuiltIn(built_in)) => Self::BuiltIn(built_in),
110            None => Self::Other,
111        }
112    }
113}
114
115#[derive(Copy, Clone, PartialEq)]
116pub(super) enum Io {
117    Input,
118    Output,
119    MeshVertices,
120    MeshPrimitives,
121}
122
123/// Argument list for nested entry points
124pub(super) struct NestedEntryPointArgs {
125    /// Arguments literally declared by the user
126    pub user_args: Vec<String>,
127    pub task_payload: Option<String>,
128    pub local_invocation_index: String,
129}
130
131const fn is_subgroup_builtin_binding(binding: &Option<crate::Binding>) -> bool {
132    let &Some(crate::Binding::BuiltIn(builtin)) = binding else {
133        return false;
134    };
135    matches!(
136        builtin,
137        crate::BuiltIn::SubgroupSize
138            | crate::BuiltIn::SubgroupInvocationId
139            | crate::BuiltIn::NumSubgroups
140            | crate::BuiltIn::SubgroupId
141    )
142}
143
144/// Information for how to generate a `binding_array<sampler>` access.
145struct BindingArraySamplerInfo {
146    /// Variable name of the sampler heap
147    sampler_heap_name: &'static str,
148    /// Variable name of the sampler index buffer
149    sampler_index_buffer_name: String,
150    /// Variable name of the base index _into_ the sampler index buffer
151    binding_array_base_index_name: String,
152}
153
154impl<'a, W: fmt::Write> super::Writer<'a, W> {
155    pub fn new(out: W, options: &'a Options, pipeline_options: &'a PipelineOptions) -> Self {
156        Self {
157            out,
158            names: crate::FastHashMap::default(),
159            namer: proc::Namer::default(),
160            options,
161            pipeline_options,
162            entry_point_io: crate::FastHashMap::default(),
163            named_expressions: crate::NamedExpressions::default(),
164            wrapped: super::Wrapped::default(),
165            written_committed_intersection: false,
166            written_candidate_intersection: false,
167            continue_ctx: back::continue_forward::ContinueCtx::default(),
168            temp_access_chain: Vec::new(),
169            need_bake_expressions: Default::default(),
170            function_task_payload_var: Default::default(),
171        }
172    }
173
174    fn reset(&mut self, module: &Module) {
175        self.names.clear();
176        self.namer.reset(
177            module,
178            &super::keywords::RESERVED_SET,
179            proc::KeywordSet::empty(),
180            &super::keywords::RESERVED_CASE_INSENSITIVE_SET,
181            super::keywords::RESERVED_PREFIXES,
182            &mut self.names,
183        );
184        self.entry_point_io.clear();
185        self.named_expressions.clear();
186        self.wrapped.clear();
187        self.written_committed_intersection = false;
188        self.written_candidate_intersection = false;
189        self.continue_ctx.clear();
190        self.need_bake_expressions.clear();
191        self.function_task_payload_var.clear();
192    }
193
194    /// Generates statements to be inserted immediately before and at the very
195    /// start of the body of each loop, to defeat infinite loop reasoning.
196    /// The 0th item of the returned tuple should be inserted immediately prior
197    /// to the loop and the 1st item should be inserted at the very start of
198    /// the loop body.
199    ///
200    /// See [`back::msl::Writer::gen_force_bounded_loop_statements`] for details.
201    fn gen_force_bounded_loop_statements(
202        &mut self,
203        level: back::Level,
204    ) -> Option<(String, String)> {
205        if !self.options.force_loop_bounding {
206            return None;
207        }
208
209        let loop_bound_name = self.namer.call("loop_bound");
210        let max = u32::MAX;
211        // Count down from u32::MAX rather than up from 0 to avoid hang on
212        // certain Intel drivers. See <https://github.com/gfx-rs/wgpu/issues/7319>.
213        let decl = format!("{level}uint2 {loop_bound_name} = uint2({max}u, {max}u);");
214        let level = level.next();
215        let break_and_inc = format!(
216            "{level}if (all({loop_bound_name} == uint2(0u, 0u))) {{ break; }}
217{level}{loop_bound_name} -= uint2({loop_bound_name}.y == 0u, 1u);"
218        );
219
220        Some((decl, break_and_inc))
221    }
222
223    /// Helper method used to find which expressions of a given function require baking
224    ///
225    /// # Notes
226    /// Clears `need_bake_expressions` set before adding to it
227    fn update_expressions_to_bake(
228        &mut self,
229        module: &Module,
230        func: &crate::Function,
231        info: &valid::FunctionInfo,
232    ) {
233        use crate::Expression;
234        self.need_bake_expressions.clear();
235        for (exp_handle, expr) in func.expressions.iter() {
236            let expr_info = &info[exp_handle];
237            let min_ref_count = func.expressions[exp_handle].bake_ref_count();
238            if min_ref_count <= expr_info.ref_count {
239                self.need_bake_expressions.insert(exp_handle);
240            }
241            if let Expression::Load { pointer } = *expr {
242                if info[pointer]
243                    .ty
244                    .inner_with(&module.types)
245                    .is_atomic_pointer(&module.types)
246                {
247                    self.need_bake_expressions.insert(exp_handle);
248                }
249            }
250
251            if let Expression::Math { fun, arg, arg1, .. } = *expr {
252                match fun {
253                    crate::MathFunction::Asinh
254                    | crate::MathFunction::Acosh
255                    | crate::MathFunction::Atanh
256                    | crate::MathFunction::Unpack2x16float
257                    | crate::MathFunction::Unpack2x16snorm
258                    | crate::MathFunction::Unpack2x16unorm
259                    | crate::MathFunction::Unpack4x8snorm
260                    | crate::MathFunction::Unpack4x8unorm
261                    | crate::MathFunction::Unpack4xI8
262                    | crate::MathFunction::Unpack4xU8
263                    | crate::MathFunction::Pack2x16float
264                    | crate::MathFunction::Pack2x16snorm
265                    | crate::MathFunction::Pack2x16unorm
266                    | crate::MathFunction::Pack4x8snorm
267                    | crate::MathFunction::Pack4x8unorm
268                    | crate::MathFunction::Pack4xI8
269                    | crate::MathFunction::Pack4xU8
270                    | crate::MathFunction::Pack4xI8Clamp
271                    | crate::MathFunction::Pack4xU8Clamp => {
272                        self.need_bake_expressions.insert(arg);
273                    }
274                    crate::MathFunction::CountLeadingZeros => {
275                        let inner = info[exp_handle].ty.inner_with(&module.types);
276                        if let Some(ScalarKind::Sint) = inner.scalar_kind() {
277                            self.need_bake_expressions.insert(arg);
278                        }
279                    }
280                    crate::MathFunction::Dot4U8Packed | crate::MathFunction::Dot4I8Packed => {
281                        self.need_bake_expressions.insert(arg);
282                        self.need_bake_expressions.insert(arg1.unwrap());
283                    }
284                    _ => {}
285                }
286            }
287
288            if let Expression::Derivative { axis, ctrl, expr } = *expr {
289                use crate::{DerivativeAxis as Axis, DerivativeControl as Ctrl};
290                if axis == Axis::Width && (ctrl == Ctrl::Coarse || ctrl == Ctrl::Fine) {
291                    self.need_bake_expressions.insert(expr);
292                }
293            }
294
295            if let Expression::GlobalVariable(_) = *expr {
296                let inner = info[exp_handle].ty.inner_with(&module.types);
297
298                if let TypeInner::Sampler { .. } = *inner {
299                    self.need_bake_expressions.insert(exp_handle);
300                }
301            }
302        }
303        for statement in func.body.iter() {
304            match *statement {
305                crate::Statement::SubgroupCollectiveOperation {
306                    op: _,
307                    collective_op: crate::CollectiveOperation::InclusiveScan,
308                    argument,
309                    result: _,
310                } => {
311                    self.need_bake_expressions.insert(argument);
312                }
313                crate::Statement::Atomic {
314                    fun: crate::AtomicFunction::Exchange { compare: Some(cmp) },
315                    ..
316                } => {
317                    self.need_bake_expressions.insert(cmp);
318                }
319                _ => {}
320            }
321        }
322    }
323
324    pub fn write(
325        &mut self,
326        module: &Module,
327        module_info: &valid::ModuleInfo,
328        fragment_entry_point: Option<&FragmentEntryPoint<'_>>,
329    ) -> Result<super::ReflectionInfo, Error> {
330        self.reset(module);
331
332        if module.uses_mesh_shaders() && self.options.shader_model < ShaderModel::V6_5 {
333            return Err(Error::ShaderModelTooLow(
334                "mesh shaders".to_string(),
335                ShaderModel::V6_5,
336            ));
337        }
338
339        // Write special constants, if needed
340        if let Some(ref bt) = self.options.special_constants_binding {
341            writeln!(self.out, "struct {SPECIAL_CBUF_TYPE} {{")?;
342            writeln!(self.out, "{}int {};", back::INDENT, SPECIAL_FIRST_VERTEX)?;
343            writeln!(self.out, "{}int {};", back::INDENT, SPECIAL_FIRST_INSTANCE)?;
344            writeln!(self.out, "{}uint {};", back::INDENT, SPECIAL_OTHER)?;
345            writeln!(self.out, "}};")?;
346            write!(
347                self.out,
348                "ConstantBuffer<{}> {}: register(b{}",
349                SPECIAL_CBUF_TYPE, SPECIAL_CBUF_VAR, bt.register
350            )?;
351            if bt.space != 0 {
352                write!(self.out, ", space{}", bt.space)?;
353            }
354            writeln!(self.out, ");")?;
355
356            // Extra newline for readability
357            writeln!(self.out)?;
358        }
359
360        for (group, bt) in self.options.dynamic_storage_buffer_offsets_targets.iter() {
361            writeln!(self.out, "struct __dynamic_buffer_offsetsTy{group} {{")?;
362            for i in 0..bt.size {
363                writeln!(self.out, "{}uint _{};", back::INDENT, i)?;
364            }
365            writeln!(self.out, "}};")?;
366            writeln!(
367                self.out,
368                "ConstantBuffer<__dynamic_buffer_offsetsTy{}> __dynamic_buffer_offsets{}: register(b{}, space{});",
369                group, group, bt.register, bt.space
370            )?;
371
372            // Extra newline for readability
373            writeln!(self.out)?;
374        }
375
376        // Save all entry point output types
377        let ep_results = module
378            .entry_points
379            .iter()
380            .map(|ep| (ep.stage, ep.function.result.clone()))
381            .collect::<Vec<(ShaderStage, Option<crate::FunctionResult>)>>();
382
383        self.write_all_mat_cx2_typedefs_and_functions(module)?;
384
385        // Write all structs
386        for (handle, ty) in module.types.iter() {
387            if let TypeInner::Struct { ref members, span } = ty.inner {
388                if module.types[members.last().unwrap().ty]
389                    .inner
390                    .is_dynamically_sized(&module.types)
391                {
392                    // unsized arrays can only be in storage buffers,
393                    // for which we use `ByteAddressBuffer` anyway.
394                    continue;
395                }
396
397                let ep_result = ep_results.iter().find(|e| {
398                    if let Some(ref result) = e.1 {
399                        result.ty == handle
400                    } else {
401                        false
402                    }
403                });
404
405                self.write_struct(
406                    module,
407                    handle,
408                    members,
409                    span,
410                    ep_result.map(|r| (r.0, Io::Output)),
411                )?;
412                writeln!(self.out)?;
413            }
414        }
415
416        self.write_special_functions(module)?;
417
418        self.write_wrapped_expression_functions(module, &module.global_expressions, None)?;
419        self.write_wrapped_zero_value_functions(module, &module.global_expressions)?;
420
421        // Write all named constants
422        let mut constants = module
423            .constants
424            .iter()
425            .filter(|&(_, c)| c.name.is_some())
426            .peekable();
427        while let Some((handle, _)) = constants.next() {
428            self.write_global_constant(module, handle)?;
429            // Add extra newline for readability on last iteration
430            if constants.peek().is_none() {
431                writeln!(self.out)?;
432            }
433        }
434
435        // Write all globals
436        for (global, _) in module.global_variables.iter() {
437            self.write_global(module, global)?;
438        }
439
440        if !module.global_variables.is_empty() {
441            // Add extra newline for readability
442            writeln!(self.out)?;
443        }
444
445        let ep_range = get_entry_points(module, self.pipeline_options.entry_point.as_ref())
446            .map_err(|(stage, name)| Error::EntryPointNotFound(stage, name))?;
447
448        // Write all entry points wrapped structs
449        for index in ep_range.clone() {
450            let ep = &module.entry_points[index];
451            let ep_name = self.names[&NameKey::EntryPoint(index as u16)].clone();
452            let ep_io = self.write_ep_interface(module, ep, &ep_name, fragment_entry_point)?;
453            self.entry_point_io.insert(index, ep_io);
454        }
455
456        // Write all regular functions
457        for (handle, function) in module.functions.iter() {
458            let info = &module_info[handle];
459
460            // Check if all of the globals are accessible
461            if !self.options.fake_missing_bindings {
462                if let Some((var_handle, _)) =
463                    module
464                        .global_variables
465                        .iter()
466                        .find(|&(var_handle, var)| match var.binding {
467                            Some(ref binding) if !info[var_handle].is_empty() => {
468                                self.options.resolve_resource_binding(binding).is_err()
469                                    && self
470                                        .options
471                                        .resolve_external_texture_resource_binding(binding)
472                                        .is_err()
473                            }
474                            _ => false,
475                        })
476                {
477                    log::debug!(
478                        "Skipping function {:?} (name {:?}) because global {:?} is inaccessible",
479                        handle,
480                        function.name,
481                        var_handle
482                    );
483                    continue;
484                }
485            }
486
487            let ctx = back::FunctionCtx {
488                ty: back::FunctionType::Function(handle),
489                info,
490                expressions: &function.expressions,
491                named_expressions: &function.named_expressions,
492            };
493            let name = self.names[&NameKey::Function(handle)].clone();
494
495            self.write_wrapped_functions(module, &ctx)?;
496
497            self.write_function(module, name.as_str(), function, &ctx, info, String::new())?;
498
499            writeln!(self.out)?;
500        }
501
502        let mut translated_ep_names = Vec::with_capacity(ep_range.len());
503
504        // Write all entry points
505        for index in ep_range {
506            let ep = &module.entry_points[index];
507            let info = module_info.get_entry_point(index);
508
509            if !self.options.fake_missing_bindings {
510                let mut ep_error = None;
511                for (var_handle, var) in module.global_variables.iter() {
512                    match var.binding {
513                        Some(ref binding) if !info[var_handle].is_empty() => {
514                            if let Err(err) = self.options.resolve_resource_binding(binding) {
515                                if self
516                                    .options
517                                    .resolve_external_texture_resource_binding(binding)
518                                    .is_err()
519                                {
520                                    ep_error = Some(err);
521                                    break;
522                                }
523                            }
524                        }
525                        _ => {}
526                    }
527                }
528                if let Some(err) = ep_error {
529                    translated_ep_names.push(Err(err));
530                    continue;
531                }
532            }
533
534            let ctx = back::FunctionCtx {
535                ty: back::FunctionType::EntryPoint(index as u16),
536                info,
537                expressions: &ep.function.expressions,
538                named_expressions: &ep.function.named_expressions,
539            };
540
541            self.write_wrapped_functions(module, &ctx)?;
542
543            // Mesh/task shaders have a wrapper entry point which is declared after the "main"
544            // user-written function. We therefore cannot always just document the next function.
545            let mut attribute_string = String::new();
546            if ep.stage.compute_like() {
547                // HLSL is calling workgroup size "num threads"
548                let num_threads = ep.workgroup_size;
549                writeln!(
550                    attribute_string,
551                    "[numthreads({}, {}, {})]",
552                    num_threads[0], num_threads[1], num_threads[2]
553                )?;
554            }
555            if let Some(ref info) = ep.mesh_info {
556                let topology_str = match info.topology {
557                    crate::MeshOutputTopology::Points => unreachable!(),
558                    crate::MeshOutputTopology::Lines => "line",
559                    crate::MeshOutputTopology::Triangles => "triangle",
560                };
561                writeln!(attribute_string, "[outputtopology(\"{topology_str}\")]")?;
562            }
563
564            let name = self.names[&NameKey::EntryPoint(index as u16)].clone();
565            self.write_function(module, &name, &ep.function, &ctx, info, attribute_string)?;
566
567            if index < module.entry_points.len() - 1 {
568                writeln!(self.out)?;
569            }
570
571            translated_ep_names.push(Ok(name));
572        }
573
574        Ok(super::ReflectionInfo {
575            entry_point_names: translated_ep_names,
576        })
577    }
578
579    fn write_modifier(&mut self, binding: &crate::Binding) -> BackendResult {
580        match *binding {
581            crate::Binding::BuiltIn(crate::BuiltIn::Position { invariant: true }) => {
582                write!(self.out, "precise ")?;
583            }
584            crate::Binding::BuiltIn(crate::BuiltIn::Barycentric { perspective: false }) => {
585                write!(self.out, "noperspective ")?;
586            }
587            crate::Binding::Location {
588                interpolation,
589                sampling,
590                ..
591            } => {
592                if let Some(interpolation) = interpolation {
593                    if let Some(string) = interpolation.to_hlsl_str() {
594                        write!(self.out, "{string} ")?
595                    }
596                }
597
598                if let Some(sampling) = sampling {
599                    if let Some(string) = sampling.to_hlsl_str() {
600                        write!(self.out, "{string} ")?
601                    }
602                }
603            }
604            crate::Binding::BuiltIn(_) => {}
605        }
606
607        Ok(())
608    }
609
610    //TODO: we could force fragment outputs to always go through `entry_point_io.output` path
611    // if they are struct, so that the `stage` argument here could be omitted.
612    pub(super) fn write_semantic(
613        &mut self,
614        binding: &Option<crate::Binding>,
615        stage: Option<(ShaderStage, Io)>,
616    ) -> BackendResult {
617        let is_per_primitive = match *binding {
618            Some(crate::Binding::BuiltIn(builtin)) if !is_subgroup_builtin_binding(binding) => {
619                if builtin == crate::BuiltIn::ViewIndex
620                    && self.options.shader_model < ShaderModel::V6_1
621                {
622                    return Err(Error::ShaderModelTooLow(
623                        "used @builtin(view_index) or SV_ViewID".to_string(),
624                        ShaderModel::V6_1,
625                    ));
626                }
627                if let Some(builtin_str) = builtin.to_hlsl_str()? {
628                    write!(self.out, " : {builtin_str}")?;
629                }
630                false
631            }
632            Some(crate::Binding::Location {
633                blend_src: Some(1),
634                per_primitive,
635                ..
636            }) => {
637                write!(self.out, " : SV_Target1")?;
638                per_primitive
639            }
640            Some(crate::Binding::Location {
641                location,
642                per_primitive,
643                ..
644            }) => {
645                if stage == Some((ShaderStage::Fragment, Io::Output)) {
646                    write!(self.out, " : SV_Target{location}")?;
647                } else {
648                    write!(self.out, " : {LOCATION_SEMANTIC}{location}")?;
649                }
650                per_primitive
651            }
652            _ => false,
653        };
654        if is_per_primitive {
655            write!(self.out, " : primitive")?;
656        }
657
658        Ok(())
659    }
660
661    pub(super) fn write_interface_struct(
662        &mut self,
663        module: &Module,
664        shader_stage: (ShaderStage, Io),
665        struct_name: String,
666        var_name: Option<&str>,
667        mut members: Vec<EpStructMember>,
668    ) -> Result<EntryPointBinding, Error> {
669        let struct_name = self.namer.call(&struct_name);
670        // Sort the members so that first come the user-defined varyings
671        // in ascending locations, and then built-ins. This allows VS and FS
672        // interfaces to match with regards to order.
673        members.sort_by_key(|m| InterfaceKey::new(m.binding.as_ref()));
674
675        write!(self.out, "struct {struct_name}")?;
676        writeln!(self.out, " {{")?;
677        let mut local_invocation_index_name = None;
678        let mut subgroup_id_used = false;
679        for m in members.iter() {
680            // Sanity check that each IO member is a built-in or is assigned a
681            // location. Also see note about nesting in `write_ep_input_struct`.
682            debug_assert!(m.binding.is_some());
683
684            match m.binding {
685                Some(crate::Binding::BuiltIn(crate::BuiltIn::SubgroupId)) => {
686                    subgroup_id_used = true;
687                }
688                Some(crate::Binding::BuiltIn(crate::BuiltIn::LocalInvocationIndex)) => {
689                    local_invocation_index_name = Some(m.name.clone());
690                }
691                _ => (),
692            }
693
694            if is_subgroup_builtin_binding(&m.binding) {
695                continue;
696            }
697            write!(self.out, "{}", back::INDENT)?;
698            if let Some(ref binding) = m.binding {
699                self.write_modifier(binding)?;
700            }
701            self.write_type(module, m.ty)?;
702            write!(self.out, " {}", &m.name)?;
703            self.write_semantic(&m.binding, Some(shader_stage))?;
704            writeln!(self.out, ";")?;
705        }
706        if subgroup_id_used && local_invocation_index_name.is_none() {
707            let name = self.namer.call("local_invocation_index");
708            writeln!(self.out, "{}uint {name} : SV_GroupIndex;", back::INDENT)?;
709            local_invocation_index_name = Some(name);
710        }
711        writeln!(self.out, "}};")?;
712        writeln!(self.out)?;
713
714        // See ordering notes on EntryPointInterface fields
715        match shader_stage.1 {
716            Io::Input => {
717                // bring back the original order
718                members.sort_by_key(|m| m.index);
719            }
720            Io::Output | Io::MeshVertices | Io::MeshPrimitives => {
721                // keep it sorted by binding
722            }
723        }
724
725        Ok(EntryPointBinding {
726            arg_name: self
727                .namer
728                .call(var_name.unwrap_or(struct_name.to_lowercase().as_str())),
729            ty_name: struct_name,
730            members,
731            local_invocation_index_name,
732        })
733    }
734
735    /// Flatten all entry point arguments into a single struct.
736    /// This is needed since we need to re-order them: first placing user locations,
737    /// then built-ins.
738    fn write_ep_input_struct(
739        &mut self,
740        module: &Module,
741        func: &crate::Function,
742        stage: ShaderStage,
743        entry_point_name: &str,
744    ) -> Result<EntryPointBinding, Error> {
745        let struct_name = format!("{stage:?}Input_{entry_point_name}");
746
747        let mut fake_members = Vec::new();
748        for arg in func.arguments.iter() {
749            // NOTE: We don't need to handle nesting structs. All members must
750            // be either built-ins or assigned a location. I.E. `binding` is
751            // `Some`. This is checked in `VaryingContext::validate`. See:
752            // https://gpuweb.github.io/gpuweb/wgsl/#input-output-locations
753            match module.types[arg.ty].inner {
754                TypeInner::Struct { ref members, .. } => {
755                    for member in members.iter() {
756                        let name = self.namer.call_or(&member.name, "member");
757                        let index = fake_members.len() as u32;
758                        fake_members.push(EpStructMember {
759                            name,
760                            ty: member.ty,
761                            binding: member.binding.clone(),
762                            index,
763                        });
764                    }
765                }
766                _ => {
767                    let member_name = self.namer.call_or(&arg.name, "member");
768                    let index = fake_members.len() as u32;
769                    fake_members.push(EpStructMember {
770                        name: member_name,
771                        ty: arg.ty,
772                        binding: arg.binding.clone(),
773                        index,
774                    });
775                }
776            }
777        }
778
779        self.write_interface_struct(module, (stage, Io::Input), struct_name, None, fake_members)
780    }
781
782    /// Flatten all entry point results into a single struct.
783    /// This is needed since we need to re-order them: first placing user locations,
784    /// then built-ins.
785    fn write_ep_output_struct(
786        &mut self,
787        module: &Module,
788        result: &crate::FunctionResult,
789        stage: ShaderStage,
790        entry_point_name: &str,
791        frag_ep: Option<&FragmentEntryPoint<'_>>,
792    ) -> Result<EntryPointBinding, Error> {
793        let struct_name = format!("{stage:?}Output_{entry_point_name}");
794
795        let empty = [];
796        let members = match module.types[result.ty].inner {
797            TypeInner::Struct { ref members, .. } => members,
798            ref other => {
799                log::error!("Unexpected {other:?} output type without a binding");
800                &empty[..]
801            }
802        };
803
804        // Gather list of fragment input locations. We use this below to remove user-defined
805        // varyings from VS outputs that aren't in the FS inputs. This makes the VS interface match
806        // as long as the FS inputs are a subset of the VS outputs. This is only applied if the
807        // writer is supplied with information about the fragment entry point.
808        let fs_input_locs = if let (Some(frag_ep), ShaderStage::Vertex) = (frag_ep, stage) {
809            let mut fs_input_locs = Vec::new();
810            for arg in frag_ep.func.arguments.iter() {
811                let mut push_if_location = |binding: &Option<crate::Binding>| match *binding {
812                    Some(crate::Binding::Location { location, .. }) => fs_input_locs.push(location),
813                    Some(crate::Binding::BuiltIn(_)) | None => {}
814                };
815
816                // NOTE: We don't need to handle struct nesting. See note in
817                // `write_ep_input_struct`.
818                match frag_ep.module.types[arg.ty].inner {
819                    TypeInner::Struct { ref members, .. } => {
820                        for member in members.iter() {
821                            push_if_location(&member.binding);
822                        }
823                    }
824                    _ => push_if_location(&arg.binding),
825                }
826            }
827            fs_input_locs.sort();
828            Some(fs_input_locs)
829        } else {
830            None
831        };
832
833        let mut fake_members = Vec::new();
834        for (index, member) in members.iter().enumerate() {
835            if let Some(ref fs_input_locs) = fs_input_locs {
836                match member.binding {
837                    Some(crate::Binding::Location { location, .. }) => {
838                        if fs_input_locs.binary_search(&location).is_err() {
839                            continue;
840                        }
841                    }
842                    Some(crate::Binding::BuiltIn(_)) | None => {}
843                }
844            }
845
846            let member_name = self.namer.call_or(&member.name, "member");
847            fake_members.push(EpStructMember {
848                name: member_name,
849                ty: member.ty,
850                binding: member.binding.clone(),
851                index: index as u32,
852            });
853        }
854
855        self.write_interface_struct(module, (stage, Io::Output), struct_name, None, fake_members)
856    }
857
858    /// Writes special interface structures for an entry point. The special structures have
859    /// all the fields flattened into them and sorted by binding. They are needed to emulate
860    /// subgroup built-ins and to make the interfaces between VS outputs and FS inputs match.
861    fn write_ep_interface(
862        &mut self,
863        module: &Module,
864        ep: &crate::EntryPoint,
865        ep_name: &str,
866        frag_ep: Option<&FragmentEntryPoint<'_>>,
867    ) -> Result<EntryPointInterface, Error> {
868        let func = &ep.function;
869        let stage = ep.stage;
870        Ok(EntryPointInterface {
871            input: if !func.arguments.is_empty()
872                && (stage == ShaderStage::Fragment
873                    || func
874                        .arguments
875                        .iter()
876                        .any(|arg| is_subgroup_builtin_binding(&arg.binding)))
877            {
878                Some(self.write_ep_input_struct(module, func, stage, ep_name)?)
879            } else {
880                None
881            },
882            output: match func.result {
883                Some(ref fr) if fr.binding.is_none() && stage == ShaderStage::Vertex => {
884                    Some(self.write_ep_output_struct(module, fr, stage, ep_name, frag_ep)?)
885                }
886                _ => None,
887            },
888            mesh_vertices: if let Some(ref info) = ep.mesh_info {
889                Some(self.write_ep_mesh_output_struct(module, ep_name, false, info)?)
890            } else {
891                None
892            },
893            mesh_primitives: if let Some(ref info) = ep.mesh_info {
894                Some(self.write_ep_mesh_output_struct(module, ep_name, true, info)?)
895            } else {
896                None
897            },
898            mesh_indices: if let Some(ref info) = ep.mesh_info {
899                Some(self.write_ep_mesh_output_indices(info.topology)?)
900            } else {
901                None
902            },
903        })
904    }
905
906    fn write_ep_argument_initialization(
907        &mut self,
908        ep: &crate::EntryPoint,
909        ep_input: &EntryPointBinding,
910        fake_member: &EpStructMember,
911    ) -> BackendResult {
912        match fake_member.binding {
913            Some(crate::Binding::BuiltIn(crate::BuiltIn::SubgroupSize)) => {
914                write!(self.out, "WaveGetLaneCount()")?
915            }
916            Some(crate::Binding::BuiltIn(crate::BuiltIn::SubgroupInvocationId)) => {
917                write!(self.out, "WaveGetLaneIndex()")?
918            }
919            Some(crate::Binding::BuiltIn(crate::BuiltIn::NumSubgroups)) => write!(
920                self.out,
921                "({}u + WaveGetLaneCount() - 1u) / WaveGetLaneCount()",
922                ep.workgroup_size[0] * ep.workgroup_size[1] * ep.workgroup_size[2]
923            )?,
924            Some(crate::Binding::BuiltIn(crate::BuiltIn::SubgroupId)) => {
925                write!(
926                    self.out,
927                    "{}.{} / WaveGetLaneCount()",
928                    ep_input.arg_name,
929                    // When writing SubgroupId, we always guarantee that local_invocation_index_name is written
930                    ep_input.local_invocation_index_name.as_ref().unwrap()
931                )?;
932            }
933            Some(crate::Binding::Location {
934                interpolation: Some(crate::Interpolation::PerVertex),
935                ..
936            }) => {
937                if self.options.shader_model < ShaderModel::V6_1 {
938                    return Err(Error::ShaderModelTooLow(
939                        "per_vertex fragment inputs".to_string(),
940                        ShaderModel::V6_1,
941                    ));
942                }
943                write!(
944                    self.out,
945                    "{{ GetAttributeAtVertex({0}.{1}, 0), GetAttributeAtVertex({0}.{1}, 1), GetAttributeAtVertex({0}.{1}, 2) }}",
946                    ep_input.arg_name,
947                    fake_member.name,
948                )?;
949            }
950            _ => {
951                write!(self.out, "{}.{}", ep_input.arg_name, fake_member.name)?;
952            }
953        }
954        Ok(())
955    }
956
957    /// Write an entry point preface that initializes the arguments as specified in IR.
958    fn write_ep_arguments_initialization(
959        &mut self,
960        module: &Module,
961        func: &crate::Function,
962        ep_index: u16,
963    ) -> BackendResult {
964        let ep = &module.entry_points[ep_index as usize];
965        let ep_input = match self
966            .entry_point_io
967            .get_mut(&(ep_index as usize))
968            .unwrap()
969            .input
970            .take()
971        {
972            Some(ep_input) => ep_input,
973            None => return Ok(()),
974        };
975        let mut fake_iter = ep_input.members.iter();
976        for (arg_index, arg) in func.arguments.iter().enumerate() {
977            write!(self.out, "{}", back::INDENT)?;
978            self.write_type(module, arg.ty)?;
979            let arg_name = &self.names[&NameKey::EntryPointArgument(ep_index, arg_index as u32)];
980            write!(self.out, " {arg_name}")?;
981            match module.types[arg.ty].inner {
982                TypeInner::Array { base, size, .. } => {
983                    self.write_array_size(module, base, size)?;
984                    write!(self.out, " = ")?;
985                    self.write_ep_argument_initialization(
986                        ep,
987                        &ep_input,
988                        fake_iter.next().unwrap(),
989                    )?;
990                    writeln!(self.out, ";")?;
991                }
992                TypeInner::Struct { ref members, .. } => {
993                    write!(self.out, " = {{ ")?;
994                    for index in 0..members.len() {
995                        if index != 0 {
996                            write!(self.out, ", ")?;
997                        }
998                        self.write_ep_argument_initialization(
999                            ep,
1000                            &ep_input,
1001                            fake_iter.next().unwrap(),
1002                        )?;
1003                    }
1004                    writeln!(self.out, " }};")?;
1005                }
1006                _ => {
1007                    write!(self.out, " = ")?;
1008                    self.write_ep_argument_initialization(
1009                        ep,
1010                        &ep_input,
1011                        fake_iter.next().unwrap(),
1012                    )?;
1013                    writeln!(self.out, ";")?;
1014                }
1015            }
1016        }
1017        assert!(fake_iter.next().is_none());
1018        Ok(())
1019    }
1020
1021    /// Helper method used to write global variables
1022    /// # Notes
1023    /// Always adds a newline
1024    fn write_global(
1025        &mut self,
1026        module: &Module,
1027        handle: Handle<crate::GlobalVariable>,
1028    ) -> BackendResult {
1029        let global = &module.global_variables[handle];
1030        let inner = &module.types[global.ty].inner;
1031
1032        let handle_ty = match *inner {
1033            TypeInner::BindingArray { ref base, .. } => &module.types[*base].inner,
1034            _ => inner,
1035        };
1036
1037        // External textures are handled entirely differently, so defer entirely to that method.
1038        // We do so prior to calling resolve_resource_binding() below, as we even need to resolve
1039        // their bindings separately.
1040        let is_external_texture = matches!(
1041            *handle_ty,
1042            TypeInner::Image {
1043                class: crate::ImageClass::External,
1044                ..
1045            }
1046        );
1047        if is_external_texture {
1048            return self.write_global_external_texture(module, handle, global);
1049        }
1050
1051        if let Some(ref binding) = global.binding {
1052            if let Err(err) = self.options.resolve_resource_binding(binding) {
1053                log::debug!(
1054                    "Skipping global {:?} (name {:?}) for being inaccessible: {}",
1055                    handle,
1056                    global.name,
1057                    err,
1058                );
1059                return Ok(());
1060            }
1061        }
1062
1063        // Samplers are handled entirely differently, so defer entirely to that method.
1064        let is_sampler = matches!(*handle_ty, TypeInner::Sampler { .. });
1065
1066        if is_sampler {
1067            return self.write_global_sampler(module, handle, global);
1068        }
1069
1070        // https://docs.microsoft.com/en-us/windows/win32/direct3dhlsl/dx-graphics-hlsl-variable-register
1071        let register_ty = match global.space {
1072            crate::AddressSpace::Function => unreachable!("Function address space"),
1073            crate::AddressSpace::Private => {
1074                write!(self.out, "static ")?;
1075                self.write_type(module, global.ty)?;
1076                ""
1077            }
1078            crate::AddressSpace::WorkGroup | crate::AddressSpace::TaskPayload => {
1079                write!(self.out, "groupshared ")?;
1080                self.write_type(module, global.ty)?;
1081                ""
1082            }
1083            crate::AddressSpace::Uniform => {
1084                // constant buffer declarations are expected to be inlined, e.g.
1085                // `cbuffer foo: register(b0) { field1: type1; }`
1086                write!(self.out, "cbuffer")?;
1087                "b"
1088            }
1089            crate::AddressSpace::Storage { access } => {
1090                if global
1091                    .memory_decorations
1092                    .contains(crate::MemoryDecorations::COHERENT)
1093                {
1094                    write!(self.out, "globallycoherent ")?;
1095                }
1096                let (prefix, register) = if access.contains(crate::StorageAccess::STORE) {
1097                    ("RW", "u")
1098                } else {
1099                    ("", "t")
1100                };
1101                write!(self.out, "{prefix}ByteAddressBuffer")?;
1102                register
1103            }
1104            crate::AddressSpace::Handle => {
1105                let register = match *handle_ty {
1106                    // all storage textures are UAV, unconditionally
1107                    TypeInner::Image {
1108                        class: crate::ImageClass::Storage { .. },
1109                        ..
1110                    } => "u",
1111                    _ => "t",
1112                };
1113                self.write_type(module, global.ty)?;
1114                register
1115            }
1116            crate::AddressSpace::Immediate => {
1117                // The type of the immediates will be wrapped in `ConstantBuffer`
1118                write!(self.out, "ConstantBuffer<")?;
1119                "b"
1120            }
1121            crate::AddressSpace::RayPayload | crate::AddressSpace::IncomingRayPayload => {
1122                unimplemented!()
1123            }
1124        };
1125
1126        // If the global is a immediate data write the type now because it will be a
1127        // generic argument to `ConstantBuffer`
1128        if global.space == crate::AddressSpace::Immediate {
1129            self.write_global_type(module, global.ty)?;
1130
1131            // need to write the array size if the type was emitted with `write_type`
1132            if let TypeInner::Array { base, size, .. } = module.types[global.ty].inner {
1133                self.write_array_size(module, base, size)?;
1134            }
1135
1136            // Close the angled brackets for the generic argument
1137            write!(self.out, ">")?;
1138        }
1139
1140        let name = &self.names[&NameKey::GlobalVariable(handle)];
1141        write!(self.out, " {name}")?;
1142
1143        // Immediates need to be assigned a binding explicitly by the consumer
1144        // since naga has no way to know the binding from the shader alone
1145        if global.space == crate::AddressSpace::Immediate {
1146            match module.types[global.ty].inner {
1147                TypeInner::Struct { .. } => {}
1148                _ => {
1149                    return Err(Error::Unimplemented(format!(
1150                        "push-constant '{name}' has non-struct type; tracked by: https://github.com/gfx-rs/wgpu/issues/5683"
1151                    )));
1152                }
1153            }
1154
1155            let target = self
1156                .options
1157                .immediates_target
1158                .as_ref()
1159                .expect("No bind target was defined for the immediates block");
1160            write!(self.out, ": register(b{}", target.register)?;
1161            if target.space != 0 {
1162                write!(self.out, ", space{}", target.space)?;
1163            }
1164            write!(self.out, ")")?;
1165        }
1166
1167        if let Some(ref binding) = global.binding {
1168            // this was already resolved earlier when we started evaluating an entry point.
1169            let bt = self.options.resolve_resource_binding(binding).unwrap();
1170
1171            // need to write the binding array size if the type was emitted with `write_type`
1172            if let TypeInner::BindingArray { base, size, .. } = module.types[global.ty].inner {
1173                if let Some(overridden_size) = bt.binding_array_size {
1174                    write!(self.out, "[{overridden_size}]")?;
1175                } else {
1176                    self.write_array_size(module, base, size)?;
1177                }
1178            }
1179
1180            write!(self.out, " : register({}{}", register_ty, bt.register)?;
1181            if bt.space != 0 {
1182                write!(self.out, ", space{}", bt.space)?;
1183            }
1184            write!(self.out, ")")?;
1185        } else {
1186            // need to write the array size if the type was emitted with `write_type`
1187            if let TypeInner::Array { base, size, .. } = module.types[global.ty].inner {
1188                self.write_array_size(module, base, size)?;
1189            }
1190            if global.space == crate::AddressSpace::Private {
1191                write!(self.out, " = ")?;
1192                if let Some(init) = global.init {
1193                    self.write_const_expression(module, init, &module.global_expressions)?;
1194                } else {
1195                    self.write_default_init(module, global.ty)?;
1196                }
1197            }
1198        }
1199
1200        if global.space == crate::AddressSpace::Uniform {
1201            write!(self.out, " {{ ")?;
1202
1203            self.write_global_type(module, global.ty)?;
1204
1205            write!(
1206                self.out,
1207                " {}",
1208                &self.names[&NameKey::GlobalVariable(handle)]
1209            )?;
1210
1211            // need to write the array size if the type was emitted with `write_type`
1212            if let TypeInner::Array { base, size, .. } = module.types[global.ty].inner {
1213                self.write_array_size(module, base, size)?;
1214            }
1215
1216            writeln!(self.out, "; }}")?;
1217        } else {
1218            writeln!(self.out, ";")?;
1219        }
1220
1221        Ok(())
1222    }
1223
1224    fn write_global_sampler(
1225        &mut self,
1226        module: &Module,
1227        handle: Handle<crate::GlobalVariable>,
1228        global: &crate::GlobalVariable,
1229    ) -> BackendResult {
1230        let binding = *global.binding.as_ref().unwrap();
1231
1232        let key = super::SamplerIndexBufferKey {
1233            group: binding.group,
1234        };
1235        self.write_wrapped_sampler_buffer(key)?;
1236
1237        // This was already validated, so we can confidently unwrap it.
1238        let bt = self.options.resolve_resource_binding(&binding).unwrap();
1239
1240        match module.types[global.ty].inner {
1241            TypeInner::Sampler { comparison } => {
1242                // If we are generating a static access, we create a variable for the sampler.
1243                //
1244                // This prevents the DXIL from containing multiple lookups for the sampler, which
1245                // the backend compiler will then have to eliminate. AMD does seem to be able to
1246                // eliminate these, but better safe than sorry.
1247
1248                write!(self.out, "static const ")?;
1249                self.write_type(module, global.ty)?;
1250
1251                let heap_var = if comparison {
1252                    COMPARISON_SAMPLER_HEAP_VAR
1253                } else {
1254                    SAMPLER_HEAP_VAR
1255                };
1256
1257                let index_buffer_name = &self.wrapped.sampler_index_buffers[&key];
1258                let name = &self.names[&NameKey::GlobalVariable(handle)];
1259                writeln!(
1260                    self.out,
1261                    " {name} = {heap_var}[{index_buffer_name}[{register}]];",
1262                    register = bt.register
1263                )?;
1264            }
1265            TypeInner::BindingArray { .. } => {
1266                // If we are generating a binding array, we cannot directly access the sampler as the index
1267                // into the sampler index buffer is unknown at compile time. Instead we generate a constant
1268                // that represents the "base" index into the sampler index buffer. This constant is added
1269                // to the user provided index to get the final index into the sampler index buffer.
1270
1271                let name = &self.names[&NameKey::GlobalVariable(handle)];
1272                writeln!(
1273                    self.out,
1274                    "static const uint {name} = {register};",
1275                    register = bt.register
1276                )?;
1277            }
1278            _ => unreachable!(),
1279        };
1280
1281        Ok(())
1282    }
1283
1284    /// Write the declarations for an external texture global variable.
1285    /// These are emitted as multiple global variables: Three `Texture2D`s
1286    /// (one for each plane) and a parameters cbuffer.
1287    fn write_global_external_texture(
1288        &mut self,
1289        module: &Module,
1290        handle: Handle<crate::GlobalVariable>,
1291        global: &crate::GlobalVariable,
1292    ) -> BackendResult {
1293        let res_binding = global
1294            .binding
1295            .as_ref()
1296            .expect("External texture global variables must have a resource binding");
1297        let ext_tex_bindings = match self
1298            .options
1299            .resolve_external_texture_resource_binding(res_binding)
1300        {
1301            Ok(bindings) => bindings,
1302            Err(err) => {
1303                log::debug!(
1304                    "Skipping global {:?} (name {:?}) for being inaccessible: {}",
1305                    handle,
1306                    global.name,
1307                    err,
1308                );
1309                return Ok(());
1310            }
1311        };
1312
1313        let mut write_plane = |bt: &super::BindTarget, name| -> BackendResult {
1314            write!(
1315                self.out,
1316                "Texture2D<float4> {}: register(t{}",
1317                name, bt.register
1318            )?;
1319            if bt.space != 0 {
1320                write!(self.out, ", space{}", bt.space)?;
1321            }
1322            writeln!(self.out, ");")?;
1323            Ok(())
1324        };
1325        for (i, bt) in ext_tex_bindings.planes.iter().enumerate() {
1326            let plane_name = &self.names
1327                [&NameKey::ExternalTextureGlobalVariable(handle, ExternalTextureNameKey::Plane(i))];
1328            write_plane(bt, plane_name)?;
1329        }
1330
1331        let params_name = &self.names
1332            [&NameKey::ExternalTextureGlobalVariable(handle, ExternalTextureNameKey::Params)];
1333        let params_ty_name =
1334            &self.names[&NameKey::Type(module.special_types.external_texture_params.unwrap())];
1335        write!(
1336            self.out,
1337            "cbuffer {}: register(b{}",
1338            params_name, ext_tex_bindings.params.register
1339        )?;
1340        if ext_tex_bindings.params.space != 0 {
1341            write!(self.out, ", space{}", ext_tex_bindings.params.space)?;
1342        }
1343        writeln!(self.out, ") {{ {params_ty_name} {params_name}; }};")?;
1344
1345        Ok(())
1346    }
1347
1348    /// Helper method used to write global constants
1349    ///
1350    /// # Notes
1351    /// Ends in a newline
1352    fn write_global_constant(
1353        &mut self,
1354        module: &Module,
1355        handle: Handle<crate::Constant>,
1356    ) -> BackendResult {
1357        write!(self.out, "static const ")?;
1358        let constant = &module.constants[handle];
1359        self.write_type(module, constant.ty)?;
1360        let name = &self.names[&NameKey::Constant(handle)];
1361        write!(self.out, " {name}")?;
1362        // Write size for array type
1363        if let TypeInner::Array { base, size, .. } = module.types[constant.ty].inner {
1364            self.write_array_size(module, base, size)?;
1365        }
1366        write!(self.out, " = ")?;
1367        self.write_const_expression(module, constant.init, &module.global_expressions)?;
1368        writeln!(self.out, ";")?;
1369        Ok(())
1370    }
1371
1372    pub(super) fn write_array_size(
1373        &mut self,
1374        module: &Module,
1375        base: Handle<crate::Type>,
1376        size: crate::ArraySize,
1377    ) -> BackendResult {
1378        write!(self.out, "[")?;
1379
1380        match size.resolve(module.to_ctx())? {
1381            proc::IndexableLength::Known(size) => {
1382                write!(self.out, "{size}")?;
1383            }
1384            proc::IndexableLength::Dynamic => unreachable!(),
1385        }
1386
1387        write!(self.out, "]")?;
1388
1389        if let TypeInner::Array {
1390            base: next_base,
1391            size: next_size,
1392            ..
1393        } = module.types[base].inner
1394        {
1395            self.write_array_size(module, next_base, next_size)?;
1396        }
1397
1398        Ok(())
1399    }
1400
1401    /// Helper method used to write structs
1402    ///
1403    /// # Notes
1404    /// Ends in a newline
1405    fn write_struct(
1406        &mut self,
1407        module: &Module,
1408        handle: Handle<crate::Type>,
1409        members: &[crate::StructMember],
1410        span: u32,
1411        shader_stage: Option<(ShaderStage, Io)>,
1412    ) -> BackendResult {
1413        // Write struct name
1414        let struct_name = &self.names[&NameKey::Type(handle)];
1415        writeln!(self.out, "struct {struct_name} {{")?;
1416
1417        let mut last_offset = 0;
1418        for (index, member) in members.iter().enumerate() {
1419            if member.binding.is_none() && member.offset > last_offset {
1420                // using int as padding should work as long as the backend
1421                // doesn't support a type that's less than 4 bytes in size
1422                // (Error::UnsupportedScalar catches this)
1423                let padding = (member.offset - last_offset) / 4;
1424                for i in 0..padding {
1425                    writeln!(self.out, "{}int _pad{}_{};", back::INDENT, index, i)?;
1426                }
1427            }
1428            let ty_inner = &module.types[member.ty].inner;
1429            last_offset = member.offset + ty_inner.size_hlsl(module.to_ctx())?;
1430
1431            // The indentation is only for readability
1432            write!(self.out, "{}", back::INDENT)?;
1433
1434            match module.types[member.ty].inner {
1435                TypeInner::Array { base, size, .. } => {
1436                    // HLSL arrays are written as `type name[size]`
1437
1438                    self.write_global_type(module, member.ty)?;
1439
1440                    // Write `name`
1441                    write!(
1442                        self.out,
1443                        " {}",
1444                        &self.names[&NameKey::StructMember(handle, index as u32)]
1445                    )?;
1446                    // Write [size]
1447                    self.write_array_size(module, base, size)?;
1448                }
1449                // We treat matrices of the form `matCx2` as a sequence of C `vec2`s.
1450                // See the module-level block comment in mod.rs for details.
1451                TypeInner::Matrix {
1452                    rows,
1453                    columns,
1454                    scalar,
1455                } if member.binding.is_none() && rows == crate::VectorSize::Bi => {
1456                    let vec_ty = TypeInner::Vector { size: rows, scalar };
1457                    let field_name_key = NameKey::StructMember(handle, index as u32);
1458
1459                    for i in 0..columns as u8 {
1460                        if i != 0 {
1461                            write!(self.out, "; ")?;
1462                        }
1463                        self.write_value_type(module, &vec_ty)?;
1464                        write!(self.out, " {}_{}", &self.names[&field_name_key], i)?;
1465                    }
1466                }
1467                _ => {
1468                    // Write modifier before type
1469                    if let Some(ref binding) = member.binding {
1470                        self.write_modifier(binding)?;
1471                    }
1472
1473                    // Even though Naga IR matrices are column-major, we must describe
1474                    // matrices passed from the CPU as being in row-major order.
1475                    // See the module-level block comment in mod.rs for details.
1476                    if let TypeInner::Matrix { .. } = module.types[member.ty].inner {
1477                        write!(self.out, "row_major ")?;
1478                    }
1479
1480                    // Write the member type and name
1481                    self.write_type(module, member.ty)?;
1482                    write!(
1483                        self.out,
1484                        " {}",
1485                        &self.names[&NameKey::StructMember(handle, index as u32)]
1486                    )?;
1487                }
1488            }
1489
1490            self.write_semantic(&member.binding, shader_stage)?;
1491            writeln!(self.out, ";")?;
1492        }
1493
1494        // add padding at the end since sizes of types don't get rounded up to their alignment in HLSL
1495        if members.last().unwrap().binding.is_none() && span > last_offset {
1496            let padding = (span - last_offset) / 4;
1497            for i in 0..padding {
1498                writeln!(self.out, "{}int _end_pad_{};", back::INDENT, i)?;
1499            }
1500        }
1501
1502        writeln!(self.out, "}};")?;
1503        Ok(())
1504    }
1505
1506    /// Helper method used to write global/structs non image/sampler types
1507    ///
1508    /// # Notes
1509    /// Adds no trailing or leading whitespace
1510    pub(super) fn write_global_type(
1511        &mut self,
1512        module: &Module,
1513        ty: Handle<crate::Type>,
1514    ) -> BackendResult {
1515        let matrix_data = get_inner_matrix_data(module, ty);
1516
1517        // We treat matrices of the form `matCx2` as a sequence of C `vec2`s.
1518        // See the module-level block comment in mod.rs for details.
1519        if let Some(MatrixType {
1520            columns,
1521            rows: crate::VectorSize::Bi,
1522            width,
1523        }) = matrix_data
1524        {
1525            write!(self.out, "__mat{}x2_f{}", columns as u8, width * 8)?;
1526        } else {
1527            // Even though Naga IR matrices are column-major, we must describe
1528            // matrices passed from the CPU as being in row-major order.
1529            // See the module-level block comment in mod.rs for details.
1530            if matrix_data.is_some() {
1531                write!(self.out, "row_major ")?;
1532            }
1533
1534            self.write_type(module, ty)?;
1535        }
1536
1537        Ok(())
1538    }
1539
1540    /// Helper method used to write non image/sampler types
1541    ///
1542    /// # Notes
1543    /// Adds no trailing or leading whitespace
1544    pub(super) fn write_type(&mut self, module: &Module, ty: Handle<crate::Type>) -> BackendResult {
1545        let inner = &module.types[ty].inner;
1546        match *inner {
1547            TypeInner::Struct { .. } => write!(self.out, "{}", self.names[&NameKey::Type(ty)])?,
1548            // hlsl array has the size separated from the base type
1549            TypeInner::Array { base, .. } | TypeInner::BindingArray { base, .. } => {
1550                self.write_type(module, base)?
1551            }
1552            ref other => self.write_value_type(module, other)?,
1553        }
1554
1555        Ok(())
1556    }
1557
1558    /// Helper method used to write value types
1559    ///
1560    /// # Notes
1561    /// Adds no trailing or leading whitespace
1562    pub(super) fn write_value_type(&mut self, module: &Module, inner: &TypeInner) -> BackendResult {
1563        match *inner {
1564            TypeInner::Scalar(scalar) | TypeInner::Atomic(scalar) => {
1565                write!(self.out, "{}", scalar.to_hlsl_str()?)?;
1566            }
1567            TypeInner::Vector { size, scalar } => {
1568                write!(
1569                    self.out,
1570                    "{}{}",
1571                    scalar.to_hlsl_str()?,
1572                    common::vector_size_str(size)
1573                )?;
1574            }
1575            TypeInner::Matrix {
1576                columns,
1577                rows,
1578                scalar,
1579            } => {
1580                // The IR supports only float matrix
1581                // https://docs.microsoft.com/en-us/windows/win32/direct3dhlsl/dx-graphics-hlsl-matrix
1582
1583                // Because of the implicit transpose all matrices have in HLSL, we need to transpose the size as well.
1584                write!(
1585                    self.out,
1586                    "{}{}x{}",
1587                    scalar.to_hlsl_str()?,
1588                    common::vector_size_str(columns),
1589                    common::vector_size_str(rows),
1590                )?;
1591            }
1592            TypeInner::Image {
1593                dim,
1594                arrayed,
1595                class,
1596            } => {
1597                self.write_image_type(dim, arrayed, class)?;
1598            }
1599            TypeInner::Sampler { comparison } => {
1600                let sampler = if comparison {
1601                    "SamplerComparisonState"
1602                } else {
1603                    "SamplerState"
1604                };
1605                write!(self.out, "{sampler}")?;
1606            }
1607            // HLSL arrays are written as `type name[size]`
1608            // Current code is written arrays only as `[size]`
1609            // Base `type` and `name` should be written outside
1610            TypeInner::Array { base, size, .. } | TypeInner::BindingArray { base, size } => {
1611                self.write_array_size(module, base, size)?;
1612            }
1613            TypeInner::AccelerationStructure { .. } => {
1614                write!(self.out, "RaytracingAccelerationStructure")?;
1615            }
1616            TypeInner::RayQuery { .. } => {
1617                // these are constant flags, there are dynamic flags also but constant flags are not supported by naga
1618                write!(self.out, "RayQuery<RAY_FLAG_NONE>")?;
1619            }
1620            _ => return Err(Error::Unimplemented(format!("write_value_type {inner:?}"))),
1621        }
1622
1623        Ok(())
1624    }
1625
1626    /// Helper method used to write functions
1627    /// # Notes
1628    /// Ends in a newline
1629    fn write_function(
1630        &mut self,
1631        module: &Module,
1632        name: &str,
1633        func: &crate::Function,
1634        func_ctx: &back::FunctionCtx<'_>,
1635        info: &valid::FunctionInfo,
1636        header: String,
1637    ) -> BackendResult {
1638        // Function Declaration Syntax - https://docs.microsoft.com/en-us/windows/win32/direct3dhlsl/dx-graphics-hlsl-function-syntax
1639
1640        self.update_expressions_to_bake(module, func, info);
1641        let ep = match func_ctx.ty {
1642            back::FunctionType::EntryPoint(idx) => Some(&module.entry_points[idx as usize]),
1643            back::FunctionType::Function(_) => None,
1644        };
1645
1646        let nested = matches!(
1647            ep,
1648            Some(crate::EntryPoint {
1649                stage: ShaderStage::Task | ShaderStage::Mesh,
1650                ..
1651            })
1652        );
1653        if !nested {
1654            write!(self.out, "{header}")?;
1655        }
1656
1657        if let Some(ref result) = func.result {
1658            // Write typedef if return type is an array
1659            let array_return_type = match module.types[result.ty].inner {
1660                TypeInner::Array { base, size, .. } => {
1661                    let array_return_type = self.namer.call(&format!("ret_{name}"));
1662                    write!(self.out, "typedef ")?;
1663                    self.write_type(module, result.ty)?;
1664                    write!(self.out, " {array_return_type}")?;
1665                    self.write_array_size(module, base, size)?;
1666                    writeln!(self.out, ";")?;
1667                    Some(array_return_type)
1668                }
1669                _ => None,
1670            };
1671
1672            // Write modifier
1673            if let Some(
1674                ref binding @ crate::Binding::BuiltIn(crate::BuiltIn::Position { invariant: true }),
1675            ) = result.binding
1676            {
1677                self.write_modifier(binding)?;
1678            }
1679
1680            // Write return type
1681            match func_ctx.ty {
1682                back::FunctionType::Function(_) => {
1683                    if let Some(array_return_type) = array_return_type {
1684                        write!(self.out, "{array_return_type}")?;
1685                    } else {
1686                        self.write_type(module, result.ty)?;
1687                    }
1688                }
1689                back::FunctionType::EntryPoint(index) => {
1690                    if let Some(ref ep_output) =
1691                        self.entry_point_io.get(&(index as usize)).unwrap().output
1692                    {
1693                        write!(self.out, "{}", ep_output.ty_name)?;
1694                    } else {
1695                        self.write_type(module, result.ty)?;
1696                    }
1697                }
1698            }
1699        } else {
1700            write!(self.out, "void")?;
1701        }
1702
1703        let nested_name = if nested {
1704            self.namer.call(&format!("_{name}"))
1705        } else {
1706            name.to_string()
1707        };
1708
1709        // Write function name
1710        write!(self.out, " {nested_name}(")?;
1711
1712        let need_workgroup_variables_initialization =
1713            self.need_workgroup_variables_initialization(func_ctx, module);
1714
1715        let mut any_args_written = false;
1716        let mut separator = || {
1717            if any_args_written {
1718                ", "
1719            } else {
1720                any_args_written = true;
1721                ""
1722            }
1723        };
1724
1725        let needs_local_invocation_index_name = need_workgroup_variables_initialization || nested;
1726        let mut local_invocation_index_name = None;
1727        // For nested entry points, collect arg names as we write them so that
1728        // write_nested_function_outer can pass the exact same names to the call site.
1729        let mut nested_wgsl_args: Vec<String> = Vec::new();
1730        let mut nested_task_payload_name: Option<String> = None;
1731        // Write function arguments for non entry point functions
1732        match func_ctx.ty {
1733            back::FunctionType::Function(handle) => {
1734                for (index, arg) in func.arguments.iter().enumerate() {
1735                    write!(self.out, "{}", separator())?;
1736                    self.write_function_argument(module, handle, arg, index)?;
1737                }
1738                // If this reads a task payload variable the variable needs to be passed as an `in` argument
1739                for (var_handle, var) in module.global_variables.iter() {
1740                    let uses = info[var_handle];
1741                    if uses.contains(valid::GlobalUse::READ)
1742                        && !uses.contains(valid::GlobalUse::WRITE)
1743                        && var.space == crate::AddressSpace::TaskPayload
1744                    {
1745                        self.function_task_payload_var.insert(handle, var_handle);
1746                        write!(self.out, "{}in ", separator())?;
1747
1748                        self.write_type(module, var.ty)?;
1749                        let name = &self.names[&NameKey::GlobalVariable(var_handle)];
1750                        write!(self.out, " {name}")?;
1751                        break;
1752                    }
1753                }
1754            }
1755            back::FunctionType::EntryPoint(ep_index) => {
1756                let ep = &module.entry_points[ep_index as usize];
1757                if let Some(ref ep_input) =
1758                    self.entry_point_io.get(&(ep_index as usize)).unwrap().input
1759                {
1760                    write!(self.out, "{} {}", ep_input.ty_name, ep_input.arg_name)?;
1761                    separator();
1762                    nested_wgsl_args.push(ep_input.arg_name.clone());
1763                } else {
1764                    let stage = ep.stage;
1765                    for (index, arg) in func.arguments.iter().enumerate() {
1766                        write!(self.out, "{}", separator())?;
1767                        self.write_type(module, arg.ty)?;
1768
1769                        let argument_name =
1770                            &self.names[&NameKey::EntryPointArgument(ep_index, index as u32)];
1771
1772                        if arg.binding
1773                            == Some(crate::Binding::BuiltIn(
1774                                crate::BuiltIn::LocalInvocationIndex,
1775                            ))
1776                        {
1777                            local_invocation_index_name = Some(argument_name.clone());
1778                        }
1779
1780                        nested_wgsl_args.push(argument_name.clone());
1781                        write!(self.out, " {argument_name}")?;
1782                        if let TypeInner::Array { base, size, .. } = module.types[arg.ty].inner {
1783                            self.write_array_size(module, base, size)?;
1784                        }
1785
1786                        self.write_semantic(&arg.binding, Some((stage, Io::Input)))?;
1787                    }
1788                }
1789                if ep.stage == ShaderStage::Mesh {
1790                    if let Some(var_handle) = ep.task_payload {
1791                        let var = &module.global_variables[var_handle];
1792                        write!(self.out, "{}in ", separator())?;
1793                        self.write_type(module, var.ty)?;
1794                        let arg_name = &self.names[&NameKey::GlobalVariable(var_handle)];
1795                        write!(self.out, " {arg_name}")?;
1796                        nested_task_payload_name = Some(arg_name.clone());
1797                        if let TypeInner::Array { base, size, .. } = module.types[var.ty].inner {
1798                            self.write_array_size(module, base, size)?;
1799                        }
1800                    }
1801                }
1802                if needs_local_invocation_index_name && local_invocation_index_name.is_none() {
1803                    let name = self.namer.call("local_invocation_index");
1804                    write!(self.out, "{}uint {name}", separator())?;
1805                    write!(self.out, " : SV_GroupIndex")?;
1806                    local_invocation_index_name = Some(name);
1807                }
1808            }
1809        }
1810        // Ends of arguments
1811        write!(self.out, ")")?;
1812
1813        // Write semantic if it present
1814        if let back::FunctionType::EntryPoint(index) = func_ctx.ty {
1815            let stage = module.entry_points[index as usize].stage;
1816            if let Some(crate::FunctionResult { ref binding, .. }) = func.result {
1817                self.write_semantic(binding, Some((stage, Io::Output)))?;
1818            }
1819        }
1820
1821        // Function body start
1822        writeln!(self.out)?;
1823        writeln!(self.out, "{{")?;
1824
1825        if need_workgroup_variables_initialization && !nested {
1826            let back::FunctionType::EntryPoint(index) = func_ctx.ty else {
1827                unreachable!();
1828            };
1829            writeln!(
1830                self.out,
1831                "{}if ({} == 0) {{",
1832                back::INDENT,
1833                // need_workgroup_variables_initialization forces this to be written
1834                // if the user doesn't specify it (so this must be Some())
1835                local_invocation_index_name.as_ref().unwrap(),
1836            )?;
1837            self.write_workgroup_variables_initialization(
1838                func_ctx,
1839                module,
1840                module.entry_points[index as usize].stage,
1841            )?;
1842
1843            writeln!(self.out, "{}}}", back::INDENT)?;
1844            self.write_control_barrier(crate::Barrier::WORK_GROUP, back::Level(1))?;
1845        }
1846
1847        if let back::FunctionType::EntryPoint(index) = func_ctx.ty {
1848            self.write_ep_arguments_initialization(module, func, index)?;
1849        }
1850
1851        // Write function local variables
1852        for (handle, local) in func.local_variables.iter() {
1853            // Write indentation (only for readability)
1854            write!(self.out, "{}", back::INDENT)?;
1855
1856            // Write the local name
1857            // The leading space is important
1858            self.write_type(module, local.ty)?;
1859            write!(self.out, " {}", self.names[&func_ctx.name_key(handle)])?;
1860            // Write size for array type
1861            if let TypeInner::Array { base, size, .. } = module.types[local.ty].inner {
1862                self.write_array_size(module, base, size)?;
1863            }
1864
1865            let is_ray_query = match module.types[local.ty].inner {
1866                // from https://microsoft.github.io/DirectX-Specs/d3d/Raytracing.html#tracerayinline-example-1 it seems that ray queries shouldn't be zeroed
1867                TypeInner::RayQuery { .. } => true,
1868                _ => {
1869                    write!(self.out, " = ")?;
1870                    // Write the local initializer if needed
1871                    if let Some(init) = local.init {
1872                        self.write_expr(module, init, func_ctx)?;
1873                    } else {
1874                        // Zero initialize local variables
1875                        self.write_default_init(module, local.ty)?;
1876                    }
1877                    false
1878                }
1879            };
1880            // Finish the local with `;` and add a newline (only for readability)
1881            writeln!(self.out, ";")?;
1882            // If it's a ray query, we also want a tracker variable
1883            if is_ray_query {
1884                write!(self.out, "{}", back::INDENT)?;
1885                self.write_value_type(module, &TypeInner::Scalar(Scalar::U32))?;
1886                writeln!(
1887                    self.out,
1888                    " {RAY_QUERY_TRACKER_VARIABLE_PREFIX}{} = 0;",
1889                    self.names[&func_ctx.name_key(handle)]
1890                )?;
1891            }
1892        }
1893
1894        if !func.local_variables.is_empty() {
1895            writeln!(self.out)?;
1896        }
1897
1898        // Write the function body (statement list)
1899        for sta in func.body.iter() {
1900            // The indentation should always be 1 when writing the function body
1901            self.write_stmt(module, sta, func_ctx, back::Level(1))?;
1902        }
1903
1904        writeln!(self.out, "}}")?;
1905
1906        if nested {
1907            self.write_nested_function_outer(
1908                module,
1909                func_ctx,
1910                &header,
1911                name,
1912                need_workgroup_variables_initialization,
1913                &nested_name,
1914                ep.unwrap(),
1915                NestedEntryPointArgs {
1916                    user_args: nested_wgsl_args,
1917                    task_payload: nested_task_payload_name,
1918                    // guaranteed to be set for nested functions (task/mesh shaders)
1919                    local_invocation_index: local_invocation_index_name.unwrap(),
1920                },
1921            )?;
1922        }
1923
1924        self.named_expressions.clear();
1925
1926        Ok(())
1927    }
1928
1929    fn write_function_argument(
1930        &mut self,
1931        module: &Module,
1932        handle: Handle<crate::Function>,
1933        arg: &crate::FunctionArgument,
1934        index: usize,
1935    ) -> BackendResult {
1936        // External texture arguments must be expanded into separate
1937        // arguments for each plane and the params buffer.
1938        if let TypeInner::Image {
1939            class: crate::ImageClass::External,
1940            ..
1941        } = module.types[arg.ty].inner
1942        {
1943            return self.write_function_external_texture_argument(module, handle, index);
1944        }
1945
1946        // Write argument type
1947        let arg_ty = match module.types[arg.ty].inner {
1948            // pointers in function arguments are expected and resolve to `inout`
1949            TypeInner::Pointer { base, .. } => {
1950                //TODO: can we narrow this down to just `in` when possible?
1951                write!(self.out, "inout ")?;
1952                base
1953            }
1954            _ => arg.ty,
1955        };
1956        self.write_type(module, arg_ty)?;
1957
1958        let argument_name = &self.names[&NameKey::FunctionArgument(handle, index as u32)];
1959
1960        // Write argument name. Space is important.
1961        write!(self.out, " {argument_name}")?;
1962        if let TypeInner::Array { base, size, .. } = module.types[arg_ty].inner {
1963            self.write_array_size(module, base, size)?;
1964        }
1965
1966        Ok(())
1967    }
1968
1969    fn write_function_external_texture_argument(
1970        &mut self,
1971        module: &Module,
1972        handle: Handle<crate::Function>,
1973        index: usize,
1974    ) -> BackendResult {
1975        let plane_names = [0, 1, 2].map(|i| {
1976            &self.names[&NameKey::ExternalTextureFunctionArgument(
1977                handle,
1978                index as u32,
1979                ExternalTextureNameKey::Plane(i),
1980            )]
1981        });
1982        let params_name = &self.names[&NameKey::ExternalTextureFunctionArgument(
1983            handle,
1984            index as u32,
1985            ExternalTextureNameKey::Params,
1986        )];
1987        let params_ty_name =
1988            &self.names[&NameKey::Type(module.special_types.external_texture_params.unwrap())];
1989        write!(
1990            self.out,
1991            "Texture2D<float4> {}, Texture2D<float4> {}, Texture2D<float4> {}, {params_ty_name} {params_name}",
1992            plane_names[0], plane_names[1], plane_names[2],
1993        )?;
1994        Ok(())
1995    }
1996
1997    fn need_workgroup_variables_initialization(
1998        &mut self,
1999        func_ctx: &back::FunctionCtx,
2000        module: &Module,
2001    ) -> bool {
2002        self.options.zero_initialize_workgroup_memory
2003            && func_ctx.ty.is_compute_like_entry_point(module)
2004            && module.global_variables.iter().any(|(handle, var)| {
2005                !func_ctx.info[handle].is_empty() && var.space.is_workgroup_like()
2006            })
2007    }
2008
2009    pub(super) fn write_workgroup_variables_initialization(
2010        &mut self,
2011        func_ctx: &back::FunctionCtx,
2012        module: &Module,
2013        stage: ShaderStage,
2014    ) -> BackendResult {
2015        let vars = module.global_variables.iter().filter(|&(handle, var)| {
2016            // Read-only in mesh shaders
2017            let task_needs_zero =
2018                (var.space == crate::AddressSpace::TaskPayload) && stage == ShaderStage::Task;
2019            !func_ctx.info[handle].is_empty()
2020                && (var.space == crate::AddressSpace::WorkGroup || task_needs_zero)
2021        });
2022
2023        for (handle, var) in vars {
2024            let name = &self.names[&NameKey::GlobalVariable(handle)];
2025            write!(self.out, "{}{} = ", back::Level(2), name)?;
2026            self.write_default_init(module, var.ty)?;
2027            writeln!(self.out, ";")?;
2028        }
2029        Ok(())
2030    }
2031
2032    /// Helper method used to write switches
2033    fn write_switch(
2034        &mut self,
2035        module: &Module,
2036        func_ctx: &back::FunctionCtx<'_>,
2037        level: back::Level,
2038        selector: Handle<crate::Expression>,
2039        cases: &[crate::SwitchCase],
2040    ) -> BackendResult {
2041        // Write all cases
2042        let indent_level_1 = level.next();
2043        let indent_level_2 = indent_level_1.next();
2044
2045        // See docs of `back::continue_forward` module.
2046        if let Some(variable) = self.continue_ctx.enter_switch(&mut self.namer) {
2047            writeln!(self.out, "{level}bool {variable} = false;",)?;
2048        };
2049
2050        // Check if there is only one body, by seeing if all except the last case are fall through
2051        // with empty bodies. FXC doesn't handle these switches correctly, so
2052        // we generate a `do {} while(false);` loop instead. There must be a default case, so there
2053        // is no need to check if one of the cases would have matched.
2054        let one_body = cases
2055            .iter()
2056            .rev()
2057            .skip(1)
2058            .all(|case| case.fall_through && case.body.is_empty());
2059        if one_body {
2060            // Start the do-while
2061            writeln!(self.out, "{level}do {{")?;
2062            // Note: Expressions have no side-effects so we don't need to emit selector expression.
2063
2064            // Body
2065            if let Some(case) = cases.last() {
2066                for sta in case.body.iter() {
2067                    self.write_stmt(module, sta, func_ctx, indent_level_1)?;
2068                }
2069            }
2070            // End do-while
2071            writeln!(self.out, "{level}}} while(false);")?;
2072        } else {
2073            // Start the switch
2074            write!(self.out, "{level}")?;
2075            write!(self.out, "switch(")?;
2076            self.write_expr(module, selector, func_ctx)?;
2077            writeln!(self.out, ") {{")?;
2078
2079            for (i, case) in cases.iter().enumerate() {
2080                match case.value {
2081                    crate::SwitchValue::I32(value) => {
2082                        write!(self.out, "{indent_level_1}case {value}:")?
2083                    }
2084                    crate::SwitchValue::U32(value) => {
2085                        write!(self.out, "{indent_level_1}case {value}u:")?
2086                    }
2087                    crate::SwitchValue::Default => write!(self.out, "{indent_level_1}default:")?,
2088                }
2089
2090                // The new block is not only stylistic, it plays a role here:
2091                // We might end up having to write the same case body
2092                // multiple times due to FXC not supporting fallthrough.
2093                // Therefore, some `Expression`s written by `Statement::Emit`
2094                // will end up having the same name (`_expr<handle_index>`).
2095                // So we need to put each case in its own scope.
2096                let write_block_braces = !(case.fall_through && case.body.is_empty());
2097                if write_block_braces {
2098                    writeln!(self.out, " {{")?;
2099                } else {
2100                    writeln!(self.out)?;
2101                }
2102
2103                // Although FXC does support a series of case clauses before
2104                // a block[^yes], it does not support fallthrough from a
2105                // non-empty case block to the next[^no]. If this case has a
2106                // non-empty body with a fallthrough, emulate that by
2107                // duplicating the bodies of all the cases it would fall
2108                // into as extensions of this case's own body. This makes
2109                // the HLSL output potentially quadratic in the size of the
2110                // Naga IR.
2111                //
2112                // [^yes]: ```hlsl
2113                // case 1:
2114                // case 2: do_stuff()
2115                // ```
2116                // [^no]: ```hlsl
2117                // case 1: do_this();
2118                // case 2: do_that();
2119                // ```
2120                if case.fall_through && !case.body.is_empty() {
2121                    let curr_len = i + 1;
2122                    let end_case_idx = curr_len
2123                        + cases
2124                            .iter()
2125                            .skip(curr_len)
2126                            .position(|case| !case.fall_through)
2127                            .unwrap();
2128                    let indent_level_3 = indent_level_2.next();
2129                    for case in &cases[i..=end_case_idx] {
2130                        writeln!(self.out, "{indent_level_2}{{")?;
2131                        let prev_len = self.named_expressions.len();
2132                        for sta in case.body.iter() {
2133                            self.write_stmt(module, sta, func_ctx, indent_level_3)?;
2134                        }
2135                        // Clear all named expressions that were previously inserted by the statements in the block
2136                        self.named_expressions.truncate(prev_len);
2137                        writeln!(self.out, "{indent_level_2}}}")?;
2138                    }
2139
2140                    let last_case = &cases[end_case_idx];
2141                    if last_case.body.last().is_none_or(|s| !s.is_terminator()) {
2142                        writeln!(self.out, "{indent_level_2}break;")?;
2143                    }
2144                } else {
2145                    for sta in case.body.iter() {
2146                        self.write_stmt(module, sta, func_ctx, indent_level_2)?;
2147                    }
2148                    if !case.fall_through && case.body.last().is_none_or(|s| !s.is_terminator()) {
2149                        writeln!(self.out, "{indent_level_2}break;")?;
2150                    }
2151                }
2152
2153                if write_block_braces {
2154                    writeln!(self.out, "{indent_level_1}}}")?;
2155                }
2156            }
2157
2158            writeln!(self.out, "{level}}}")?;
2159        }
2160
2161        // Handle any forwarded continue statements.
2162        use back::continue_forward::ExitControlFlow;
2163        let op = match self.continue_ctx.exit_switch() {
2164            ExitControlFlow::None => None,
2165            ExitControlFlow::Continue { variable } => Some(("continue", variable)),
2166            ExitControlFlow::Break { variable } => Some(("break", variable)),
2167        };
2168        if let Some((control_flow, variable)) = op {
2169            writeln!(self.out, "{level}if ({variable}) {{")?;
2170            writeln!(self.out, "{indent_level_1}{control_flow};")?;
2171            writeln!(self.out, "{level}}}")?;
2172        }
2173
2174        Ok(())
2175    }
2176
2177    fn write_index(
2178        &mut self,
2179        module: &Module,
2180        index: Index,
2181        func_ctx: &back::FunctionCtx<'_>,
2182    ) -> BackendResult {
2183        match index {
2184            Index::Static(index) => {
2185                write!(self.out, "{index}")?;
2186            }
2187            Index::Expression(index) => {
2188                self.write_expr(module, index, func_ctx)?;
2189            }
2190        }
2191        Ok(())
2192    }
2193
2194    /// Helper method used to write statements
2195    ///
2196    /// # Notes
2197    /// Always adds a newline
2198    fn write_stmt(
2199        &mut self,
2200        module: &Module,
2201        stmt: &crate::Statement,
2202        func_ctx: &back::FunctionCtx<'_>,
2203        level: back::Level,
2204    ) -> BackendResult {
2205        use crate::Statement;
2206
2207        match *stmt {
2208            Statement::Emit(ref range) => {
2209                for handle in range.clone() {
2210                    let ptr_class = func_ctx.resolve_type(handle, &module.types).pointer_space();
2211                    let expr_name = if ptr_class.is_some() {
2212                        // HLSL can't save a pointer-valued expression in a variable,
2213                        // but we shouldn't ever need to: they should never be named expressions,
2214                        // and none of the expression types flagged by bake_ref_count can be pointer-valued.
2215                        None
2216                    } else if let Some(name) = func_ctx.named_expressions.get(&handle) {
2217                        // Front end provides names for all variables at the start of writing.
2218                        // But we write them to step by step. We need to recache them
2219                        // Otherwise, we could accidentally write variable name instead of full expression.
2220                        // Also, we use sanitized names! It defense backend from generating variable with name from reserved keywords.
2221                        Some(self.namer.call(name))
2222                    } else if self.need_bake_expressions.contains(&handle) {
2223                        Some(Baked(handle).to_string())
2224                    } else {
2225                        None
2226                    };
2227
2228                    if let Some(name) = expr_name {
2229                        write!(self.out, "{level}")?;
2230                        self.write_named_expr(module, handle, name, handle, func_ctx)?;
2231                    }
2232                }
2233            }
2234            // TODO: copy-paste from glsl-out
2235            Statement::Block(ref block) => {
2236                write!(self.out, "{level}")?;
2237                writeln!(self.out, "{{")?;
2238                for sta in block.iter() {
2239                    // Increase the indentation to help with readability
2240                    self.write_stmt(module, sta, func_ctx, level.next())?
2241                }
2242                writeln!(self.out, "{level}}}")?
2243            }
2244            // TODO: copy-paste from glsl-out
2245            Statement::If {
2246                condition,
2247                ref accept,
2248                ref reject,
2249            } => {
2250                write!(self.out, "{level}")?;
2251                write!(self.out, "if (")?;
2252                self.write_expr(module, condition, func_ctx)?;
2253                writeln!(self.out, ") {{")?;
2254
2255                let l2 = level.next();
2256                for sta in accept {
2257                    // Increase indentation to help with readability
2258                    self.write_stmt(module, sta, func_ctx, l2)?;
2259                }
2260
2261                // If there are no statements in the reject block we skip writing it
2262                // This is only for readability
2263                if !reject.is_empty() {
2264                    writeln!(self.out, "{level}}} else {{")?;
2265
2266                    for sta in reject {
2267                        // Increase indentation to help with readability
2268                        self.write_stmt(module, sta, func_ctx, l2)?;
2269                    }
2270                }
2271
2272                writeln!(self.out, "{level}}}")?
2273            }
2274            // TODO: copy-paste from glsl-out
2275            Statement::Kill => writeln!(self.out, "{level}discard;")?,
2276            Statement::Return { value: None } => {
2277                writeln!(self.out, "{level}return;")?;
2278            }
2279            Statement::Return { value: Some(expr) } => {
2280                let base_ty_res = &func_ctx.info[expr].ty;
2281                let mut resolved = base_ty_res.inner_with(&module.types);
2282                if let TypeInner::Pointer { base, space: _ } = *resolved {
2283                    resolved = &module.types[base].inner;
2284                }
2285
2286                if let TypeInner::Struct { .. } = *resolved {
2287                    // We can safely unwrap here, since we now we working with struct
2288                    let ty = base_ty_res.handle().unwrap();
2289                    let struct_name = &self.names[&NameKey::Type(ty)];
2290                    let variable_name = self.namer.call(&struct_name.to_lowercase());
2291                    write!(self.out, "{level}const {struct_name} {variable_name} = ",)?;
2292                    self.write_expr(module, expr, func_ctx)?;
2293                    writeln!(self.out, ";")?;
2294
2295                    // for entry point returns, we may need to reshuffle the outputs into a different struct
2296                    let ep_output = match func_ctx.ty {
2297                        back::FunctionType::Function(_) => None,
2298                        back::FunctionType::EntryPoint(index) => self
2299                            .entry_point_io
2300                            .get(&(index as usize))
2301                            .unwrap()
2302                            .output
2303                            .as_ref(),
2304                    };
2305                    let final_name = match ep_output {
2306                        Some(ep_output) => {
2307                            let final_name = self.namer.call(&variable_name);
2308                            write!(
2309                                self.out,
2310                                "{}const {} {} = {{ ",
2311                                level, ep_output.ty_name, final_name,
2312                            )?;
2313                            for (index, m) in ep_output.members.iter().enumerate() {
2314                                if index != 0 {
2315                                    write!(self.out, ", ")?;
2316                                }
2317                                let member_name = &self.names[&NameKey::StructMember(ty, m.index)];
2318                                write!(self.out, "{variable_name}.{member_name}")?;
2319                            }
2320                            writeln!(self.out, " }};")?;
2321                            final_name
2322                        }
2323                        None => variable_name,
2324                    };
2325                    writeln!(self.out, "{level}return {final_name};")?;
2326                } else {
2327                    write!(self.out, "{level}return ")?;
2328                    self.write_expr(module, expr, func_ctx)?;
2329                    writeln!(self.out, ";")?
2330                }
2331            }
2332            Statement::Store { pointer, value } => {
2333                let ty_inner = func_ctx.resolve_type(pointer, &module.types);
2334                if ty_inner.is_atomic_pointer(&module.types) {
2335                    let pointer_space = ty_inner.pointer_space().unwrap();
2336                    let dummy = self.namer.call("dummy");
2337                    write!(self.out, "{level}{{ ")?;
2338                    if let TypeInner::Pointer { base, .. } = *ty_inner {
2339                        self.write_value_type(module, &module.types[base].inner)?;
2340                    }
2341                    write!(self.out, " {dummy} = 0; ")?;
2342                    match pointer_space {
2343                        crate::AddressSpace::WorkGroup => {
2344                            write!(self.out, "InterlockedExchange(")?;
2345                            self.write_expr(module, pointer, func_ctx)?;
2346                        }
2347                        crate::AddressSpace::Storage { .. } => {
2348                            let var_handle = self.fill_access_chain(module, pointer, func_ctx)?;
2349                            let var_name = &self.names[&NameKey::GlobalVariable(var_handle)];
2350                            write!(self.out, "{var_name}.InterlockedExchange(")?;
2351                            let chain = mem::take(&mut self.temp_access_chain);
2352                            self.write_storage_address(module, &chain, func_ctx)?;
2353                            self.temp_access_chain = chain;
2354                        }
2355                        _ => unreachable!(),
2356                    }
2357                    write!(self.out, ", ")?;
2358                    self.write_expr(module, value, func_ctx)?;
2359                    writeln!(self.out, ", {dummy}); }}")?;
2360                } else if let Some(crate::AddressSpace::Storage { .. }) = ty_inner.pointer_space() {
2361                    let var_handle = self.fill_access_chain(module, pointer, func_ctx)?;
2362                    self.write_storage_store(
2363                        module,
2364                        var_handle,
2365                        StoreValue::Expression(value),
2366                        func_ctx,
2367                        level,
2368                        None,
2369                    )?;
2370                } else {
2371                    // We treat matrices of the form `matCx2` as a sequence of C `vec2`s.
2372                    // See the module-level block comment in mod.rs for details.
2373                    //
2374                    // We handle matrix Stores here directly (including sub accesses for Vectors and Scalars).
2375                    // Loads are handled by `Expression::AccessIndex` (since sub accesses work fine for Loads).
2376                    enum MatrixAccess {
2377                        Direct {
2378                            base: Handle<crate::Expression>,
2379                            index: u32,
2380                        },
2381                        Struct {
2382                            columns: crate::VectorSize,
2383                            width: u8,
2384                            base: Handle<crate::Expression>,
2385                        },
2386                    }
2387
2388                    let get_members = |expr: Handle<crate::Expression>| {
2389                        let resolved = func_ctx.resolve_type(expr, &module.types);
2390                        match *resolved {
2391                            TypeInner::Pointer { base, .. } => match module.types[base].inner {
2392                                TypeInner::Struct { ref members, .. } => Some(members),
2393                                _ => None,
2394                            },
2395                            _ => None,
2396                        }
2397                    };
2398
2399                    write!(self.out, "{level}")?;
2400
2401                    let matrix_access_on_lhs =
2402                        find_matrix_in_access_chain(module, pointer, func_ctx).and_then(
2403                            |(matrix_expr, vector, scalar)| match (
2404                                func_ctx.resolve_type(matrix_expr, &module.types),
2405                                &func_ctx.expressions[matrix_expr],
2406                            ) {
2407                                (
2408                                    &TypeInner::Pointer { base: ty, .. },
2409                                    &crate::Expression::AccessIndex { base, index },
2410                                ) if matches!(
2411                                    module.types[ty].inner,
2412                                    TypeInner::Matrix {
2413                                        rows: crate::VectorSize::Bi,
2414                                        ..
2415                                    }
2416                                ) && get_members(base)
2417                                    .map(|members| members[index as usize].binding.is_none())
2418                                    == Some(true) =>
2419                                {
2420                                    Some((MatrixAccess::Direct { base, index }, vector, scalar))
2421                                }
2422                                _ => {
2423                                    if let Some(MatrixType {
2424                                        columns,
2425                                        rows: crate::VectorSize::Bi,
2426                                        width,
2427                                    }) = get_inner_matrix_of_struct_array_member(
2428                                        module,
2429                                        matrix_expr,
2430                                        func_ctx,
2431                                        true,
2432                                    ) {
2433                                        Some((
2434                                            MatrixAccess::Struct {
2435                                                columns,
2436                                                width,
2437                                                base: matrix_expr,
2438                                            },
2439                                            vector,
2440                                            scalar,
2441                                        ))
2442                                    } else {
2443                                        None
2444                                    }
2445                                }
2446                            },
2447                        );
2448
2449                    match matrix_access_on_lhs {
2450                        Some((MatrixAccess::Direct { index, base }, vector, scalar)) => {
2451                            let base_ty_res = &func_ctx.info[base].ty;
2452                            let resolved = base_ty_res.inner_with(&module.types);
2453                            let ty = match *resolved {
2454                                TypeInner::Pointer { base, .. } => base,
2455                                _ => base_ty_res.handle().unwrap(),
2456                            };
2457
2458                            if let Some(Index::Static(vec_index)) = vector {
2459                                self.write_expr(module, base, func_ctx)?;
2460                                write!(
2461                                    self.out,
2462                                    ".{}_{}",
2463                                    &self.names[&NameKey::StructMember(ty, index)],
2464                                    vec_index
2465                                )?;
2466
2467                                if let Some(scalar_index) = scalar {
2468                                    write!(self.out, "[")?;
2469                                    self.write_index(module, scalar_index, func_ctx)?;
2470                                    write!(self.out, "]")?;
2471                                }
2472
2473                                write!(self.out, " = ")?;
2474                                self.write_expr(module, value, func_ctx)?;
2475                                writeln!(self.out, ";")?;
2476                            } else {
2477                                let access = WrappedStructMatrixAccess { ty, index };
2478                                match (&vector, &scalar) {
2479                                    (&Some(_), &Some(_)) => {
2480                                        self.write_wrapped_struct_matrix_set_scalar_function_name(
2481                                            access,
2482                                        )?;
2483                                    }
2484                                    (&Some(_), &None) => {
2485                                        self.write_wrapped_struct_matrix_set_vec_function_name(
2486                                            access,
2487                                        )?;
2488                                    }
2489                                    (&None, _) => {
2490                                        self.write_wrapped_struct_matrix_set_function_name(access)?;
2491                                    }
2492                                }
2493
2494                                write!(self.out, "(")?;
2495                                self.write_expr(module, base, func_ctx)?;
2496                                write!(self.out, ", ")?;
2497                                self.write_expr(module, value, func_ctx)?;
2498
2499                                if let Some(Index::Expression(vec_index)) = vector {
2500                                    write!(self.out, ", ")?;
2501                                    self.write_expr(module, vec_index, func_ctx)?;
2502
2503                                    if let Some(scalar_index) = scalar {
2504                                        write!(self.out, ", ")?;
2505                                        self.write_index(module, scalar_index, func_ctx)?;
2506                                    }
2507                                }
2508                                writeln!(self.out, ");")?;
2509                            }
2510                        }
2511                        Some((
2512                            MatrixAccess::Struct {
2513                                columns,
2514                                width,
2515                                base,
2516                            },
2517                            Some(Index::Expression(vec_index)),
2518                            scalar,
2519                        )) => {
2520                            // We handle `Store`s to __matCx2 column vectors and scalar elements via
2521                            // the previously injected functions __set_col_of_matCx2 / __set_el_of_matCx2.
2522
2523                            if scalar.is_some() {
2524                                write!(
2525                                    self.out,
2526                                    "__set_el_of_mat{}x2_f{}",
2527                                    columns as u8,
2528                                    width * 8
2529                                )?;
2530                            } else {
2531                                write!(
2532                                    self.out,
2533                                    "__set_col_of_mat{}x2_f{}",
2534                                    columns as u8,
2535                                    width * 8
2536                                )?;
2537                            }
2538                            write!(self.out, "(")?;
2539                            self.write_expr(module, base, func_ctx)?;
2540                            write!(self.out, ", ")?;
2541                            self.write_expr(module, vec_index, func_ctx)?;
2542
2543                            if let Some(scalar_index) = scalar {
2544                                write!(self.out, ", ")?;
2545                                self.write_index(module, scalar_index, func_ctx)?;
2546                            }
2547
2548                            write!(self.out, ", ")?;
2549                            self.write_expr(module, value, func_ctx)?;
2550
2551                            writeln!(self.out, ");")?;
2552                        }
2553                        Some((MatrixAccess::Struct { .. }, Some(Index::Static(_)), _))
2554                        | Some((MatrixAccess::Struct { .. }, None, _))
2555                        | None => {
2556                            self.write_expr(module, pointer, func_ctx)?;
2557                            write!(self.out, " = ")?;
2558
2559                            // We cast the RHS of this store in cases where the LHS
2560                            // is a struct member with type:
2561                            //  - matCx2 or
2562                            //  - a (possibly nested) array of matCx2's
2563                            if let Some(MatrixType {
2564                                columns,
2565                                rows: crate::VectorSize::Bi,
2566                                width,
2567                            }) = get_inner_matrix_of_struct_array_member(
2568                                module, pointer, func_ctx, false,
2569                            ) {
2570                                let mut resolved = func_ctx.resolve_type(pointer, &module.types);
2571                                if let TypeInner::Pointer { base, .. } = *resolved {
2572                                    resolved = &module.types[base].inner;
2573                                }
2574
2575                                write!(self.out, "(__mat{}x2_f{}", columns as u8, width * 8)?;
2576                                if let TypeInner::Array { base, size, .. } = *resolved {
2577                                    self.write_array_size(module, base, size)?;
2578                                }
2579                                write!(self.out, ")")?;
2580                            }
2581
2582                            self.write_expr(module, value, func_ctx)?;
2583                            writeln!(self.out, ";")?
2584                        }
2585                    }
2586                }
2587            }
2588            Statement::Loop {
2589                ref body,
2590                ref continuing,
2591                break_if,
2592            } => {
2593                let force_loop_bound_statements = self.gen_force_bounded_loop_statements(level);
2594                let gate_name = (!continuing.is_empty() || break_if.is_some())
2595                    .then(|| self.namer.call("loop_init"));
2596
2597                if let Some((ref decl, _)) = force_loop_bound_statements {
2598                    writeln!(self.out, "{decl}")?;
2599                }
2600                if let Some(ref gate_name) = gate_name {
2601                    writeln!(self.out, "{level}bool {gate_name} = true;")?;
2602                }
2603
2604                self.continue_ctx.enter_loop();
2605                writeln!(self.out, "{level}while(true) {{")?;
2606                if let Some((_, ref break_and_inc)) = force_loop_bound_statements {
2607                    writeln!(self.out, "{break_and_inc}")?;
2608                }
2609                let l2 = level.next();
2610                if let Some(gate_name) = gate_name {
2611                    writeln!(self.out, "{l2}if (!{gate_name}) {{")?;
2612                    let l3 = l2.next();
2613                    for sta in continuing.iter() {
2614                        self.write_stmt(module, sta, func_ctx, l3)?;
2615                    }
2616                    if let Some(condition) = break_if {
2617                        write!(self.out, "{l3}if (")?;
2618                        self.write_expr(module, condition, func_ctx)?;
2619                        writeln!(self.out, ") {{")?;
2620                        writeln!(self.out, "{}break;", l3.next())?;
2621                        writeln!(self.out, "{l3}}}")?;
2622                    }
2623                    writeln!(self.out, "{l2}}}")?;
2624                    writeln!(self.out, "{l2}{gate_name} = false;")?;
2625                }
2626
2627                for sta in body.iter() {
2628                    self.write_stmt(module, sta, func_ctx, l2)?;
2629                }
2630
2631                writeln!(self.out, "{level}}}")?;
2632                self.continue_ctx.exit_loop();
2633            }
2634            Statement::Break => writeln!(self.out, "{level}break;")?,
2635            Statement::Continue => {
2636                if let Some(variable) = self.continue_ctx.continue_encountered() {
2637                    writeln!(self.out, "{level}{variable} = true;")?;
2638                    writeln!(self.out, "{level}break;")?
2639                } else {
2640                    writeln!(self.out, "{level}continue;")?
2641                }
2642            }
2643            Statement::ControlBarrier(barrier) => {
2644                self.write_control_barrier(barrier, level)?;
2645            }
2646            Statement::MemoryBarrier(barrier) => {
2647                self.write_memory_barrier(barrier, level)?;
2648            }
2649            Statement::ImageStore {
2650                image,
2651                coordinate,
2652                array_index,
2653                value,
2654            } => {
2655                write!(self.out, "{level}")?;
2656                self.write_expr(module, image, func_ctx)?;
2657
2658                write!(self.out, "[")?;
2659                if let Some(index) = array_index {
2660                    // Array index accepted only for texture_storage_2d_array, so we can safety use int3(coordinate, array_index) here
2661                    write!(self.out, "int3(")?;
2662                    self.write_expr(module, coordinate, func_ctx)?;
2663                    write!(self.out, ", ")?;
2664                    self.write_expr(module, index, func_ctx)?;
2665                    write!(self.out, ")")?;
2666                } else {
2667                    self.write_expr(module, coordinate, func_ctx)?;
2668                }
2669                write!(self.out, "]")?;
2670
2671                write!(self.out, " = ")?;
2672                self.write_expr(module, value, func_ctx)?;
2673                writeln!(self.out, ";")?;
2674            }
2675            Statement::Call {
2676                function,
2677                ref arguments,
2678                result,
2679            } => {
2680                write!(self.out, "{level}")?;
2681
2682                if let Some(expr) = result {
2683                    write!(self.out, "const ")?;
2684                    let name = Baked(expr).to_string();
2685                    let expr_ty = &func_ctx.info[expr].ty;
2686                    let ty_inner = match *expr_ty {
2687                        proc::TypeResolution::Handle(handle) => {
2688                            self.write_type(module, handle)?;
2689                            &module.types[handle].inner
2690                        }
2691                        proc::TypeResolution::Value(ref value) => {
2692                            self.write_value_type(module, value)?;
2693                            value
2694                        }
2695                    };
2696                    write!(self.out, " {name}")?;
2697                    if let TypeInner::Array { base, size, .. } = *ty_inner {
2698                        self.write_array_size(module, base, size)?;
2699                    }
2700                    write!(self.out, " = ")?;
2701                    self.named_expressions.insert(expr, name);
2702                }
2703                let func_name = &self.names[&NameKey::Function(function)];
2704                write!(self.out, "{func_name}(")?;
2705                let mut any_args_written = false;
2706                let mut separator = || {
2707                    if any_args_written {
2708                        ", "
2709                    } else {
2710                        any_args_written = true;
2711                        ""
2712                    }
2713                };
2714                for argument in arguments {
2715                    write!(self.out, "{}", separator())?;
2716                    self.write_expr(module, *argument, func_ctx)?;
2717                }
2718                if let Some(&var) = self.function_task_payload_var.get(&function) {
2719                    let name = &self.names[&NameKey::GlobalVariable(var)];
2720                    // Pass it through directly, whether its an in variable to this function or the global variable
2721                    write!(self.out, "{}{name}", separator())?;
2722                }
2723                writeln!(self.out, ");")?;
2724            }
2725            Statement::Atomic {
2726                pointer,
2727                ref fun,
2728                value,
2729                result,
2730            } => {
2731                write!(self.out, "{level}")?;
2732                let res_var_info = if let Some(res_handle) = result {
2733                    let name = Baked(res_handle).to_string();
2734                    match func_ctx.info[res_handle].ty {
2735                        proc::TypeResolution::Handle(handle) => self.write_type(module, handle)?,
2736                        proc::TypeResolution::Value(ref value) => {
2737                            self.write_value_type(module, value)?
2738                        }
2739                    };
2740                    write!(self.out, " {name}; ")?;
2741                    self.named_expressions.insert(res_handle, name.clone());
2742                    Some((res_handle, name))
2743                } else {
2744                    None
2745                };
2746                let pointer_space = func_ctx
2747                    .resolve_type(pointer, &module.types)
2748                    .pointer_space()
2749                    .unwrap();
2750                let fun_str = fun.to_hlsl_suffix();
2751                let compare_expr = match *fun {
2752                    crate::AtomicFunction::Exchange { compare: Some(cmp) } => Some(cmp),
2753                    _ => None,
2754                };
2755                match pointer_space {
2756                    crate::AddressSpace::WorkGroup => {
2757                        write!(self.out, "Interlocked{fun_str}(")?;
2758                        self.write_expr(module, pointer, func_ctx)?;
2759                        self.emit_hlsl_atomic_tail(
2760                            module,
2761                            func_ctx,
2762                            fun,
2763                            compare_expr,
2764                            value,
2765                            &res_var_info,
2766                        )?;
2767                    }
2768                    crate::AddressSpace::Storage { .. } => {
2769                        let var_handle = self.fill_access_chain(module, pointer, func_ctx)?;
2770                        let var_name = &self.names[&NameKey::GlobalVariable(var_handle)];
2771                        let width = match func_ctx.resolve_type(value, &module.types) {
2772                            &TypeInner::Scalar(Scalar { width: 8, .. }) => "64",
2773                            _ => "",
2774                        };
2775                        write!(self.out, "{var_name}.Interlocked{fun_str}{width}(")?;
2776                        let chain = mem::take(&mut self.temp_access_chain);
2777                        self.write_storage_address(module, &chain, func_ctx)?;
2778                        self.temp_access_chain = chain;
2779                        self.emit_hlsl_atomic_tail(
2780                            module,
2781                            func_ctx,
2782                            fun,
2783                            compare_expr,
2784                            value,
2785                            &res_var_info,
2786                        )?;
2787                    }
2788                    ref other => {
2789                        return Err(Error::Custom(format!(
2790                            "invalid address space {other:?} for atomic statement"
2791                        )))
2792                    }
2793                }
2794                if let Some(cmp) = compare_expr {
2795                    if let Some(&(_res_handle, ref res_name)) = res_var_info.as_ref() {
2796                        write!(
2797                            self.out,
2798                            "{level}{res_name}.exchanged = ({res_name}.old_value == "
2799                        )?;
2800                        self.write_expr(module, cmp, func_ctx)?;
2801                        writeln!(self.out, ");")?;
2802                    }
2803                }
2804            }
2805            Statement::ImageAtomic {
2806                image,
2807                coordinate,
2808                array_index,
2809                fun,
2810                value,
2811            } => {
2812                write!(self.out, "{level}")?;
2813
2814                let fun_str = fun.to_hlsl_suffix();
2815                write!(self.out, "Interlocked{fun_str}(")?;
2816                self.write_expr(module, image, func_ctx)?;
2817                write!(self.out, "[")?;
2818                self.write_texture_coordinates(
2819                    "int",
2820                    coordinate,
2821                    array_index,
2822                    None,
2823                    module,
2824                    func_ctx,
2825                )?;
2826                write!(self.out, "],")?;
2827
2828                self.write_expr(module, value, func_ctx)?;
2829                writeln!(self.out, ");")?;
2830            }
2831            Statement::WorkGroupUniformLoad { pointer, result } => {
2832                self.write_control_barrier(crate::Barrier::WORK_GROUP, level)?;
2833                write!(self.out, "{level}")?;
2834                let name = Baked(result).to_string();
2835                self.write_named_expr(module, pointer, name, result, func_ctx)?;
2836
2837                self.write_control_barrier(crate::Barrier::WORK_GROUP, level)?;
2838            }
2839            Statement::Switch {
2840                selector,
2841                ref cases,
2842            } => {
2843                self.write_switch(module, func_ctx, level, selector, cases)?;
2844            }
2845            Statement::RayQuery { query, ref fun } => {
2846                // There are three possibilities for a ptr to be:
2847                // 1. A variable
2848                // 2. A function argument
2849                // 3. part of a struct
2850                //
2851                // 2 and 3 are not possible, a ray query (in naga IR)
2852                // is not allowed to be passed into a function, and
2853                // all languages disallow it in a struct (you get fun results if
2854                // you try it :) ).
2855                //
2856                // Therefore, the ray query expression must be a variable.
2857                let crate::Expression::LocalVariable(query_var) = func_ctx.expressions[query]
2858                else {
2859                    unreachable!()
2860                };
2861
2862                let tracker_expr_name = format!(
2863                    "{RAY_QUERY_TRACKER_VARIABLE_PREFIX}{}",
2864                    self.names[&func_ctx.name_key(query_var)]
2865                );
2866
2867                match *fun {
2868                    RayQueryFunction::Initialize {
2869                        acceleration_structure,
2870                        descriptor,
2871                    } => {
2872                        self.write_initialize_function(
2873                            module,
2874                            level,
2875                            query,
2876                            acceleration_structure,
2877                            descriptor,
2878                            &tracker_expr_name,
2879                            func_ctx,
2880                        )?;
2881                    }
2882                    RayQueryFunction::Proceed { result } => {
2883                        self.write_proceed(
2884                            module,
2885                            level,
2886                            query,
2887                            result,
2888                            &tracker_expr_name,
2889                            func_ctx,
2890                        )?;
2891                    }
2892                    RayQueryFunction::GenerateIntersection { hit_t } => {
2893                        self.write_generate_intersection(
2894                            module,
2895                            level,
2896                            query,
2897                            hit_t,
2898                            &tracker_expr_name,
2899                            func_ctx,
2900                        )?;
2901                    }
2902                    RayQueryFunction::ConfirmIntersection => {
2903                        self.write_confirm_intersection(
2904                            module,
2905                            level,
2906                            query,
2907                            &tracker_expr_name,
2908                            func_ctx,
2909                        )?;
2910                    }
2911                    RayQueryFunction::Terminate => {
2912                        self.write_terminate(module, level, query, &tracker_expr_name, func_ctx)?;
2913                    }
2914                }
2915            }
2916            Statement::SubgroupBallot { result, predicate } => {
2917                write!(self.out, "{level}")?;
2918                let name = Baked(result).to_string();
2919                write!(self.out, "const uint4 {name} = ")?;
2920                self.named_expressions.insert(result, name);
2921
2922                write!(self.out, "WaveActiveBallot(")?;
2923                match predicate {
2924                    Some(predicate) => self.write_expr(module, predicate, func_ctx)?,
2925                    None => write!(self.out, "true")?,
2926                }
2927                writeln!(self.out, ");")?;
2928            }
2929            Statement::SubgroupCollectiveOperation {
2930                op,
2931                collective_op,
2932                argument,
2933                result,
2934            } => {
2935                write!(self.out, "{level}")?;
2936                write!(self.out, "const ")?;
2937                let name = Baked(result).to_string();
2938                match func_ctx.info[result].ty {
2939                    proc::TypeResolution::Handle(handle) => self.write_type(module, handle)?,
2940                    proc::TypeResolution::Value(ref value) => {
2941                        self.write_value_type(module, value)?
2942                    }
2943                };
2944                write!(self.out, " {name} = ")?;
2945                self.named_expressions.insert(result, name);
2946
2947                match (collective_op, op) {
2948                    (crate::CollectiveOperation::Reduce, crate::SubgroupOperation::All) => {
2949                        write!(self.out, "WaveActiveAllTrue(")?
2950                    }
2951                    (crate::CollectiveOperation::Reduce, crate::SubgroupOperation::Any) => {
2952                        write!(self.out, "WaveActiveAnyTrue(")?
2953                    }
2954                    (crate::CollectiveOperation::Reduce, crate::SubgroupOperation::Add) => {
2955                        write!(self.out, "WaveActiveSum(")?
2956                    }
2957                    (crate::CollectiveOperation::Reduce, crate::SubgroupOperation::Mul) => {
2958                        write!(self.out, "WaveActiveProduct(")?
2959                    }
2960                    (crate::CollectiveOperation::Reduce, crate::SubgroupOperation::Max) => {
2961                        write!(self.out, "WaveActiveMax(")?
2962                    }
2963                    (crate::CollectiveOperation::Reduce, crate::SubgroupOperation::Min) => {
2964                        write!(self.out, "WaveActiveMin(")?
2965                    }
2966                    (crate::CollectiveOperation::Reduce, crate::SubgroupOperation::And) => {
2967                        write!(self.out, "WaveActiveBitAnd(")?
2968                    }
2969                    (crate::CollectiveOperation::Reduce, crate::SubgroupOperation::Or) => {
2970                        write!(self.out, "WaveActiveBitOr(")?
2971                    }
2972                    (crate::CollectiveOperation::Reduce, crate::SubgroupOperation::Xor) => {
2973                        write!(self.out, "WaveActiveBitXor(")?
2974                    }
2975                    (crate::CollectiveOperation::ExclusiveScan, crate::SubgroupOperation::Add) => {
2976                        write!(self.out, "WavePrefixSum(")?
2977                    }
2978                    (crate::CollectiveOperation::ExclusiveScan, crate::SubgroupOperation::Mul) => {
2979                        write!(self.out, "WavePrefixProduct(")?
2980                    }
2981                    (crate::CollectiveOperation::InclusiveScan, crate::SubgroupOperation::Add) => {
2982                        self.write_expr(module, argument, func_ctx)?;
2983                        write!(self.out, " + WavePrefixSum(")?;
2984                    }
2985                    (crate::CollectiveOperation::InclusiveScan, crate::SubgroupOperation::Mul) => {
2986                        self.write_expr(module, argument, func_ctx)?;
2987                        write!(self.out, " * WavePrefixProduct(")?;
2988                    }
2989                    _ => unimplemented!(),
2990                }
2991                self.write_expr(module, argument, func_ctx)?;
2992                writeln!(self.out, ");")?;
2993            }
2994            Statement::SubgroupGather {
2995                mode,
2996                argument,
2997                result,
2998            } => {
2999                write!(self.out, "{level}")?;
3000                write!(self.out, "const ")?;
3001                let name = Baked(result).to_string();
3002                match func_ctx.info[result].ty {
3003                    proc::TypeResolution::Handle(handle) => self.write_type(module, handle)?,
3004                    proc::TypeResolution::Value(ref value) => {
3005                        self.write_value_type(module, value)?
3006                    }
3007                };
3008                write!(self.out, " {name} = ")?;
3009                self.named_expressions.insert(result, name);
3010                match mode {
3011                    crate::GatherMode::BroadcastFirst => {
3012                        write!(self.out, "WaveReadLaneFirst(")?;
3013                        self.write_expr(module, argument, func_ctx)?;
3014                    }
3015                    crate::GatherMode::QuadBroadcast(index) => {
3016                        write!(self.out, "QuadReadLaneAt(")?;
3017                        self.write_expr(module, argument, func_ctx)?;
3018                        write!(self.out, ", ")?;
3019                        self.write_expr(module, index, func_ctx)?;
3020                    }
3021                    crate::GatherMode::QuadSwap(direction) => {
3022                        match direction {
3023                            crate::Direction::X => {
3024                                write!(self.out, "QuadReadAcrossX(")?;
3025                            }
3026                            crate::Direction::Y => {
3027                                write!(self.out, "QuadReadAcrossY(")?;
3028                            }
3029                            crate::Direction::Diagonal => {
3030                                write!(self.out, "QuadReadAcrossDiagonal(")?;
3031                            }
3032                        }
3033                        self.write_expr(module, argument, func_ctx)?;
3034                    }
3035                    _ => {
3036                        write!(self.out, "WaveReadLaneAt(")?;
3037                        self.write_expr(module, argument, func_ctx)?;
3038                        write!(self.out, ", ")?;
3039                        match mode {
3040                            crate::GatherMode::BroadcastFirst => unreachable!(),
3041                            crate::GatherMode::Broadcast(index)
3042                            | crate::GatherMode::Shuffle(index) => {
3043                                self.write_expr(module, index, func_ctx)?;
3044                            }
3045                            crate::GatherMode::ShuffleDown(index) => {
3046                                write!(self.out, "WaveGetLaneIndex() + ")?;
3047                                self.write_expr(module, index, func_ctx)?;
3048                            }
3049                            crate::GatherMode::ShuffleUp(index) => {
3050                                write!(self.out, "WaveGetLaneIndex() - ")?;
3051                                self.write_expr(module, index, func_ctx)?;
3052                            }
3053                            crate::GatherMode::ShuffleXor(index) => {
3054                                write!(self.out, "WaveGetLaneIndex() ^ ")?;
3055                                self.write_expr(module, index, func_ctx)?;
3056                            }
3057                            crate::GatherMode::QuadBroadcast(_) => unreachable!(),
3058                            crate::GatherMode::QuadSwap(_) => unreachable!(),
3059                        }
3060                    }
3061                }
3062                writeln!(self.out, ");")?;
3063            }
3064            Statement::CooperativeStore { .. } => unimplemented!(),
3065            Statement::RayPipelineFunction(_) => unreachable!(),
3066        }
3067
3068        Ok(())
3069    }
3070
3071    #[allow(clippy::too_many_arguments)]
3072    fn write_math_expression(
3073        &mut self,
3074        module: &Module,
3075        fun: crate::MathFunction,
3076        arg: Handle<crate::Expression>,
3077        arg1: Option<Handle<crate::Expression>>,
3078        arg2: Option<Handle<crate::Expression>>,
3079        arg3: Option<Handle<crate::Expression>>,
3080        func_ctx: &back::FunctionCtx<'_>,
3081    ) -> BackendResult {
3082        use crate::MathFunction as Mf;
3083
3084        enum Function {
3085            Asincosh { is_sin: bool },
3086            Atanh,
3087            Pack2x16float,
3088            Pack2x16snorm,
3089            Pack2x16unorm,
3090            Pack4x8snorm,
3091            Pack4x8unorm,
3092            Pack4xI8,
3093            Pack4xU8,
3094            Pack4xI8Clamp,
3095            Pack4xU8Clamp,
3096            Unpack2x16float,
3097            Unpack2x16snorm,
3098            Unpack2x16unorm,
3099            Unpack4x8snorm,
3100            Unpack4x8unorm,
3101            Unpack4xI8,
3102            Unpack4xU8,
3103            Dot4I8Packed,
3104            Dot4U8Packed,
3105            QuantizeToF16,
3106            Regular(&'static str),
3107            MissingIntOverload(&'static str),
3108            MissingIntReturnType(&'static str),
3109            CountTrailingZeros,
3110            CountLeadingZeros,
3111        }
3112
3113        let fun = match fun {
3114            // comparison
3115            Mf::Abs => match func_ctx.resolve_type(arg, &module.types).scalar() {
3116                Some(Scalar::I32) => Function::Regular(ABS_FUNCTION),
3117                _ => Function::Regular("abs"),
3118            },
3119            Mf::Min => Function::Regular("min"),
3120            Mf::Max => Function::Regular("max"),
3121            Mf::Clamp => Function::Regular("clamp"),
3122            Mf::Saturate => Function::Regular("saturate"),
3123            // trigonometry
3124            Mf::Cos => Function::Regular("cos"),
3125            Mf::Cosh => Function::Regular("cosh"),
3126            Mf::Sin => Function::Regular("sin"),
3127            Mf::Sinh => Function::Regular("sinh"),
3128            Mf::Tan => Function::Regular("tan"),
3129            Mf::Tanh => Function::Regular("tanh"),
3130            Mf::Acos => Function::Regular("acos"),
3131            Mf::Asin => Function::Regular("asin"),
3132            Mf::Atan => Function::Regular("atan"),
3133            Mf::Atan2 => Function::Regular("atan2"),
3134            Mf::Asinh => Function::Asincosh { is_sin: true },
3135            Mf::Acosh => Function::Asincosh { is_sin: false },
3136            Mf::Atanh => Function::Atanh,
3137            Mf::Radians => Function::Regular("radians"),
3138            Mf::Degrees => Function::Regular("degrees"),
3139            // decomposition
3140            Mf::Ceil => Function::Regular("ceil"),
3141            Mf::Floor => Function::Regular("floor"),
3142            Mf::Round => Function::Regular("round"),
3143            Mf::Fract => Function::Regular("frac"),
3144            Mf::Trunc => Function::Regular("trunc"),
3145            Mf::Modf => Function::Regular(MODF_FUNCTION),
3146            Mf::Frexp => Function::Regular(FREXP_FUNCTION),
3147            Mf::Ldexp => Function::Regular("ldexp"),
3148            // exponent
3149            Mf::Exp => Function::Regular("exp"),
3150            Mf::Exp2 => Function::Regular("exp2"),
3151            Mf::Log => Function::Regular("log"),
3152            Mf::Log2 => Function::Regular("log2"),
3153            Mf::Pow => Function::Regular("pow"),
3154            // geometry
3155            Mf::Dot => Function::Regular("dot"),
3156            Mf::Dot4I8Packed => Function::Dot4I8Packed,
3157            Mf::Dot4U8Packed => Function::Dot4U8Packed,
3158            //Mf::Outer => ,
3159            Mf::Cross => Function::Regular("cross"),
3160            Mf::Distance => Function::Regular("distance"),
3161            Mf::Length => Function::Regular("length"),
3162            Mf::Normalize => Function::Regular("normalize"),
3163            Mf::FaceForward => Function::Regular("faceforward"),
3164            Mf::Reflect => Function::Regular("reflect"),
3165            Mf::Refract => Function::Regular("refract"),
3166            // computational
3167            Mf::Sign => Function::Regular("sign"),
3168            Mf::Fma => Function::Regular("mad"),
3169            Mf::Mix => Function::Regular("lerp"),
3170            Mf::Step => Function::Regular("step"),
3171            Mf::SmoothStep => Function::Regular("smoothstep"),
3172            Mf::Sqrt => Function::Regular("sqrt"),
3173            Mf::InverseSqrt => Function::Regular("rsqrt"),
3174            //Mf::Inverse =>,
3175            Mf::Transpose => Function::Regular("transpose"),
3176            Mf::Determinant => Function::Regular("determinant"),
3177            Mf::QuantizeToF16 => Function::QuantizeToF16,
3178            // bits
3179            Mf::CountTrailingZeros => Function::CountTrailingZeros,
3180            Mf::CountLeadingZeros => Function::CountLeadingZeros,
3181            Mf::CountOneBits => Function::MissingIntOverload("countbits"),
3182            Mf::ReverseBits => Function::MissingIntOverload("reversebits"),
3183            Mf::FirstTrailingBit => Function::MissingIntReturnType("firstbitlow"),
3184            Mf::FirstLeadingBit => Function::MissingIntReturnType("firstbithigh"),
3185            Mf::ExtractBits => Function::Regular(EXTRACT_BITS_FUNCTION),
3186            Mf::InsertBits => Function::Regular(INSERT_BITS_FUNCTION),
3187            // Data Packing
3188            Mf::Pack2x16float => Function::Pack2x16float,
3189            Mf::Pack2x16snorm => Function::Pack2x16snorm,
3190            Mf::Pack2x16unorm => Function::Pack2x16unorm,
3191            Mf::Pack4x8snorm => Function::Pack4x8snorm,
3192            Mf::Pack4x8unorm => Function::Pack4x8unorm,
3193            Mf::Pack4xI8 => Function::Pack4xI8,
3194            Mf::Pack4xU8 => Function::Pack4xU8,
3195            Mf::Pack4xI8Clamp => Function::Pack4xI8Clamp,
3196            Mf::Pack4xU8Clamp => Function::Pack4xU8Clamp,
3197            // Data Unpacking
3198            Mf::Unpack2x16float => Function::Unpack2x16float,
3199            Mf::Unpack2x16snorm => Function::Unpack2x16snorm,
3200            Mf::Unpack2x16unorm => Function::Unpack2x16unorm,
3201            Mf::Unpack4x8snorm => Function::Unpack4x8snorm,
3202            Mf::Unpack4x8unorm => Function::Unpack4x8unorm,
3203            Mf::Unpack4xI8 => Function::Unpack4xI8,
3204            Mf::Unpack4xU8 => Function::Unpack4xU8,
3205            _ => return Err(Error::Unimplemented(format!("write_expr_math {fun:?}"))),
3206        };
3207
3208        match fun {
3209            Function::Asincosh { is_sin } => {
3210                write!(self.out, "log(")?;
3211                self.write_expr(module, arg, func_ctx)?;
3212                write!(self.out, " + sqrt(")?;
3213                self.write_expr(module, arg, func_ctx)?;
3214                write!(self.out, " * ")?;
3215                self.write_expr(module, arg, func_ctx)?;
3216                match is_sin {
3217                    true => write!(self.out, " + 1.0))")?,
3218                    false => write!(self.out, " - 1.0))")?,
3219                }
3220            }
3221            Function::Atanh => {
3222                write!(self.out, "0.5 * log((1.0 + ")?;
3223                self.write_expr(module, arg, func_ctx)?;
3224                write!(self.out, ") / (1.0 - ")?;
3225                self.write_expr(module, arg, func_ctx)?;
3226                write!(self.out, "))")?;
3227            }
3228            Function::Pack2x16float => {
3229                write!(self.out, "(f32tof16(")?;
3230                self.write_expr(module, arg, func_ctx)?;
3231                write!(self.out, "[0]) | f32tof16(")?;
3232                self.write_expr(module, arg, func_ctx)?;
3233                write!(self.out, "[1]) << 16)")?;
3234            }
3235            Function::Pack2x16snorm => {
3236                let scale = 32767;
3237
3238                write!(self.out, "uint((int(round(clamp(")?;
3239                self.write_expr(module, arg, func_ctx)?;
3240                write!(
3241                    self.out,
3242                    "[0], -1.0, 1.0) * {scale}.0)) & 0xFFFF) | ((int(round(clamp("
3243                )?;
3244                self.write_expr(module, arg, func_ctx)?;
3245                write!(self.out, "[1], -1.0, 1.0) * {scale}.0)) & 0xFFFF) << 16))",)?;
3246            }
3247            Function::Pack2x16unorm => {
3248                let scale = 65535;
3249
3250                write!(self.out, "(uint(round(clamp(")?;
3251                self.write_expr(module, arg, func_ctx)?;
3252                write!(self.out, "[0], 0.0, 1.0) * {scale}.0)) | uint(round(clamp(")?;
3253                self.write_expr(module, arg, func_ctx)?;
3254                write!(self.out, "[1], 0.0, 1.0) * {scale}.0)) << 16)")?;
3255            }
3256            Function::Pack4x8snorm => {
3257                let scale = 127;
3258
3259                write!(self.out, "uint((int(round(clamp(")?;
3260                self.write_expr(module, arg, func_ctx)?;
3261                write!(
3262                    self.out,
3263                    "[0], -1.0, 1.0) * {scale}.0)) & 0xFF) | ((int(round(clamp("
3264                )?;
3265                self.write_expr(module, arg, func_ctx)?;
3266                write!(
3267                    self.out,
3268                    "[1], -1.0, 1.0) * {scale}.0)) & 0xFF) << 8) | ((int(round(clamp("
3269                )?;
3270                self.write_expr(module, arg, func_ctx)?;
3271                write!(
3272                    self.out,
3273                    "[2], -1.0, 1.0) * {scale}.0)) & 0xFF) << 16) | ((int(round(clamp("
3274                )?;
3275                self.write_expr(module, arg, func_ctx)?;
3276                write!(self.out, "[3], -1.0, 1.0) * {scale}.0)) & 0xFF) << 24))",)?;
3277            }
3278            Function::Pack4x8unorm => {
3279                let scale = 255;
3280
3281                write!(self.out, "(uint(round(clamp(")?;
3282                self.write_expr(module, arg, func_ctx)?;
3283                write!(self.out, "[0], 0.0, 1.0) * {scale}.0)) | uint(round(clamp(")?;
3284                self.write_expr(module, arg, func_ctx)?;
3285                write!(
3286                    self.out,
3287                    "[1], 0.0, 1.0) * {scale}.0)) << 8 | uint(round(clamp("
3288                )?;
3289                self.write_expr(module, arg, func_ctx)?;
3290                write!(
3291                    self.out,
3292                    "[2], 0.0, 1.0) * {scale}.0)) << 16 | uint(round(clamp("
3293                )?;
3294                self.write_expr(module, arg, func_ctx)?;
3295                write!(self.out, "[3], 0.0, 1.0) * {scale}.0)) << 24)")?;
3296            }
3297            fun @ (Function::Pack4xI8
3298            | Function::Pack4xU8
3299            | Function::Pack4xI8Clamp
3300            | Function::Pack4xU8Clamp) => {
3301                let was_signed = matches!(fun, Function::Pack4xI8 | Function::Pack4xI8Clamp);
3302                let clamp_bounds = match fun {
3303                    Function::Pack4xI8Clamp => Some(("-128", "127")),
3304                    Function::Pack4xU8Clamp => Some(("0", "255")),
3305                    _ => None,
3306                };
3307                if was_signed {
3308                    write!(self.out, "uint(")?;
3309                }
3310                let write_arg = |this: &mut Self| -> BackendResult {
3311                    if let Some((min, max)) = clamp_bounds {
3312                        write!(this.out, "clamp(")?;
3313                        this.write_expr(module, arg, func_ctx)?;
3314                        write!(this.out, ", {min}, {max})")?;
3315                    } else {
3316                        this.write_expr(module, arg, func_ctx)?;
3317                    }
3318                    Ok(())
3319                };
3320                write!(self.out, "(")?;
3321                write_arg(self)?;
3322                write!(self.out, "[0] & 0xFF) | ((")?;
3323                write_arg(self)?;
3324                write!(self.out, "[1] & 0xFF) << 8) | ((")?;
3325                write_arg(self)?;
3326                write!(self.out, "[2] & 0xFF) << 16) | ((")?;
3327                write_arg(self)?;
3328                write!(self.out, "[3] & 0xFF) << 24)")?;
3329                if was_signed {
3330                    write!(self.out, ")")?;
3331                }
3332            }
3333
3334            Function::Unpack2x16float => {
3335                write!(self.out, "float2(f16tof32(")?;
3336                self.write_expr(module, arg, func_ctx)?;
3337                write!(self.out, "), f16tof32((")?;
3338                self.write_expr(module, arg, func_ctx)?;
3339                write!(self.out, ") >> 16))")?;
3340            }
3341            Function::Unpack2x16snorm => {
3342                let scale = 32767;
3343
3344                write!(self.out, "(float2(int2(")?;
3345                self.write_expr(module, arg, func_ctx)?;
3346                write!(self.out, " << 16, ")?;
3347                self.write_expr(module, arg, func_ctx)?;
3348                write!(self.out, ") >> 16) / {scale}.0)")?;
3349            }
3350            Function::Unpack2x16unorm => {
3351                let scale = 65535;
3352
3353                write!(self.out, "(float2(")?;
3354                self.write_expr(module, arg, func_ctx)?;
3355                write!(self.out, " & 0xFFFF, ")?;
3356                self.write_expr(module, arg, func_ctx)?;
3357                write!(self.out, " >> 16) / {scale}.0)")?;
3358            }
3359            Function::Unpack4x8snorm => {
3360                let scale = 127;
3361
3362                write!(self.out, "(float4(int4(")?;
3363                self.write_expr(module, arg, func_ctx)?;
3364                write!(self.out, " << 24, ")?;
3365                self.write_expr(module, arg, func_ctx)?;
3366                write!(self.out, " << 16, ")?;
3367                self.write_expr(module, arg, func_ctx)?;
3368                write!(self.out, " << 8, ")?;
3369                self.write_expr(module, arg, func_ctx)?;
3370                write!(self.out, ") >> 24) / {scale}.0)")?;
3371            }
3372            Function::Unpack4x8unorm => {
3373                let scale = 255;
3374
3375                write!(self.out, "(float4(")?;
3376                self.write_expr(module, arg, func_ctx)?;
3377                write!(self.out, " & 0xFF, ")?;
3378                self.write_expr(module, arg, func_ctx)?;
3379                write!(self.out, " >> 8 & 0xFF, ")?;
3380                self.write_expr(module, arg, func_ctx)?;
3381                write!(self.out, " >> 16 & 0xFF, ")?;
3382                self.write_expr(module, arg, func_ctx)?;
3383                write!(self.out, " >> 24) / {scale}.0)")?;
3384            }
3385            fun @ (Function::Unpack4xI8 | Function::Unpack4xU8) => {
3386                write!(self.out, "(")?;
3387                if matches!(fun, Function::Unpack4xU8) {
3388                    write!(self.out, "u")?;
3389                }
3390                write!(self.out, "int4(")?;
3391                self.write_expr(module, arg, func_ctx)?;
3392                write!(self.out, ", ")?;
3393                self.write_expr(module, arg, func_ctx)?;
3394                write!(self.out, " >> 8, ")?;
3395                self.write_expr(module, arg, func_ctx)?;
3396                write!(self.out, " >> 16, ")?;
3397                self.write_expr(module, arg, func_ctx)?;
3398                write!(self.out, " >> 24) << 24 >> 24)")?;
3399            }
3400            fun @ (Function::Dot4I8Packed | Function::Dot4U8Packed) => {
3401                let arg1 = arg1.unwrap();
3402
3403                if self.options.shader_model >= ShaderModel::V6_4 {
3404                    // Intrinsics `dot4add_{i, u}8packed` are available in SM 6.4 and later.
3405                    let function_name = match fun {
3406                        Function::Dot4I8Packed => "dot4add_i8packed",
3407                        Function::Dot4U8Packed => "dot4add_u8packed",
3408                        _ => unreachable!(),
3409                    };
3410                    write!(self.out, "{function_name}(")?;
3411                    self.write_expr(module, arg, func_ctx)?;
3412                    write!(self.out, ", ")?;
3413                    self.write_expr(module, arg1, func_ctx)?;
3414                    write!(self.out, ", 0)")?;
3415                } else {
3416                    // Fall back to a polyfill as `dot4add_u8packed` is not available.
3417                    write!(self.out, "dot(")?;
3418
3419                    if matches!(fun, Function::Dot4U8Packed) {
3420                        write!(self.out, "u")?;
3421                    }
3422                    write!(self.out, "int4(")?;
3423                    self.write_expr(module, arg, func_ctx)?;
3424                    write!(self.out, ", ")?;
3425                    self.write_expr(module, arg, func_ctx)?;
3426                    write!(self.out, " >> 8, ")?;
3427                    self.write_expr(module, arg, func_ctx)?;
3428                    write!(self.out, " >> 16, ")?;
3429                    self.write_expr(module, arg, func_ctx)?;
3430                    write!(self.out, " >> 24) << 24 >> 24, ")?;
3431
3432                    if matches!(fun, Function::Dot4U8Packed) {
3433                        write!(self.out, "u")?;
3434                    }
3435                    write!(self.out, "int4(")?;
3436                    self.write_expr(module, arg1, func_ctx)?;
3437                    write!(self.out, ", ")?;
3438                    self.write_expr(module, arg1, func_ctx)?;
3439                    write!(self.out, " >> 8, ")?;
3440                    self.write_expr(module, arg1, func_ctx)?;
3441                    write!(self.out, " >> 16, ")?;
3442                    self.write_expr(module, arg1, func_ctx)?;
3443                    write!(self.out, " >> 24) << 24 >> 24)")?;
3444                }
3445            }
3446            Function::QuantizeToF16 => {
3447                write!(self.out, "f16tof32(f32tof16(")?;
3448                self.write_expr(module, arg, func_ctx)?;
3449                write!(self.out, "))")?;
3450            }
3451            Function::Regular(fun_name) => {
3452                write!(self.out, "{fun_name}(")?;
3453                self.write_expr(module, arg, func_ctx)?;
3454                if let Some(arg) = arg1 {
3455                    write!(self.out, ", ")?;
3456                    self.write_expr(module, arg, func_ctx)?;
3457                }
3458                if let Some(arg) = arg2 {
3459                    write!(self.out, ", ")?;
3460                    self.write_expr(module, arg, func_ctx)?;
3461                }
3462                if let Some(arg) = arg3 {
3463                    write!(self.out, ", ")?;
3464                    self.write_expr(module, arg, func_ctx)?;
3465                }
3466                write!(self.out, ")")?
3467            }
3468            // These overloads are only missing on FXC, so this is only needed for 32bit types,
3469            // as non-32bit types are DXC only.
3470            Function::MissingIntOverload(fun_name) => {
3471                let scalar_kind = func_ctx.resolve_type(arg, &module.types).scalar();
3472                if let Some(Scalar::I32) = scalar_kind {
3473                    write!(self.out, "asint({fun_name}(asuint(")?;
3474                    self.write_expr(module, arg, func_ctx)?;
3475                    write!(self.out, ")))")?;
3476                } else {
3477                    write!(self.out, "{fun_name}(")?;
3478                    self.write_expr(module, arg, func_ctx)?;
3479                    write!(self.out, ")")?;
3480                }
3481            }
3482            // These overloads are only missing on FXC, so this is only needed for 32bit types,
3483            // as non-32bit types are DXC only.
3484            Function::MissingIntReturnType(fun_name) => {
3485                let scalar_kind = func_ctx.resolve_type(arg, &module.types).scalar();
3486                if let Some(Scalar::I32) = scalar_kind {
3487                    write!(self.out, "asint({fun_name}(")?;
3488                    self.write_expr(module, arg, func_ctx)?;
3489                    write!(self.out, "))")?;
3490                } else {
3491                    write!(self.out, "{fun_name}(")?;
3492                    self.write_expr(module, arg, func_ctx)?;
3493                    write!(self.out, ")")?;
3494                }
3495            }
3496            Function::CountTrailingZeros => {
3497                match *func_ctx.resolve_type(arg, &module.types) {
3498                    TypeInner::Vector { size, scalar } => {
3499                        let s = match size {
3500                            crate::VectorSize::Bi => ".xx",
3501                            crate::VectorSize::Tri => ".xxx",
3502                            crate::VectorSize::Quad => ".xxxx",
3503                        };
3504
3505                        let scalar_width_bits = scalar.width * 8;
3506
3507                        if scalar.kind == ScalarKind::Uint || scalar.width != 4 {
3508                            write!(self.out, "min(({scalar_width_bits}u){s}, firstbitlow(")?;
3509                            self.write_expr(module, arg, func_ctx)?;
3510                            write!(self.out, "))")?;
3511                        } else {
3512                            // This is only needed for the FXC path, on 32bit signed integers.
3513                            write!(
3514                                self.out,
3515                                "asint(min(({scalar_width_bits}u){s}, firstbitlow("
3516                            )?;
3517                            self.write_expr(module, arg, func_ctx)?;
3518                            write!(self.out, ")))")?;
3519                        }
3520                    }
3521                    TypeInner::Scalar(scalar) => {
3522                        let scalar_width_bits = scalar.width * 8;
3523
3524                        if scalar.kind == ScalarKind::Uint || scalar.width != 4 {
3525                            write!(self.out, "min({scalar_width_bits}u, firstbitlow(")?;
3526                            self.write_expr(module, arg, func_ctx)?;
3527                            write!(self.out, "))")?;
3528                        } else {
3529                            // This is only needed for the FXC path, on 32bit signed integers.
3530                            write!(self.out, "asint(min({scalar_width_bits}u, firstbitlow(")?;
3531                            self.write_expr(module, arg, func_ctx)?;
3532                            write!(self.out, ")))")?;
3533                        }
3534                    }
3535                    _ => unreachable!(),
3536                }
3537
3538                return Ok(());
3539            }
3540            Function::CountLeadingZeros => {
3541                match *func_ctx.resolve_type(arg, &module.types) {
3542                    TypeInner::Vector { size, scalar } => {
3543                        let s = match size {
3544                            crate::VectorSize::Bi => ".xx",
3545                            crate::VectorSize::Tri => ".xxx",
3546                            crate::VectorSize::Quad => ".xxxx",
3547                        };
3548
3549                        // scalar width - 1
3550                        let constant = scalar.width * 8 - 1;
3551
3552                        if scalar.kind == ScalarKind::Uint {
3553                            write!(self.out, "(({constant}u){s} - firstbithigh(")?;
3554                            self.write_expr(module, arg, func_ctx)?;
3555                            write!(self.out, "))")?;
3556                        } else {
3557                            let conversion_func = match scalar.width {
3558                                4 => "asint",
3559                                _ => "",
3560                            };
3561                            write!(self.out, "(")?;
3562                            self.write_expr(module, arg, func_ctx)?;
3563                            write!(
3564                                self.out,
3565                                " < (0){s} ? (0){s} : ({constant}){s} - {conversion_func}(firstbithigh("
3566                            )?;
3567                            self.write_expr(module, arg, func_ctx)?;
3568                            write!(self.out, ")))")?;
3569                        }
3570                    }
3571                    TypeInner::Scalar(scalar) => {
3572                        // scalar width - 1
3573                        let constant = scalar.width * 8 - 1;
3574
3575                        if let ScalarKind::Uint = scalar.kind {
3576                            write!(self.out, "({constant}u - firstbithigh(")?;
3577                            self.write_expr(module, arg, func_ctx)?;
3578                            write!(self.out, "))")?;
3579                        } else {
3580                            let conversion_func = match scalar.width {
3581                                4 => "asint",
3582                                _ => "",
3583                            };
3584                            write!(self.out, "(")?;
3585                            self.write_expr(module, arg, func_ctx)?;
3586                            write!(
3587                                self.out,
3588                                " < 0 ? 0 : {constant} - {conversion_func}(firstbithigh("
3589                            )?;
3590                            self.write_expr(module, arg, func_ctx)?;
3591                            write!(self.out, ")))")?;
3592                        }
3593                    }
3594                    _ => unreachable!(),
3595                }
3596            }
3597        }
3598
3599        Ok(())
3600    }
3601
3602    fn write_const_expression(
3603        &mut self,
3604        module: &Module,
3605        expr: Handle<crate::Expression>,
3606        arena: &crate::Arena<crate::Expression>,
3607    ) -> BackendResult {
3608        self.write_possibly_const_expression(module, expr, arena, |writer, expr| {
3609            writer.write_const_expression(module, expr, arena)
3610        })
3611    }
3612
3613    pub(super) fn write_literal(&mut self, literal: crate::Literal) -> BackendResult {
3614        match literal {
3615            crate::Literal::F64(value) => write!(self.out, "{value:?}L")?,
3616            crate::Literal::F32(value) => write!(self.out, "{value:?}")?,
3617            crate::Literal::F16(value) => write!(self.out, "{value:?}h")?,
3618            crate::Literal::U16(value) => write!(self.out, "uint16_t({value})")?,
3619            crate::Literal::I16(value) => write!(self.out, "int16_t({value})")?,
3620            crate::Literal::U32(value) => write!(self.out, "{value}u")?,
3621            // `-2147483648` is parsed by some compilers as unary negation of
3622            // positive 2147483648, which is too large for an int, causing
3623            // issues for some compilers. Neither DXC nor FXC appear to have
3624            // this problem, but this is not specified and could change. We
3625            // therefore use `-2147483647 - 1` as a precaution.
3626            crate::Literal::I32(value) if value == i32::MIN => {
3627                write!(self.out, "int({} - 1)", value + 1)?
3628            }
3629            // HLSL has no suffix for explicit i32 literals, but not using any suffix
3630            // makes the type ambiguous which prevents overload resolution from
3631            // working. So we explicitly use the int() constructor syntax.
3632            crate::Literal::I32(value) => write!(self.out, "int({value})")?,
3633            crate::Literal::U64(value) => write!(self.out, "{value}uL")?,
3634            // I64 version of the minimum I32 value issue described above.
3635            crate::Literal::I64(value) if value == i64::MIN => {
3636                write!(self.out, "({}L - 1L)", value + 1)?;
3637            }
3638            crate::Literal::I64(value) => write!(self.out, "{value}L")?,
3639            crate::Literal::Bool(value) => write!(self.out, "{value}")?,
3640            crate::Literal::AbstractInt(_) | crate::Literal::AbstractFloat(_) => {
3641                return Err(Error::Custom(
3642                    "Abstract types should not appear in IR presented to backends".into(),
3643                ));
3644            }
3645        }
3646        Ok(())
3647    }
3648
3649    fn write_possibly_const_expression<E>(
3650        &mut self,
3651        module: &Module,
3652        expr: Handle<crate::Expression>,
3653        expressions: &crate::Arena<crate::Expression>,
3654        write_expression: E,
3655    ) -> BackendResult
3656    where
3657        E: Fn(&mut Self, Handle<crate::Expression>) -> BackendResult,
3658    {
3659        use crate::Expression;
3660
3661        match expressions[expr] {
3662            Expression::Literal(literal) => {
3663                self.write_literal(literal)?;
3664            }
3665            Expression::Constant(handle) => {
3666                let constant = &module.constants[handle];
3667                if constant.name.is_some() {
3668                    write!(self.out, "{}", self.names[&NameKey::Constant(handle)])?;
3669                } else {
3670                    self.write_const_expression(module, constant.init, &module.global_expressions)?;
3671                }
3672            }
3673            Expression::ZeroValue(ty) => {
3674                self.write_wrapped_zero_value_function_name(module, WrappedZeroValue { ty })?;
3675                write!(self.out, "()")?;
3676            }
3677            Expression::Compose { ty, ref components } => {
3678                match module.types[ty].inner {
3679                    TypeInner::Struct { .. } | TypeInner::Array { .. } => {
3680                        self.write_wrapped_constructor_function_name(
3681                            module,
3682                            WrappedConstructor { ty },
3683                        )?;
3684                    }
3685                    _ => {
3686                        self.write_type(module, ty)?;
3687                    }
3688                };
3689                write!(self.out, "(")?;
3690                for (index, component) in components.iter().enumerate() {
3691                    if index != 0 {
3692                        write!(self.out, ", ")?;
3693                    }
3694                    write_expression(self, *component)?;
3695                }
3696                write!(self.out, ")")?;
3697            }
3698            Expression::Splat { size, value } => {
3699                // hlsl is not supported one value constructor
3700                // if we write, for example, int4(0), dxc returns error:
3701                // error: too few elements in vector initialization (expected 4 elements, have 1)
3702                let number_of_components = match size {
3703                    crate::VectorSize::Bi => "xx",
3704                    crate::VectorSize::Tri => "xxx",
3705                    crate::VectorSize::Quad => "xxxx",
3706                };
3707                write!(self.out, "(")?;
3708                write_expression(self, value)?;
3709                write!(self.out, ").{number_of_components}")?
3710            }
3711            _ => {
3712                return Err(Error::Override);
3713            }
3714        }
3715
3716        Ok(())
3717    }
3718
3719    /// Helper method to write expressions
3720    ///
3721    /// # Notes
3722    /// Doesn't add any newlines or leading/trailing spaces
3723    pub(super) fn write_expr(
3724        &mut self,
3725        module: &Module,
3726        expr: Handle<crate::Expression>,
3727        func_ctx: &back::FunctionCtx<'_>,
3728    ) -> BackendResult {
3729        use crate::Expression;
3730
3731        // Handle the special semantics of vertex_index/instance_index
3732        let ff_input = if self.options.special_constants_binding.is_some() {
3733            func_ctx.is_fixed_function_input(expr, module)
3734        } else {
3735            None
3736        };
3737        let closing_bracket = match ff_input {
3738            Some(crate::BuiltIn::VertexIndex) => {
3739                write!(self.out, "({SPECIAL_CBUF_VAR}.{SPECIAL_FIRST_VERTEX} + ")?;
3740                ")"
3741            }
3742            Some(crate::BuiltIn::InstanceIndex) => {
3743                write!(self.out, "({SPECIAL_CBUF_VAR}.{SPECIAL_FIRST_INSTANCE} + ",)?;
3744                ")"
3745            }
3746            Some(crate::BuiltIn::NumWorkGroups) => {
3747                // Note: despite their names (`FIRST_VERTEX` and `FIRST_INSTANCE`),
3748                // in compute shaders the special constants contain the number
3749                // of workgroups, which we are using here.
3750                write!(
3751                    self.out,
3752                    "uint3({SPECIAL_CBUF_VAR}.{SPECIAL_FIRST_VERTEX}, {SPECIAL_CBUF_VAR}.{SPECIAL_FIRST_INSTANCE}, {SPECIAL_CBUF_VAR}.{SPECIAL_OTHER})",
3753                )?;
3754                return Ok(());
3755            }
3756            _ => "",
3757        };
3758
3759        if let Some(name) = self.named_expressions.get(&expr) {
3760            write!(self.out, "{name}{closing_bracket}")?;
3761            return Ok(());
3762        }
3763
3764        let expression = &func_ctx.expressions[expr];
3765
3766        match *expression {
3767            Expression::Literal(_)
3768            | Expression::Constant(_)
3769            | Expression::ZeroValue(_)
3770            | Expression::Compose { .. }
3771            | Expression::Splat { .. } => {
3772                self.write_possibly_const_expression(
3773                    module,
3774                    expr,
3775                    func_ctx.expressions,
3776                    |writer, expr| writer.write_expr(module, expr, func_ctx),
3777                )?;
3778            }
3779            Expression::Override(_) => return Err(Error::Override),
3780            // Avoid undefined behaviour for addition, subtraction, and
3781            // multiplication of signed integers by casting operands to
3782            // unsigned, performing the operation, then casting the result back
3783            // to signed.
3784            // TODO(#7109): This relies on the asint()/asuint() functions which only work
3785            // for 32-bit types, so we must find another solution for different bit widths.
3786            Expression::Binary {
3787                op:
3788                    op @ crate::BinaryOperator::Add
3789                    | op @ crate::BinaryOperator::Subtract
3790                    | op @ crate::BinaryOperator::Multiply,
3791                left,
3792                right,
3793            } if matches!(
3794                func_ctx.resolve_type(expr, &module.types).scalar(),
3795                Some(Scalar::I32)
3796            ) =>
3797            {
3798                write!(self.out, "asint(asuint(",)?;
3799                self.write_expr(module, left, func_ctx)?;
3800                write!(self.out, ") {} asuint(", back::binary_operation_str(op))?;
3801                self.write_expr(module, right, func_ctx)?;
3802                write!(self.out, "))")?;
3803            }
3804            // All of the multiplication can be expressed as `mul`,
3805            // except vector * vector, which needs to use the "*" operator.
3806            Expression::Binary {
3807                op: crate::BinaryOperator::Multiply,
3808                left,
3809                right,
3810            } if func_ctx.resolve_type(left, &module.types).is_matrix()
3811                || func_ctx.resolve_type(right, &module.types).is_matrix() =>
3812            {
3813                // We intentionally flip the order of multiplication as our matrices are implicitly transposed.
3814                write!(self.out, "mul(")?;
3815                self.write_expr(module, right, func_ctx)?;
3816                write!(self.out, ", ")?;
3817                self.write_expr(module, left, func_ctx)?;
3818                write!(self.out, ")")?;
3819            }
3820
3821            // WGSL says that floating-point division by zero should return
3822            // infinity. Microsoft's Direct3D 11 functional specification
3823            // (https://microsoft.github.io/DirectX-Specs/d3d/archive/D3D11_3_FunctionalSpec.htm)
3824            // says:
3825            //
3826            //     Divide by 0 produces +/- INF, except 0/0 which results in NaN.
3827            //
3828            // which is what we want. The DXIL specification for the FDiv
3829            // instruction corroborates this:
3830            //
3831            // https://github.com/microsoft/DirectXShaderCompiler/blob/main/docs/DXIL.rst#fdiv
3832            Expression::Binary {
3833                op: crate::BinaryOperator::Divide,
3834                left,
3835                right,
3836            } if matches!(
3837                func_ctx.resolve_type(expr, &module.types).scalar_kind(),
3838                Some(ScalarKind::Sint | ScalarKind::Uint)
3839            ) =>
3840            {
3841                write!(self.out, "{DIV_FUNCTION}(")?;
3842                self.write_expr(module, left, func_ctx)?;
3843                write!(self.out, ", ")?;
3844                self.write_expr(module, right, func_ctx)?;
3845                write!(self.out, ")")?;
3846            }
3847
3848            Expression::Binary {
3849                op: crate::BinaryOperator::Modulo,
3850                left,
3851                right,
3852            } if matches!(
3853                func_ctx.resolve_type(expr, &module.types).scalar_kind(),
3854                Some(ScalarKind::Sint | ScalarKind::Uint | ScalarKind::Float)
3855            ) =>
3856            {
3857                write!(self.out, "{MOD_FUNCTION}(")?;
3858                self.write_expr(module, left, func_ctx)?;
3859                write!(self.out, ", ")?;
3860                self.write_expr(module, right, func_ctx)?;
3861                write!(self.out, ")")?;
3862            }
3863
3864            Expression::Binary { op, left, right } => {
3865                write!(self.out, "(")?;
3866                self.write_expr(module, left, func_ctx)?;
3867                write!(self.out, " {} ", back::binary_operation_str(op))?;
3868                self.write_expr(module, right, func_ctx)?;
3869                write!(self.out, ")")?;
3870            }
3871            Expression::Access { base, index } => {
3872                if let Some(crate::AddressSpace::Storage { .. }) =
3873                    func_ctx.resolve_type(expr, &module.types).pointer_space()
3874                {
3875                    // do nothing, the chain is written on `Load`/`Store`
3876                } else {
3877                    // We use the function __get_col_of_matCx2 here in cases
3878                    // where `base`s type resolves to a matCx2 and is part of a
3879                    // struct member with type of (possibly nested) array of matCx2's.
3880                    //
3881                    // Note that this only works for `Load`s and we handle
3882                    // `Store`s differently in `Statement::Store`.
3883                    if let Some(MatrixType {
3884                        columns,
3885                        rows: crate::VectorSize::Bi,
3886                        width,
3887                    }) = get_inner_matrix_of_struct_array_member(module, base, func_ctx, true)
3888                        .or_else(|| {
3889                            get_inner_matrix_of_global_uniform(module, base, func_ctx, true)
3890                        })
3891                    {
3892                        write!(
3893                            self.out,
3894                            "__get_col_of_mat{}x2_f{}(",
3895                            columns as u8,
3896                            width * 8
3897                        )?;
3898                        self.write_expr(module, base, func_ctx)?;
3899                        write!(self.out, ", ")?;
3900                        self.write_expr(module, index, func_ctx)?;
3901                        write!(self.out, ")")?;
3902                        return Ok(());
3903                    }
3904
3905                    let resolved = func_ctx.resolve_type(base, &module.types);
3906
3907                    let (indexing_binding_array, non_uniform_qualifier) = match *resolved {
3908                        TypeInner::BindingArray { .. } => {
3909                            let uniformity = &func_ctx.info[index].uniformity;
3910
3911                            (true, uniformity.non_uniform_result.is_some())
3912                        }
3913                        _ => (false, false),
3914                    };
3915
3916                    self.write_expr(module, base, func_ctx)?;
3917
3918                    let array_sampler_info = self.sampler_binding_array_info_from_expression(
3919                        module, func_ctx, base, resolved,
3920                    );
3921
3922                    if let Some(ref info) = array_sampler_info {
3923                        write!(self.out, "{}[", info.sampler_heap_name)?;
3924                    } else {
3925                        write!(self.out, "[")?;
3926                    }
3927
3928                    let needs_bound_check = self.options.restrict_indexing
3929                        && !indexing_binding_array
3930                        && match resolved.pointer_space() {
3931                            Some(
3932                                crate::AddressSpace::Function
3933                                | crate::AddressSpace::Private
3934                                | crate::AddressSpace::WorkGroup
3935                                | crate::AddressSpace::Immediate
3936                                | crate::AddressSpace::TaskPayload
3937                                | crate::AddressSpace::RayPayload
3938                                | crate::AddressSpace::IncomingRayPayload,
3939                            )
3940                            | None => true,
3941                            Some(crate::AddressSpace::Uniform) => {
3942                                // check if BindTarget.restrict_indexing is set, this is used for dynamic buffers
3943                                let var_handle = self.fill_access_chain(module, base, func_ctx)?;
3944                                let bind_target = self
3945                                    .options
3946                                    .resolve_resource_binding(
3947                                        module.global_variables[var_handle]
3948                                            .binding
3949                                            .as_ref()
3950                                            .unwrap(),
3951                                    )
3952                                    .unwrap();
3953                                bind_target.restrict_indexing
3954                            }
3955                            Some(
3956                                crate::AddressSpace::Handle | crate::AddressSpace::Storage { .. },
3957                            ) => unreachable!(),
3958                        };
3959                    // Decide whether this index needs to be clamped to fall within range.
3960                    let restriction_needed = if needs_bound_check {
3961                        index::access_needs_check(
3962                            base,
3963                            index::GuardedIndex::Expression(index),
3964                            module,
3965                            func_ctx.expressions,
3966                            func_ctx.info,
3967                        )
3968                    } else {
3969                        None
3970                    };
3971                    if let Some(limit) = restriction_needed {
3972                        write!(self.out, "min(uint(")?;
3973                        self.write_expr(module, index, func_ctx)?;
3974                        write!(self.out, "), ")?;
3975                        match limit {
3976                            index::IndexableLength::Known(limit) => {
3977                                write!(self.out, "{}u", limit - 1)?;
3978                            }
3979                            index::IndexableLength::Dynamic => unreachable!(),
3980                        }
3981                        write!(self.out, ")")?;
3982                    } else {
3983                        if non_uniform_qualifier {
3984                            write!(self.out, "NonUniformResourceIndex(")?;
3985                        }
3986                        if let Some(ref info) = array_sampler_info {
3987                            write!(
3988                                self.out,
3989                                "{}[{} + ",
3990                                info.sampler_index_buffer_name, info.binding_array_base_index_name,
3991                            )?;
3992                        }
3993                        self.write_expr(module, index, func_ctx)?;
3994                        if array_sampler_info.is_some() {
3995                            write!(self.out, "]")?;
3996                        }
3997                        if non_uniform_qualifier {
3998                            write!(self.out, ")")?;
3999                        }
4000                    }
4001
4002                    write!(self.out, "]")?;
4003                }
4004            }
4005            Expression::AccessIndex { base, index } => {
4006                if let Some(crate::AddressSpace::Storage { .. }) =
4007                    func_ctx.resolve_type(expr, &module.types).pointer_space()
4008                {
4009                    // do nothing, the chain is written on `Load`/`Store`
4010                } else {
4011                    // See if we need to write the matrix column access in a
4012                    // special way since the type of `base` is our special
4013                    // __matCx2 struct.
4014                    if let Some(MatrixType {
4015                        rows: crate::VectorSize::Bi,
4016                        ..
4017                    }) = get_inner_matrix_of_struct_array_member(module, base, func_ctx, true)
4018                        .or_else(|| {
4019                            get_inner_matrix_of_global_uniform(module, base, func_ctx, true)
4020                        })
4021                    {
4022                        self.write_expr(module, base, func_ctx)?;
4023                        write!(self.out, "._{index}")?;
4024                        return Ok(());
4025                    }
4026
4027                    let base_ty_res = &func_ctx.info[base].ty;
4028                    let mut resolved = base_ty_res.inner_with(&module.types);
4029                    let base_ty_handle = match *resolved {
4030                        TypeInner::Pointer { base, .. } => {
4031                            resolved = &module.types[base].inner;
4032                            Some(base)
4033                        }
4034                        _ => base_ty_res.handle(),
4035                    };
4036
4037                    // We treat matrices of the form `matCx2` as a sequence of C `vec2`s.
4038                    // See the module-level block comment in mod.rs for details.
4039                    //
4040                    // We handle matrix reconstruction here for Loads.
4041                    // Stores are handled directly by `Statement::Store`.
4042                    if let TypeInner::Struct { ref members, .. } = *resolved {
4043                        let member = &members[index as usize];
4044
4045                        match module.types[member.ty].inner {
4046                            TypeInner::Matrix {
4047                                rows: crate::VectorSize::Bi,
4048                                ..
4049                            } if member.binding.is_none() => {
4050                                let ty = base_ty_handle.unwrap();
4051                                self.write_wrapped_struct_matrix_get_function_name(
4052                                    WrappedStructMatrixAccess { ty, index },
4053                                )?;
4054                                write!(self.out, "(")?;
4055                                self.write_expr(module, base, func_ctx)?;
4056                                write!(self.out, ")")?;
4057                                return Ok(());
4058                            }
4059                            _ => {}
4060                        }
4061                    }
4062
4063                    let array_sampler_info = self.sampler_binding_array_info_from_expression(
4064                        module, func_ctx, base, resolved,
4065                    );
4066
4067                    if let Some(ref info) = array_sampler_info {
4068                        write!(
4069                            self.out,
4070                            "{}[{}",
4071                            info.sampler_heap_name, info.sampler_index_buffer_name
4072                        )?;
4073                    }
4074
4075                    self.write_expr(module, base, func_ctx)?;
4076
4077                    match *resolved {
4078                        // We specifically lift the ValuePointer to this case. While `[0]` is valid
4079                        // HLSL for any vector behind a value pointer, FXC completely miscompiles
4080                        // it and generates completely nonsensical DXBC.
4081                        //
4082                        // See https://github.com/gfx-rs/naga/issues/2095 for more details.
4083                        TypeInner::Vector { .. } | TypeInner::ValuePointer { .. } => {
4084                            // Write vector access as a swizzle
4085                            write!(self.out, ".{}", back::COMPONENTS[index as usize])?
4086                        }
4087                        TypeInner::Matrix { .. }
4088                        | TypeInner::Array { .. }
4089                        | TypeInner::BindingArray { .. } => {
4090                            if let Some(ref info) = array_sampler_info {
4091                                write!(
4092                                    self.out,
4093                                    "[{} + {index}]",
4094                                    info.binding_array_base_index_name
4095                                )?;
4096                            } else {
4097                                write!(self.out, "[{index}]")?;
4098                            }
4099                        }
4100                        TypeInner::Struct { .. } => {
4101                            // This will never panic in case the type is a `Struct`, this is not true
4102                            // for other types so we can only check while inside this match arm
4103                            let ty = base_ty_handle.unwrap();
4104
4105                            write!(
4106                                self.out,
4107                                ".{}",
4108                                &self.names[&NameKey::StructMember(ty, index)]
4109                            )?
4110                        }
4111                        ref other => return Err(Error::Custom(format!("Cannot index {other:?}"))),
4112                    }
4113
4114                    if array_sampler_info.is_some() {
4115                        write!(self.out, "]")?;
4116                    }
4117                }
4118            }
4119            Expression::FunctionArgument(pos) => {
4120                let ty = func_ctx.resolve_type(expr, &module.types);
4121
4122                // We know that any external texture function argument has been expanded into
4123                // separate consecutive arguments for each plane and the parameters buffer. And we
4124                // also know that external textures can only ever be used as an argument to another
4125                // function. Therefore we can simply emit each of the expanded arguments in a
4126                // consecutive comma-separated list.
4127                if let TypeInner::Image {
4128                    class: crate::ImageClass::External,
4129                    ..
4130                } = *ty
4131                {
4132                    let plane_names = [0, 1, 2].map(|i| {
4133                        &self.names[&func_ctx
4134                            .external_texture_argument_key(pos, ExternalTextureNameKey::Plane(i))]
4135                    });
4136                    let params_name = &self.names[&func_ctx
4137                        .external_texture_argument_key(pos, ExternalTextureNameKey::Params)];
4138                    write!(
4139                        self.out,
4140                        "{}, {}, {}, {}",
4141                        plane_names[0], plane_names[1], plane_names[2], params_name
4142                    )?;
4143                } else {
4144                    let key = func_ctx.argument_key(pos);
4145                    let name = &self.names[&key];
4146                    write!(self.out, "{name}")?;
4147                }
4148            }
4149            Expression::ImageSample {
4150                coordinate,
4151                image,
4152                sampler,
4153                clamp_to_edge: true,
4154                gather: None,
4155                array_index: None,
4156                offset: None,
4157                level: crate::SampleLevel::Zero,
4158                depth_ref: None,
4159            } => {
4160                write!(self.out, "{IMAGE_SAMPLE_BASE_CLAMP_TO_EDGE_FUNCTION}(")?;
4161                self.write_expr(module, image, func_ctx)?;
4162                write!(self.out, ", ")?;
4163                self.write_expr(module, sampler, func_ctx)?;
4164                write!(self.out, ", ")?;
4165                self.write_expr(module, coordinate, func_ctx)?;
4166                write!(self.out, ")")?;
4167            }
4168            Expression::ImageSample {
4169                image,
4170                sampler,
4171                gather,
4172                coordinate,
4173                array_index,
4174                offset,
4175                level,
4176                depth_ref,
4177                clamp_to_edge,
4178            } => {
4179                if clamp_to_edge {
4180                    return Err(Error::Custom(
4181                        "ImageSample::clamp_to_edge should have been validated out".to_string(),
4182                    ));
4183                }
4184
4185                use crate::SampleLevel as Sl;
4186                const COMPONENTS: [&str; 4] = ["", "Green", "Blue", "Alpha"];
4187
4188                let (base_str, component_str) = match gather {
4189                    Some(component) => ("Gather", COMPONENTS[component as usize]),
4190                    None => ("Sample", ""),
4191                };
4192                let cmp_str = match depth_ref {
4193                    Some(_) => "Cmp",
4194                    None => "",
4195                };
4196                let level_str = match level {
4197                    Sl::Zero if gather.is_none() => "LevelZero",
4198                    Sl::Auto | Sl::Zero => "",
4199                    Sl::Exact(_) => "Level",
4200                    Sl::Bias(_) => "Bias",
4201                    Sl::Gradient { .. } => "Grad",
4202                };
4203
4204                self.write_expr(module, image, func_ctx)?;
4205                write!(self.out, ".{base_str}{cmp_str}{component_str}{level_str}(")?;
4206                self.write_expr(module, sampler, func_ctx)?;
4207                write!(self.out, ", ")?;
4208                self.write_texture_coordinates(
4209                    "float",
4210                    coordinate,
4211                    array_index,
4212                    None,
4213                    module,
4214                    func_ctx,
4215                )?;
4216
4217                if let Some(depth_ref) = depth_ref {
4218                    write!(self.out, ", ")?;
4219                    self.write_expr(module, depth_ref, func_ctx)?;
4220                }
4221
4222                match level {
4223                    Sl::Auto | Sl::Zero => {}
4224                    Sl::Exact(expr) => {
4225                        write!(self.out, ", ")?;
4226                        self.write_expr(module, expr, func_ctx)?;
4227                    }
4228                    Sl::Bias(expr) => {
4229                        write!(self.out, ", ")?;
4230                        self.write_expr(module, expr, func_ctx)?;
4231                    }
4232                    Sl::Gradient { x, y } => {
4233                        write!(self.out, ", ")?;
4234                        self.write_expr(module, x, func_ctx)?;
4235                        write!(self.out, ", ")?;
4236                        self.write_expr(module, y, func_ctx)?;
4237                    }
4238                }
4239
4240                if let Some(offset) = offset {
4241                    write!(self.out, ", ")?;
4242                    write!(self.out, "int2(")?; // work around https://github.com/microsoft/DirectXShaderCompiler/issues/5082#issuecomment-1540147807
4243                    self.write_const_expression(module, offset, func_ctx.expressions)?;
4244                    write!(self.out, ")")?;
4245                }
4246
4247                write!(self.out, ")")?;
4248            }
4249            Expression::ImageQuery { image, query } => {
4250                // use wrapped image query function
4251                if let TypeInner::Image {
4252                    dim,
4253                    arrayed,
4254                    class,
4255                } = *func_ctx.resolve_type(image, &module.types)
4256                {
4257                    let wrapped_image_query = WrappedImageQuery {
4258                        dim,
4259                        arrayed,
4260                        class,
4261                        query: query.into(),
4262                    };
4263
4264                    self.write_wrapped_image_query_function_name(wrapped_image_query)?;
4265                    write!(self.out, "(")?;
4266                    // Image always first param
4267                    self.write_expr(module, image, func_ctx)?;
4268                    if let crate::ImageQuery::Size { level: Some(level) } = query {
4269                        write!(self.out, ", ")?;
4270                        self.write_expr(module, level, func_ctx)?;
4271                    }
4272                    write!(self.out, ")")?;
4273                }
4274            }
4275            Expression::ImageLoad {
4276                image,
4277                coordinate,
4278                array_index,
4279                sample,
4280                level,
4281            } => self.write_image_load(
4282                &module,
4283                expr,
4284                func_ctx,
4285                image,
4286                coordinate,
4287                array_index,
4288                sample,
4289                level,
4290            )?,
4291            Expression::GlobalVariable(handle) => {
4292                let global_variable = &module.global_variables[handle];
4293                let ty = &module.types[global_variable.ty].inner;
4294
4295                // In the case of binding arrays of samplers, we need to not write anything
4296                // as the we are in the wrong position to fully write the expression.
4297                //
4298                // The entire writing is done by AccessIndex.
4299                let is_binding_array_of_samplers = match *ty {
4300                    TypeInner::BindingArray { base, .. } => {
4301                        let base_ty = &module.types[base].inner;
4302                        matches!(*base_ty, TypeInner::Sampler { .. })
4303                    }
4304                    _ => false,
4305                };
4306
4307                let is_storage_space =
4308                    matches!(global_variable.space, crate::AddressSpace::Storage { .. });
4309
4310                // Our external texture global variable has been expanded into multiple
4311                // global variables, one for each plane and the parameters buffer.
4312                // External textures can only ever be used as arguments to a function
4313                // call, and we know that an external texture argument to any function
4314                // will have been expanded to separate consecutive arguments for each
4315                // plane and the parameters buffer. Therefore we can simply emit each of
4316                // the expanded global variables in a consecutive comma-separated list.
4317                if let TypeInner::Image {
4318                    class: crate::ImageClass::External,
4319                    ..
4320                } = *ty
4321                {
4322                    let plane_names = [0, 1, 2].map(|i| {
4323                        &self.names[&NameKey::ExternalTextureGlobalVariable(
4324                            handle,
4325                            ExternalTextureNameKey::Plane(i),
4326                        )]
4327                    });
4328                    let params_name = &self.names[&NameKey::ExternalTextureGlobalVariable(
4329                        handle,
4330                        ExternalTextureNameKey::Params,
4331                    )];
4332                    write!(
4333                        self.out,
4334                        "{}, {}, {}, {}",
4335                        plane_names[0], plane_names[1], plane_names[2], params_name
4336                    )?;
4337                } else if !is_binding_array_of_samplers && !is_storage_space {
4338                    let name = &self.names[&NameKey::GlobalVariable(handle)];
4339                    write!(self.out, "{name}")?;
4340                }
4341            }
4342            Expression::LocalVariable(handle) => {
4343                write!(self.out, "{}", self.names[&func_ctx.name_key(handle)])?
4344            }
4345            Expression::Load { pointer } => {
4346                match func_ctx
4347                    .resolve_type(pointer, &module.types)
4348                    .pointer_space()
4349                {
4350                    Some(crate::AddressSpace::Storage { .. }) => {
4351                        let var_handle = self.fill_access_chain(module, pointer, func_ctx)?;
4352                        let result_ty = func_ctx.info[expr].ty.clone();
4353                        self.write_storage_load(module, var_handle, result_ty, func_ctx)?;
4354                    }
4355                    _ => {
4356                        let mut close_paren = false;
4357
4358                        // We cast the value loaded to a native HLSL floatCx2
4359                        // in cases where it is of type:
4360                        //  - __matCx2 or
4361                        //  - a (possibly nested) array of __matCx2's
4362                        if let Some(MatrixType {
4363                            rows: crate::VectorSize::Bi,
4364                            ..
4365                        }) = get_inner_matrix_of_struct_array_member(
4366                            module, pointer, func_ctx, false,
4367                        )
4368                        .or_else(|| {
4369                            get_inner_matrix_of_global_uniform(module, pointer, func_ctx, false)
4370                        }) {
4371                            let mut resolved = func_ctx.resolve_type(pointer, &module.types);
4372                            let ptr_tr = resolved.pointer_base_type();
4373                            if let Some(ptr_ty) =
4374                                ptr_tr.as_ref().map(|tr| tr.inner_with(&module.types))
4375                            {
4376                                resolved = ptr_ty;
4377                            }
4378
4379                            write!(self.out, "((")?;
4380                            if let TypeInner::Array { base, size, .. } = *resolved {
4381                                self.write_type(module, base)?;
4382                                self.write_array_size(module, base, size)?;
4383                            } else {
4384                                self.write_value_type(module, resolved)?;
4385                            }
4386                            write!(self.out, ")")?;
4387                            close_paren = true;
4388                        }
4389
4390                        self.write_expr(module, pointer, func_ctx)?;
4391
4392                        if close_paren {
4393                            write!(self.out, ")")?;
4394                        }
4395                    }
4396                }
4397            }
4398            Expression::Unary { op, expr } => {
4399                // https://docs.microsoft.com/en-us/windows/win32/direct3dhlsl/dx-graphics-hlsl-operators#unary-operators
4400                let op_str = match op {
4401                    crate::UnaryOperator::Negate => {
4402                        match func_ctx.resolve_type(expr, &module.types).scalar() {
4403                            Some(Scalar::I32) => NEG_FUNCTION,
4404                            _ => "-",
4405                        }
4406                    }
4407                    crate::UnaryOperator::LogicalNot => "!",
4408                    crate::UnaryOperator::BitwiseNot => "~",
4409                };
4410                write!(self.out, "{op_str}(")?;
4411                self.write_expr(module, expr, func_ctx)?;
4412                write!(self.out, ")")?;
4413            }
4414            Expression::As {
4415                expr,
4416                kind,
4417                convert,
4418            } => {
4419                let inner = func_ctx.resolve_type(expr, &module.types);
4420                if inner.scalar_kind() == Some(ScalarKind::Float)
4421                    && (kind == ScalarKind::Sint || kind == ScalarKind::Uint)
4422                    && convert.is_some()
4423                    && matches!(convert, Some(4) | Some(8))
4424                {
4425                    // Use helper functions for float to int casts in order to
4426                    // avoid undefined behaviour when value is out of range for
4427                    // the target type.
4428                    let fun_name = match (kind, convert) {
4429                        (ScalarKind::Sint, Some(4)) => F2I32_FUNCTION,
4430                        (ScalarKind::Uint, Some(4)) => F2U32_FUNCTION,
4431                        (ScalarKind::Sint, Some(8)) => F2I64_FUNCTION,
4432                        (ScalarKind::Uint, Some(8)) => F2U64_FUNCTION,
4433                        _ => unreachable!(),
4434                    };
4435                    write!(self.out, "{fun_name}(")?;
4436                    self.write_expr(module, expr, func_ctx)?;
4437                    write!(self.out, ")")?;
4438                } else {
4439                    let close_paren = match convert {
4440                        Some(dst_width) => {
4441                            let scalar = Scalar {
4442                                kind,
4443                                width: dst_width,
4444                            };
4445                            match *inner {
4446                                TypeInner::Vector { size, .. } => {
4447                                    write!(
4448                                        self.out,
4449                                        "{}{}(",
4450                                        scalar.to_hlsl_str()?,
4451                                        common::vector_size_str(size)
4452                                    )?;
4453                                }
4454                                TypeInner::Scalar(_) => {
4455                                    write!(self.out, "{}(", scalar.to_hlsl_str()?,)?;
4456                                }
4457                                TypeInner::Matrix { columns, rows, .. } => {
4458                                    write!(
4459                                        self.out,
4460                                        "{}{}x{}(",
4461                                        scalar.to_hlsl_str()?,
4462                                        common::vector_size_str(columns),
4463                                        common::vector_size_str(rows)
4464                                    )?;
4465                                }
4466                                _ => {
4467                                    return Err(Error::Unimplemented(format!(
4468                                        "write_expr expression::as {inner:?}"
4469                                    )));
4470                                }
4471                            };
4472                            true
4473                        }
4474                        None => {
4475                            if inner.scalar_width() == Some(8) {
4476                                false
4477                            } else if inner.scalar_width() == Some(2) {
4478                                // HLSL's asint()/asuint() only work on 32-bit types.
4479                                // For 16-bit bitcasts, use type constructor instead.
4480                                let dst_scalar = Scalar { kind, width: 2 };
4481                                match *inner {
4482                                    TypeInner::Vector { size, .. } => {
4483                                        write!(
4484                                            self.out,
4485                                            "{}{}(",
4486                                            dst_scalar.to_hlsl_str()?,
4487                                            common::vector_size_str(size)
4488                                        )?;
4489                                    }
4490                                    _ => {
4491                                        write!(self.out, "{}(", dst_scalar.to_hlsl_str()?)?;
4492                                    }
4493                                };
4494                                true
4495                            } else {
4496                                write!(self.out, "{}(", kind.to_hlsl_cast(),)?;
4497                                true
4498                            }
4499                        }
4500                    };
4501                    self.write_expr(module, expr, func_ctx)?;
4502                    if close_paren {
4503                        write!(self.out, ")")?;
4504                    }
4505                }
4506            }
4507            Expression::Math {
4508                fun,
4509                arg,
4510                arg1,
4511                arg2,
4512                arg3,
4513            } => {
4514                return self.write_math_expression(module, fun, arg, arg1, arg2, arg3, func_ctx);
4515            }
4516            Expression::Swizzle {
4517                size,
4518                vector,
4519                pattern,
4520            } => {
4521                self.write_expr(module, vector, func_ctx)?;
4522                write!(self.out, ".")?;
4523                for &sc in pattern[..size as usize].iter() {
4524                    self.out.write_char(back::COMPONENTS[sc as usize])?;
4525                }
4526            }
4527            Expression::ArrayLength(expr) => {
4528                let var_handle = match func_ctx.expressions[expr] {
4529                    Expression::AccessIndex { base, index: _ } => {
4530                        match func_ctx.expressions[base] {
4531                            Expression::GlobalVariable(handle) => handle,
4532                            _ => unreachable!(),
4533                        }
4534                    }
4535                    Expression::GlobalVariable(handle) => handle,
4536                    _ => unreachable!(),
4537                };
4538
4539                let var = &module.global_variables[var_handle];
4540                let (offset, stride) = match module.types[var.ty].inner {
4541                    TypeInner::Array { stride, .. } => (0, stride),
4542                    TypeInner::Struct { ref members, .. } => {
4543                        let last = members.last().unwrap();
4544                        let stride = match module.types[last.ty].inner {
4545                            TypeInner::Array { stride, .. } => stride,
4546                            _ => unreachable!(),
4547                        };
4548                        (last.offset, stride)
4549                    }
4550                    _ => unreachable!(),
4551                };
4552
4553                let storage_access = match var.space {
4554                    crate::AddressSpace::Storage { access } => access,
4555                    _ => crate::StorageAccess::default(),
4556                };
4557                let wrapped_array_length = WrappedArrayLength {
4558                    writable: storage_access.contains(crate::StorageAccess::STORE),
4559                };
4560
4561                write!(self.out, "((")?;
4562                self.write_wrapped_array_length_function_name(wrapped_array_length)?;
4563                let var_name = &self.names[&NameKey::GlobalVariable(var_handle)];
4564                write!(self.out, "({var_name}) - {offset}) / {stride})")?
4565            }
4566            Expression::Derivative { axis, ctrl, expr } => {
4567                use crate::{DerivativeAxis as Axis, DerivativeControl as Ctrl};
4568                if axis == Axis::Width && (ctrl == Ctrl::Coarse || ctrl == Ctrl::Fine) {
4569                    let tail = match ctrl {
4570                        Ctrl::Coarse => "coarse",
4571                        Ctrl::Fine => "fine",
4572                        Ctrl::None => unreachable!(),
4573                    };
4574                    write!(self.out, "abs(ddx_{tail}(")?;
4575                    self.write_expr(module, expr, func_ctx)?;
4576                    write!(self.out, ")) + abs(ddy_{tail}(")?;
4577                    self.write_expr(module, expr, func_ctx)?;
4578                    write!(self.out, "))")?
4579                } else {
4580                    let fun_str = match (axis, ctrl) {
4581                        (Axis::X, Ctrl::Coarse) => "ddx_coarse",
4582                        (Axis::X, Ctrl::Fine) => "ddx_fine",
4583                        (Axis::X, Ctrl::None) => "ddx",
4584                        (Axis::Y, Ctrl::Coarse) => "ddy_coarse",
4585                        (Axis::Y, Ctrl::Fine) => "ddy_fine",
4586                        (Axis::Y, Ctrl::None) => "ddy",
4587                        (Axis::Width, Ctrl::Coarse | Ctrl::Fine) => unreachable!(),
4588                        (Axis::Width, Ctrl::None) => "fwidth",
4589                    };
4590                    write!(self.out, "{fun_str}(")?;
4591                    self.write_expr(module, expr, func_ctx)?;
4592                    write!(self.out, ")")?
4593                }
4594            }
4595            Expression::Relational { fun, argument } => {
4596                use crate::RelationalFunction as Rf;
4597
4598                let fun_str = match fun {
4599                    Rf::All => "all",
4600                    Rf::Any => "any",
4601                    Rf::IsNan => "isnan",
4602                    Rf::IsInf => "isinf",
4603                };
4604                write!(self.out, "{fun_str}(")?;
4605                self.write_expr(module, argument, func_ctx)?;
4606                write!(self.out, ")")?
4607            }
4608            Expression::Select {
4609                condition,
4610                accept,
4611                reject,
4612            } => {
4613                write!(self.out, "(")?;
4614                self.write_expr(module, condition, func_ctx)?;
4615                write!(self.out, " ? ")?;
4616                self.write_expr(module, accept, func_ctx)?;
4617                write!(self.out, " : ")?;
4618                self.write_expr(module, reject, func_ctx)?;
4619                write!(self.out, ")")?
4620            }
4621            Expression::RayQueryGetIntersection { query, committed } => {
4622                // For reasoning, see write_stmt
4623                let Expression::LocalVariable(query_var) = func_ctx.expressions[query] else {
4624                    unreachable!()
4625                };
4626
4627                let tracker_expr_name = format!(
4628                    "{RAY_QUERY_TRACKER_VARIABLE_PREFIX}{}",
4629                    self.names[&func_ctx.name_key(query_var)]
4630                );
4631
4632                if committed {
4633                    write!(self.out, "GetCommittedIntersection(")?;
4634                    self.write_expr(module, query, func_ctx)?;
4635                    write!(self.out, ", {tracker_expr_name})")?;
4636                } else {
4637                    write!(self.out, "GetCandidateIntersection(")?;
4638                    self.write_expr(module, query, func_ctx)?;
4639                    write!(self.out, ", {tracker_expr_name})")?;
4640                }
4641            }
4642            // Not supported yet
4643            Expression::RayQueryVertexPositions { .. }
4644            | Expression::CooperativeLoad { .. }
4645            | Expression::CooperativeMultiplyAdd { .. } => {
4646                unreachable!()
4647            }
4648            // Nothing to do here, since call expression already cached
4649            Expression::CallResult(_)
4650            | Expression::AtomicResult { .. }
4651            | Expression::WorkGroupUniformLoadResult { .. }
4652            | Expression::RayQueryProceedResult
4653            | Expression::SubgroupBallotResult
4654            | Expression::SubgroupOperationResult { .. } => {}
4655        }
4656
4657        if !closing_bracket.is_empty() {
4658            write!(self.out, "{closing_bracket}")?;
4659        }
4660        Ok(())
4661    }
4662
4663    #[allow(clippy::too_many_arguments)]
4664    fn write_image_load(
4665        &mut self,
4666        module: &&Module,
4667        expr: Handle<crate::Expression>,
4668        func_ctx: &back::FunctionCtx,
4669        image: Handle<crate::Expression>,
4670        coordinate: Handle<crate::Expression>,
4671        array_index: Option<Handle<crate::Expression>>,
4672        sample: Option<Handle<crate::Expression>>,
4673        level: Option<Handle<crate::Expression>>,
4674    ) -> Result<(), Error> {
4675        let mut wrapping_type = None;
4676        match *func_ctx.resolve_type(image, &module.types) {
4677            TypeInner::Image {
4678                class: crate::ImageClass::External,
4679                ..
4680            } => {
4681                write!(self.out, "{IMAGE_LOAD_EXTERNAL_FUNCTION}(")?;
4682                self.write_expr(module, image, func_ctx)?;
4683                write!(self.out, ", ")?;
4684                self.write_expr(module, coordinate, func_ctx)?;
4685                write!(self.out, ")")?;
4686                return Ok(());
4687            }
4688            TypeInner::Image {
4689                class: crate::ImageClass::Storage { format, .. },
4690                ..
4691            } => {
4692                if format.single_component() {
4693                    wrapping_type = Some(Scalar::from(format));
4694                }
4695            }
4696            _ => {}
4697        }
4698        if let Some(scalar) = wrapping_type {
4699            write!(
4700                self.out,
4701                "{}{}(",
4702                help::IMAGE_STORAGE_LOAD_SCALAR_WRAPPER,
4703                scalar.to_hlsl_str()?
4704            )?;
4705        }
4706        // https://docs.microsoft.com/en-us/windows/win32/direct3dhlsl/dx-graphics-hlsl-to-load
4707        self.write_expr(module, image, func_ctx)?;
4708        write!(self.out, ".Load(")?;
4709
4710        self.write_texture_coordinates("int", coordinate, array_index, level, module, func_ctx)?;
4711
4712        if let Some(sample) = sample {
4713            write!(self.out, ", ")?;
4714            self.write_expr(module, sample, func_ctx)?;
4715        }
4716
4717        // close bracket for Load function
4718        write!(self.out, ")")?;
4719
4720        if wrapping_type.is_some() {
4721            write!(self.out, ")")?;
4722        }
4723
4724        // return x component if return type is scalar
4725        if let TypeInner::Scalar(_) = *func_ctx.resolve_type(expr, &module.types) {
4726            write!(self.out, ".x")?;
4727        }
4728        Ok(())
4729    }
4730
4731    /// Find the [`BindingArraySamplerInfo`] from an expression so that such an access
4732    /// can be generated later.
4733    fn sampler_binding_array_info_from_expression(
4734        &mut self,
4735        module: &Module,
4736        func_ctx: &back::FunctionCtx<'_>,
4737        base: Handle<crate::Expression>,
4738        resolved: &TypeInner,
4739    ) -> Option<BindingArraySamplerInfo> {
4740        if let TypeInner::BindingArray {
4741            base: base_ty_handle,
4742            ..
4743        } = *resolved
4744        {
4745            let base_ty = &module.types[base_ty_handle].inner;
4746            if let TypeInner::Sampler { comparison, .. } = *base_ty {
4747                let base = &func_ctx.expressions[base];
4748
4749                if let crate::Expression::GlobalVariable(handle) = *base {
4750                    let variable = &module.global_variables[handle];
4751
4752                    let sampler_heap_name = match comparison {
4753                        true => COMPARISON_SAMPLER_HEAP_VAR,
4754                        false => SAMPLER_HEAP_VAR,
4755                    };
4756
4757                    return Some(BindingArraySamplerInfo {
4758                        sampler_heap_name,
4759                        sampler_index_buffer_name: self
4760                            .wrapped
4761                            .sampler_index_buffers
4762                            .get(&super::SamplerIndexBufferKey {
4763                                group: variable.binding.unwrap().group,
4764                            })
4765                            .unwrap()
4766                            .clone(),
4767                        binding_array_base_index_name: self.names[&NameKey::GlobalVariable(handle)]
4768                            .clone(),
4769                    });
4770                }
4771            }
4772        }
4773
4774        None
4775    }
4776
4777    fn write_named_expr(
4778        &mut self,
4779        module: &Module,
4780        handle: Handle<crate::Expression>,
4781        name: String,
4782        // The expression which is being named.
4783        // Generally, this is the same as handle, except in WorkGroupUniformLoad
4784        expr: Handle<crate::Expression>,
4785        func_ctx: &back::FunctionCtx,
4786    ) -> BackendResult {
4787        if let crate::Expression::Load { pointer } = func_ctx.expressions[expr] {
4788            let ty_inner = func_ctx.resolve_type(pointer, &module.types);
4789            if ty_inner.is_atomic_pointer(&module.types) {
4790                let pointer_space = ty_inner.pointer_space().unwrap();
4791                self.write_value_type(module, func_ctx.info[handle].ty.inner_with(&module.types))?;
4792                write!(self.out, " {name}; ")?;
4793                match pointer_space {
4794                    crate::AddressSpace::WorkGroup => {
4795                        write!(self.out, "InterlockedOr(")?;
4796                        self.write_expr(module, pointer, func_ctx)?;
4797                    }
4798                    crate::AddressSpace::Storage { .. } => {
4799                        let var_handle = self.fill_access_chain(module, pointer, func_ctx)?;
4800                        let var_name = &self.names[&NameKey::GlobalVariable(var_handle)];
4801                        write!(self.out, "{var_name}.InterlockedOr(")?;
4802                        let chain = mem::take(&mut self.temp_access_chain);
4803                        self.write_storage_address(module, &chain, func_ctx)?;
4804                        self.temp_access_chain = chain;
4805                    }
4806                    _ => unreachable!(),
4807                }
4808                writeln!(self.out, ", 0, {name});")?;
4809                self.named_expressions.insert(expr, name);
4810                return Ok(());
4811            }
4812        }
4813        match func_ctx.info[expr].ty {
4814            proc::TypeResolution::Handle(ty_handle) => match module.types[ty_handle].inner {
4815                TypeInner::Struct { .. } => {
4816                    let ty_name = &self.names[&NameKey::Type(ty_handle)];
4817                    write!(self.out, "{ty_name}")?;
4818                }
4819                _ => {
4820                    self.write_type(module, ty_handle)?;
4821                }
4822            },
4823            proc::TypeResolution::Value(ref inner) => {
4824                self.write_value_type(module, inner)?;
4825            }
4826        }
4827
4828        let resolved = func_ctx.resolve_type(expr, &module.types);
4829
4830        write!(self.out, " {name}")?;
4831        // If rhs is a array type, we should write array size
4832        if let TypeInner::Array { base, size, .. } = *resolved {
4833            self.write_array_size(module, base, size)?;
4834        }
4835        write!(self.out, " = ")?;
4836        self.write_expr(module, handle, func_ctx)?;
4837        writeln!(self.out, ";")?;
4838        self.named_expressions.insert(expr, name);
4839
4840        Ok(())
4841    }
4842
4843    /// Helper function that write default zero initialization
4844    pub(super) fn write_default_init(
4845        &mut self,
4846        module: &Module,
4847        ty: Handle<crate::Type>,
4848    ) -> BackendResult {
4849        write!(self.out, "(")?;
4850        self.write_type(module, ty)?;
4851        if let TypeInner::Array { base, size, .. } = module.types[ty].inner {
4852            self.write_array_size(module, base, size)?;
4853        }
4854        write!(self.out, ")0")?;
4855        Ok(())
4856    }
4857
4858    pub(super) fn write_control_barrier(
4859        &mut self,
4860        barrier: crate::Barrier,
4861        level: back::Level,
4862    ) -> BackendResult {
4863        if barrier.contains(crate::Barrier::STORAGE) {
4864            writeln!(self.out, "{level}DeviceMemoryBarrierWithGroupSync();")?;
4865        }
4866        if barrier.contains(crate::Barrier::WORK_GROUP) {
4867            writeln!(self.out, "{level}GroupMemoryBarrierWithGroupSync();")?;
4868        }
4869        if barrier.contains(crate::Barrier::SUB_GROUP) {
4870            // Does not exist in DirectX
4871        }
4872        if barrier.contains(crate::Barrier::TEXTURE) {
4873            writeln!(self.out, "{level}DeviceMemoryBarrierWithGroupSync();")?;
4874        }
4875        Ok(())
4876    }
4877
4878    fn write_memory_barrier(
4879        &mut self,
4880        barrier: crate::Barrier,
4881        level: back::Level,
4882    ) -> BackendResult {
4883        if barrier.contains(crate::Barrier::STORAGE) {
4884            writeln!(self.out, "{level}DeviceMemoryBarrier();")?;
4885        }
4886        if barrier.contains(crate::Barrier::WORK_GROUP) {
4887            writeln!(self.out, "{level}GroupMemoryBarrier();")?;
4888        }
4889        if barrier.contains(crate::Barrier::SUB_GROUP) {
4890            // Does not exist in DirectX
4891        }
4892        if barrier.contains(crate::Barrier::TEXTURE) {
4893            writeln!(self.out, "{level}DeviceMemoryBarrier();")?;
4894        }
4895        Ok(())
4896    }
4897
4898    /// Helper to emit the shared tail of an HLSL atomic call (arguments, value, result)
4899    fn emit_hlsl_atomic_tail(
4900        &mut self,
4901        module: &Module,
4902        func_ctx: &back::FunctionCtx<'_>,
4903        fun: &crate::AtomicFunction,
4904        compare_expr: Option<Handle<crate::Expression>>,
4905        value: Handle<crate::Expression>,
4906        res_var_info: &Option<(Handle<crate::Expression>, String)>,
4907    ) -> BackendResult {
4908        if let Some(cmp) = compare_expr {
4909            write!(self.out, ", ")?;
4910            self.write_expr(module, cmp, func_ctx)?;
4911        }
4912        write!(self.out, ", ")?;
4913        if let crate::AtomicFunction::Subtract = *fun {
4914            // we just wrote `InterlockedAdd`, so negate the argument
4915            write!(self.out, "-")?;
4916        }
4917        self.write_expr(module, value, func_ctx)?;
4918        if let Some(&(_res_handle, ref res_name)) = res_var_info.as_ref() {
4919            write!(self.out, ", ")?;
4920            if compare_expr.is_some() {
4921                write!(self.out, "{res_name}.old_value")?;
4922            } else {
4923                write!(self.out, "{res_name}")?;
4924            }
4925        }
4926        writeln!(self.out, ");")?;
4927        Ok(())
4928    }
4929}
4930
4931pub(super) struct MatrixType {
4932    pub(super) columns: crate::VectorSize,
4933    pub(super) rows: crate::VectorSize,
4934    pub(super) width: crate::Bytes,
4935}
4936
4937pub(super) fn get_inner_matrix_data(
4938    module: &Module,
4939    handle: Handle<crate::Type>,
4940) -> Option<MatrixType> {
4941    match module.types[handle].inner {
4942        TypeInner::Matrix {
4943            columns,
4944            rows,
4945            scalar,
4946        } => Some(MatrixType {
4947            columns,
4948            rows,
4949            width: scalar.width,
4950        }),
4951        TypeInner::Array { base, .. } => get_inner_matrix_data(module, base),
4952        _ => None,
4953    }
4954}
4955
4956/// If `base` is an access chain of the form `mat`, `mat[col]`, or `mat[col][row]`,
4957/// returns a tuple of the matrix, the column (vector) index (if present), and
4958/// the row (scalar) index (if present).
4959fn find_matrix_in_access_chain(
4960    module: &Module,
4961    base: Handle<crate::Expression>,
4962    func_ctx: &back::FunctionCtx<'_>,
4963) -> Option<(Handle<crate::Expression>, Option<Index>, Option<Index>)> {
4964    let mut current_base = base;
4965    let mut vector = None;
4966    let mut scalar = None;
4967    loop {
4968        let resolved_tr = func_ctx
4969            .resolve_type(current_base, &module.types)
4970            .pointer_base_type();
4971        let resolved = resolved_tr.as_ref()?.inner_with(&module.types);
4972
4973        match *resolved {
4974            TypeInner::Matrix { .. } => return Some((current_base, vector, scalar)),
4975            TypeInner::Scalar(_) | TypeInner::Vector { .. } => {}
4976            _ => return None,
4977        }
4978
4979        let index;
4980        (current_base, index) = match func_ctx.expressions[current_base] {
4981            crate::Expression::Access { base, index } => (base, Index::Expression(index)),
4982            crate::Expression::AccessIndex { base, index } => (base, Index::Static(index)),
4983            _ => return None,
4984        };
4985
4986        match *resolved {
4987            TypeInner::Scalar(_) => scalar = Some(index),
4988            TypeInner::Vector { .. } => vector = Some(index),
4989            _ => unreachable!(),
4990        }
4991    }
4992}
4993
4994/// Returns the matrix data if the access chain starting at `base`:
4995/// - starts with an expression with resolved type of [`TypeInner::Matrix`] if `direct = true`
4996/// - contains one or more expressions with resolved type of [`TypeInner::Array`] of [`TypeInner::Matrix`]
4997/// - ends at an expression with resolved type of [`TypeInner::Struct`]
4998pub(super) fn get_inner_matrix_of_struct_array_member(
4999    module: &Module,
5000    base: Handle<crate::Expression>,
5001    func_ctx: &back::FunctionCtx<'_>,
5002    direct: bool,
5003) -> Option<MatrixType> {
5004    let mut mat_data = None;
5005    let mut array_base = None;
5006
5007    let mut current_base = base;
5008    loop {
5009        let mut resolved = func_ctx.resolve_type(current_base, &module.types);
5010        if let TypeInner::Pointer { base, .. } = *resolved {
5011            resolved = &module.types[base].inner;
5012        };
5013
5014        match *resolved {
5015            TypeInner::Matrix {
5016                columns,
5017                rows,
5018                scalar,
5019            } => {
5020                mat_data = Some(MatrixType {
5021                    columns,
5022                    rows,
5023                    width: scalar.width,
5024                })
5025            }
5026            TypeInner::Array { base, .. } => {
5027                array_base = Some(base);
5028            }
5029            TypeInner::Struct { .. } => {
5030                if let Some(array_base) = array_base {
5031                    if direct {
5032                        return mat_data;
5033                    } else {
5034                        return get_inner_matrix_data(module, array_base);
5035                    }
5036                }
5037
5038                break;
5039            }
5040            _ => break,
5041        }
5042
5043        current_base = match func_ctx.expressions[current_base] {
5044            crate::Expression::Access { base, .. } => base,
5045            crate::Expression::AccessIndex { base, .. } => base,
5046            _ => break,
5047        };
5048    }
5049    None
5050}
5051
5052/// Returns the matrix data if the access chain starting at `base`:
5053/// - starts with an expression with resolved type of [`TypeInner::Matrix`], or
5054/// - contains zero or more expressions with resolved type of [`TypeInner::Array`] of [`TypeInner::Matrix`] if `direct = false`
5055/// - and ends with an [`Expression::GlobalVariable`](crate::Expression::GlobalVariable) in [`AddressSpace::Uniform`](crate::AddressSpace::Uniform)
5056fn get_inner_matrix_of_global_uniform(
5057    module: &Module,
5058    base: Handle<crate::Expression>,
5059    func_ctx: &back::FunctionCtx<'_>,
5060    direct: bool,
5061) -> Option<MatrixType> {
5062    let mut mat_data = None;
5063    let mut array_base = None;
5064
5065    let mut current_base = base;
5066    loop {
5067        let mut resolved = func_ctx.resolve_type(current_base, &module.types);
5068        if let TypeInner::Pointer { base, .. } = *resolved {
5069            resolved = &module.types[base].inner;
5070        };
5071
5072        match *resolved {
5073            TypeInner::Matrix {
5074                columns,
5075                rows,
5076                scalar,
5077            } => {
5078                mat_data = Some(MatrixType {
5079                    columns,
5080                    rows,
5081                    width: scalar.width,
5082                })
5083            }
5084            TypeInner::Array { base, .. } => {
5085                if !direct {
5086                    array_base = Some(base);
5087                }
5088            }
5089            _ => break,
5090        }
5091
5092        current_base = match func_ctx.expressions[current_base] {
5093            crate::Expression::Access { base, .. } => base,
5094            crate::Expression::AccessIndex { base, .. } => base,
5095            crate::Expression::GlobalVariable(handle)
5096                if module.global_variables[handle].space == crate::AddressSpace::Uniform =>
5097            {
5098                return mat_data.or_else(|| {
5099                    array_base.and_then(|array_base| get_inner_matrix_data(module, array_base))
5100                })
5101            }
5102            _ => break,
5103        };
5104    }
5105    None
5106}