naga/back/hlsl/
writer.rs

1use alloc::{
2    format,
3    string::{String, ToString},
4    vec::Vec,
5};
6use core::{
7    fmt::{self, Write as _},
8    mem,
9};
10
11use super::{
12    help,
13    help::{
14        WrappedArrayLength, WrappedConstructor, WrappedImageQuery, WrappedStructMatrixAccess,
15        WrappedZeroValue,
16    },
17    storage::StoreValue,
18    BackendResult, Error, FragmentEntryPoint, Options, PipelineOptions, ShaderModel,
19};
20use crate::{
21    back::{self, get_entry_points, Baked},
22    common,
23    proc::{self, index, ExternalTextureNameKey, NameKey},
24    valid, Handle, Module, RayQueryFunction, Scalar, ScalarKind, ShaderStage, TypeInner,
25};
26
27const LOCATION_SEMANTIC: &str = "LOC";
28const SPECIAL_CBUF_TYPE: &str = "NagaConstants";
29const SPECIAL_CBUF_VAR: &str = "_NagaConstants";
30const SPECIAL_FIRST_VERTEX: &str = "first_vertex";
31const SPECIAL_FIRST_INSTANCE: &str = "first_instance";
32const SPECIAL_OTHER: &str = "other";
33
34pub(crate) const MODF_FUNCTION: &str = "naga_modf";
35pub(crate) const FREXP_FUNCTION: &str = "naga_frexp";
36pub(crate) const EXTRACT_BITS_FUNCTION: &str = "naga_extractBits";
37pub(crate) const INSERT_BITS_FUNCTION: &str = "naga_insertBits";
38pub(crate) const SAMPLER_HEAP_VAR: &str = "nagaSamplerHeap";
39pub(crate) const COMPARISON_SAMPLER_HEAP_VAR: &str = "nagaComparisonSamplerHeap";
40pub(crate) const SAMPLE_EXTERNAL_TEXTURE_FUNCTION: &str = "nagaSampleExternalTexture";
41pub(crate) const ABS_FUNCTION: &str = "naga_abs";
42pub(crate) const DIV_FUNCTION: &str = "naga_div";
43pub(crate) const MOD_FUNCTION: &str = "naga_mod";
44pub(crate) const NEG_FUNCTION: &str = "naga_neg";
45pub(crate) const F2I32_FUNCTION: &str = "naga_f2i32";
46pub(crate) const F2U32_FUNCTION: &str = "naga_f2u32";
47pub(crate) const F2I64_FUNCTION: &str = "naga_f2i64";
48pub(crate) const F2U64_FUNCTION: &str = "naga_f2u64";
49pub(crate) const IMAGE_SAMPLE_BASE_CLAMP_TO_EDGE_FUNCTION: &str =
50    "nagaTextureSampleBaseClampToEdge";
51pub(crate) const IMAGE_LOAD_EXTERNAL_FUNCTION: &str = "nagaTextureLoadExternal";
52pub(crate) const RAY_QUERY_TRACKER_VARIABLE_PREFIX: &str = "naga_query_init_tracker_for_";
53/// Prefix for variables in a naga statement
54pub(crate) const INTERNAL_PREFIX: &str = "naga_";
55
56enum Index {
57    Expression(Handle<crate::Expression>),
58    Static(u32),
59}
60
61pub(super) struct EpStructMember {
62    pub(super) name: String,
63    pub(super) ty: Handle<crate::Type>,
64    // technically, this should always be `Some`
65    // (we `debug_assert!` this in `write_interface_struct`)
66    pub(super) binding: Option<crate::Binding>,
67    pub(super) index: u32,
68}
69
70/// Structure contains information required for generating
71/// wrapped structure of all entry points arguments
72pub(super) struct EntryPointBinding {
73    /// Name of the fake EP argument that contains the struct
74    /// with all the flattened input data.
75    pub(super) arg_name: String,
76    /// Generated structure name
77    pub(super) ty_name: String,
78    /// Members of generated structure
79    pub(super) members: Vec<EpStructMember>,
80    pub(super) local_invocation_index_name: Option<String>,
81}
82
83pub(super) struct EntryPointInterface {
84    /// If `Some`, the input of an entry point is gathered in a special
85    /// struct with members sorted by binding.
86    /// The `EntryPointBinding::members` array is sorted by index,
87    /// so that we can walk it in `write_ep_arguments_initialization`.
88    pub(crate) input: Option<EntryPointBinding>,
89    /// If `Some`, the output of an entry point is flattened.
90    /// The `EntryPointBinding::members` array is sorted by binding,
91    /// So that we can walk it in `Statement::Return` handler.
92    pub(crate) output: Option<EntryPointBinding>,
93    pub(crate) mesh_vertices: Option<EntryPointBinding>,
94    pub(crate) mesh_primitives: Option<EntryPointBinding>,
95    pub(crate) mesh_indices: Option<EntryPointBinding>,
96}
97
98#[derive(Clone, Eq, PartialEq, PartialOrd, Ord)]
99enum InterfaceKey {
100    Location(u32),
101    BuiltIn(crate::BuiltIn),
102    Other,
103}
104
105impl InterfaceKey {
106    const fn new(binding: Option<&crate::Binding>) -> Self {
107        match binding {
108            Some(&crate::Binding::Location { location, .. }) => Self::Location(location),
109            Some(&crate::Binding::BuiltIn(built_in)) => Self::BuiltIn(built_in),
110            None => Self::Other,
111        }
112    }
113}
114
115#[derive(Copy, Clone, PartialEq)]
116pub(super) enum Io {
117    Input,
118    Output,
119    MeshVertices,
120    MeshPrimitives,
121}
122
123/// Argument list for nested entry points
124pub(super) struct NestedEntryPointArgs {
125    /// Arguments literally declared by the user
126    pub user_args: Vec<String>,
127    pub task_payload: Option<String>,
128    pub local_invocation_index: String,
129}
130
131const fn is_subgroup_builtin_binding(binding: &Option<crate::Binding>) -> bool {
132    let &Some(crate::Binding::BuiltIn(builtin)) = binding else {
133        return false;
134    };
135    matches!(
136        builtin,
137        crate::BuiltIn::SubgroupSize
138            | crate::BuiltIn::SubgroupInvocationId
139            | crate::BuiltIn::NumSubgroups
140            | crate::BuiltIn::SubgroupId
141    )
142}
143
144/// Information for how to generate a `binding_array<sampler>` access.
145struct BindingArraySamplerInfo {
146    /// Variable name of the sampler heap
147    sampler_heap_name: &'static str,
148    /// Variable name of the sampler index buffer
149    sampler_index_buffer_name: String,
150    /// Variable name of the base index _into_ the sampler index buffer
151    binding_array_base_index_name: String,
152}
153
154impl<'a, W: fmt::Write> super::Writer<'a, W> {
155    pub fn new(out: W, options: &'a Options, pipeline_options: &'a PipelineOptions) -> Self {
156        Self {
157            out,
158            names: crate::FastHashMap::default(),
159            namer: proc::Namer::default(),
160            options,
161            pipeline_options,
162            entry_point_io: crate::FastHashMap::default(),
163            named_expressions: crate::NamedExpressions::default(),
164            wrapped: super::Wrapped::default(),
165            written_committed_intersection: false,
166            written_candidate_intersection: false,
167            continue_ctx: back::continue_forward::ContinueCtx::default(),
168            temp_access_chain: Vec::new(),
169            need_bake_expressions: Default::default(),
170            function_task_payload_var: Default::default(),
171        }
172    }
173
174    fn reset(&mut self, module: &Module) {
175        self.names.clear();
176        self.namer.reset(
177            module,
178            &super::keywords::RESERVED_SET,
179            proc::KeywordSet::empty(),
180            &super::keywords::RESERVED_CASE_INSENSITIVE_SET,
181            super::keywords::RESERVED_PREFIXES,
182            &mut self.names,
183        );
184        self.entry_point_io.clear();
185        self.named_expressions.clear();
186        self.wrapped.clear();
187        self.written_committed_intersection = false;
188        self.written_candidate_intersection = false;
189        self.continue_ctx.clear();
190        self.need_bake_expressions.clear();
191        self.function_task_payload_var.clear();
192    }
193
194    /// Generates statements to be inserted immediately before and at the very
195    /// start of the body of each loop, to defeat infinite loop reasoning.
196    /// The 0th item of the returned tuple should be inserted immediately prior
197    /// to the loop and the 1st item should be inserted at the very start of
198    /// the loop body.
199    ///
200    /// See [`back::msl::Writer::gen_force_bounded_loop_statements`] for details.
201    fn gen_force_bounded_loop_statements(
202        &mut self,
203        level: back::Level,
204    ) -> Option<(String, String)> {
205        if !self.options.force_loop_bounding {
206            return None;
207        }
208
209        let loop_bound_name = self.namer.call("loop_bound");
210        let max = u32::MAX;
211        // Count down from u32::MAX rather than up from 0 to avoid hang on
212        // certain Intel drivers. See <https://github.com/gfx-rs/wgpu/issues/7319>.
213        let decl = format!("{level}uint2 {loop_bound_name} = uint2({max}u, {max}u);");
214        let level = level.next();
215        let break_and_inc = format!(
216            "{level}if (all({loop_bound_name} == uint2(0u, 0u))) {{ break; }}
217{level}{loop_bound_name} -= uint2({loop_bound_name}.y == 0u, 1u);"
218        );
219
220        Some((decl, break_and_inc))
221    }
222
223    /// Helper method used to find which expressions of a given function require baking
224    ///
225    /// # Notes
226    /// Clears `need_bake_expressions` set before adding to it
227    fn update_expressions_to_bake(
228        &mut self,
229        module: &Module,
230        func: &crate::Function,
231        info: &valid::FunctionInfo,
232    ) {
233        use crate::Expression;
234        self.need_bake_expressions.clear();
235        for (exp_handle, expr) in func.expressions.iter() {
236            let expr_info = &info[exp_handle];
237            let min_ref_count = func.expressions[exp_handle].bake_ref_count();
238            if min_ref_count <= expr_info.ref_count {
239                self.need_bake_expressions.insert(exp_handle);
240            }
241            if let Expression::Load { pointer } = *expr {
242                if info[pointer]
243                    .ty
244                    .inner_with(&module.types)
245                    .is_atomic_pointer(&module.types)
246                {
247                    self.need_bake_expressions.insert(exp_handle);
248                }
249            }
250
251            if let Expression::Math { fun, arg, arg1, .. } = *expr {
252                match fun {
253                    crate::MathFunction::Asinh
254                    | crate::MathFunction::Acosh
255                    | crate::MathFunction::Atanh
256                    | crate::MathFunction::Unpack2x16float
257                    | crate::MathFunction::Unpack2x16snorm
258                    | crate::MathFunction::Unpack2x16unorm
259                    | crate::MathFunction::Unpack4x8snorm
260                    | crate::MathFunction::Unpack4x8unorm
261                    | crate::MathFunction::Unpack4xI8
262                    | crate::MathFunction::Unpack4xU8
263                    | crate::MathFunction::Pack2x16float
264                    | crate::MathFunction::Pack2x16snorm
265                    | crate::MathFunction::Pack2x16unorm
266                    | crate::MathFunction::Pack4x8snorm
267                    | crate::MathFunction::Pack4x8unorm
268                    | crate::MathFunction::Pack4xI8
269                    | crate::MathFunction::Pack4xU8
270                    | crate::MathFunction::Pack4xI8Clamp
271                    | crate::MathFunction::Pack4xU8Clamp => {
272                        self.need_bake_expressions.insert(arg);
273                    }
274                    crate::MathFunction::CountLeadingZeros => {
275                        let inner = info[exp_handle].ty.inner_with(&module.types);
276                        if let Some(ScalarKind::Sint) = inner.scalar_kind() {
277                            self.need_bake_expressions.insert(arg);
278                        }
279                    }
280                    crate::MathFunction::Dot4U8Packed | crate::MathFunction::Dot4I8Packed => {
281                        self.need_bake_expressions.insert(arg);
282                        self.need_bake_expressions.insert(arg1.unwrap());
283                    }
284                    _ => {}
285                }
286            }
287
288            if let Expression::Derivative { axis, ctrl, expr } = *expr {
289                use crate::{DerivativeAxis as Axis, DerivativeControl as Ctrl};
290                if axis == Axis::Width && (ctrl == Ctrl::Coarse || ctrl == Ctrl::Fine) {
291                    self.need_bake_expressions.insert(expr);
292                }
293            }
294
295            if let Expression::GlobalVariable(_) = *expr {
296                let inner = info[exp_handle].ty.inner_with(&module.types);
297
298                if let TypeInner::Sampler { .. } = *inner {
299                    self.need_bake_expressions.insert(exp_handle);
300                }
301            }
302        }
303        for statement in func.body.iter() {
304            match *statement {
305                crate::Statement::SubgroupCollectiveOperation {
306                    op: _,
307                    collective_op: crate::CollectiveOperation::InclusiveScan,
308                    argument,
309                    result: _,
310                } => {
311                    self.need_bake_expressions.insert(argument);
312                }
313                crate::Statement::Atomic {
314                    fun: crate::AtomicFunction::Exchange { compare: Some(cmp) },
315                    ..
316                } => {
317                    self.need_bake_expressions.insert(cmp);
318                }
319                _ => {}
320            }
321        }
322    }
323
324    pub fn write(
325        &mut self,
326        module: &Module,
327        module_info: &valid::ModuleInfo,
328        fragment_entry_point: Option<&FragmentEntryPoint<'_>>,
329    ) -> Result<super::ReflectionInfo, Error> {
330        self.reset(module);
331
332        if module.uses_mesh_shaders() && self.options.shader_model < ShaderModel::V6_5 {
333            return Err(Error::ShaderModelTooLow(
334                "mesh shaders".to_string(),
335                ShaderModel::V6_5,
336            ));
337        }
338
339        // Write special constants, if needed
340        if let Some(ref bt) = self.options.special_constants_binding {
341            writeln!(self.out, "struct {SPECIAL_CBUF_TYPE} {{")?;
342            writeln!(self.out, "{}int {};", back::INDENT, SPECIAL_FIRST_VERTEX)?;
343            writeln!(self.out, "{}int {};", back::INDENT, SPECIAL_FIRST_INSTANCE)?;
344            writeln!(self.out, "{}uint {};", back::INDENT, SPECIAL_OTHER)?;
345            writeln!(self.out, "}};")?;
346            write!(
347                self.out,
348                "ConstantBuffer<{}> {}: register(b{}",
349                SPECIAL_CBUF_TYPE, SPECIAL_CBUF_VAR, bt.register
350            )?;
351            if bt.space != 0 {
352                write!(self.out, ", space{}", bt.space)?;
353            }
354            writeln!(self.out, ");")?;
355
356            // Extra newline for readability
357            writeln!(self.out)?;
358        }
359
360        for (group, bt) in self.options.dynamic_storage_buffer_offsets_targets.iter() {
361            writeln!(self.out, "struct __dynamic_buffer_offsetsTy{group} {{")?;
362            for i in 0..bt.size {
363                writeln!(self.out, "{}uint _{};", back::INDENT, i)?;
364            }
365            writeln!(self.out, "}};")?;
366            writeln!(
367                self.out,
368                "ConstantBuffer<__dynamic_buffer_offsetsTy{}> __dynamic_buffer_offsets{}: register(b{}, space{});",
369                group, group, bt.register, bt.space
370            )?;
371
372            // Extra newline for readability
373            writeln!(self.out)?;
374        }
375
376        // Save all entry point output types
377        let ep_results = module
378            .entry_points
379            .iter()
380            .map(|ep| (ep.stage, ep.function.result.clone()))
381            .collect::<Vec<(ShaderStage, Option<crate::FunctionResult>)>>();
382
383        self.write_all_mat_cx2_typedefs_and_functions(module)?;
384
385        // Write all structs
386        for (handle, ty) in module.types.iter() {
387            if let TypeInner::Struct { ref members, span } = ty.inner {
388                if module.types[members.last().unwrap().ty]
389                    .inner
390                    .is_dynamically_sized(&module.types)
391                {
392                    // unsized arrays can only be in storage buffers,
393                    // for which we use `ByteAddressBuffer` anyway.
394                    continue;
395                }
396
397                let ep_result = ep_results.iter().find(|e| {
398                    if let Some(ref result) = e.1 {
399                        result.ty == handle
400                    } else {
401                        false
402                    }
403                });
404
405                self.write_struct(
406                    module,
407                    handle,
408                    members,
409                    span,
410                    ep_result.map(|r| (r.0, Io::Output)),
411                )?;
412                writeln!(self.out)?;
413            }
414        }
415
416        self.write_special_functions(module)?;
417
418        self.write_wrapped_expression_functions(module, &module.global_expressions, None)?;
419        self.write_wrapped_zero_value_functions(module, &module.global_expressions)?;
420
421        // Write all named constants
422        let mut constants = module
423            .constants
424            .iter()
425            .filter(|&(_, c)| c.name.is_some())
426            .peekable();
427        while let Some((handle, _)) = constants.next() {
428            self.write_global_constant(module, handle)?;
429            // Add extra newline for readability on last iteration
430            if constants.peek().is_none() {
431                writeln!(self.out)?;
432            }
433        }
434
435        // Write all globals
436        for (global, _) in module.global_variables.iter() {
437            self.write_global(module, global)?;
438        }
439
440        if !module.global_variables.is_empty() {
441            // Add extra newline for readability
442            writeln!(self.out)?;
443        }
444
445        let ep_range = get_entry_points(module, self.pipeline_options.entry_point.as_ref())
446            .map_err(|(stage, name)| Error::EntryPointNotFound(stage, name))?;
447
448        // Write all entry points wrapped structs
449        for index in ep_range.clone() {
450            let ep = &module.entry_points[index];
451            let ep_name = self.names[&NameKey::EntryPoint(index as u16)].clone();
452            let ep_io = self.write_ep_interface(module, ep, &ep_name, fragment_entry_point)?;
453            self.entry_point_io.insert(index, ep_io);
454        }
455
456        // Write all regular functions
457        for (handle, function) in module.functions.iter() {
458            let info = &module_info[handle];
459
460            // Check if all of the globals are accessible
461            if !self.options.fake_missing_bindings {
462                if let Some((var_handle, _)) =
463                    module
464                        .global_variables
465                        .iter()
466                        .find(|&(var_handle, var)| match var.binding {
467                            Some(ref binding) if !info[var_handle].is_empty() => {
468                                self.options.resolve_resource_binding(binding).is_err()
469                                    && self
470                                        .options
471                                        .resolve_external_texture_resource_binding(binding)
472                                        .is_err()
473                            }
474                            _ => false,
475                        })
476                {
477                    log::debug!(
478                        "Skipping function {:?} (name {:?}) because global {:?} is inaccessible",
479                        handle,
480                        function.name,
481                        var_handle
482                    );
483                    continue;
484                }
485            }
486
487            let ctx = back::FunctionCtx {
488                ty: back::FunctionType::Function(handle),
489                info,
490                expressions: &function.expressions,
491                named_expressions: &function.named_expressions,
492            };
493            let name = self.names[&NameKey::Function(handle)].clone();
494
495            self.write_wrapped_functions(module, &ctx)?;
496
497            self.write_function(module, name.as_str(), function, &ctx, info, String::new())?;
498
499            writeln!(self.out)?;
500        }
501
502        let mut translated_ep_names = Vec::with_capacity(ep_range.len());
503
504        // Write all entry points
505        for index in ep_range {
506            let ep = &module.entry_points[index];
507            let info = module_info.get_entry_point(index);
508
509            if !self.options.fake_missing_bindings {
510                let mut ep_error = None;
511                for (var_handle, var) in module.global_variables.iter() {
512                    match var.binding {
513                        Some(ref binding) if !info[var_handle].is_empty() => {
514                            if let Err(err) = self.options.resolve_resource_binding(binding) {
515                                if self
516                                    .options
517                                    .resolve_external_texture_resource_binding(binding)
518                                    .is_err()
519                                {
520                                    ep_error = Some(err);
521                                    break;
522                                }
523                            }
524                        }
525                        _ => {}
526                    }
527                }
528                if let Some(err) = ep_error {
529                    translated_ep_names.push(Err(err));
530                    continue;
531                }
532            }
533
534            let ctx = back::FunctionCtx {
535                ty: back::FunctionType::EntryPoint(index as u16),
536                info,
537                expressions: &ep.function.expressions,
538                named_expressions: &ep.function.named_expressions,
539            };
540
541            self.write_wrapped_functions(module, &ctx)?;
542
543            // Mesh/task shaders have a wrapper entry point which is declared after the "main"
544            // user-written function. We therefore cannot always just document the next function.
545            let mut attribute_string = String::new();
546            if ep.stage.compute_like() {
547                // HLSL is calling workgroup size "num threads"
548                let num_threads = ep.workgroup_size;
549                writeln!(
550                    attribute_string,
551                    "[numthreads({}, {}, {})]",
552                    num_threads[0], num_threads[1], num_threads[2]
553                )?;
554            }
555            if let Some(ref info) = ep.mesh_info {
556                let topology_str = match info.topology {
557                    crate::MeshOutputTopology::Points => unreachable!(),
558                    crate::MeshOutputTopology::Lines => "line",
559                    crate::MeshOutputTopology::Triangles => "triangle",
560                };
561                writeln!(attribute_string, "[outputtopology(\"{topology_str}\")]")?;
562            }
563
564            let name = self.names[&NameKey::EntryPoint(index as u16)].clone();
565            self.write_function(module, &name, &ep.function, &ctx, info, attribute_string)?;
566
567            if index < module.entry_points.len() - 1 {
568                writeln!(self.out)?;
569            }
570
571            translated_ep_names.push(Ok(name));
572        }
573
574        Ok(super::ReflectionInfo {
575            entry_point_names: translated_ep_names,
576        })
577    }
578
579    fn write_modifier(&mut self, binding: &crate::Binding) -> BackendResult {
580        match *binding {
581            crate::Binding::BuiltIn(crate::BuiltIn::Position { invariant: true }) => {
582                write!(self.out, "precise ")?;
583            }
584            crate::Binding::BuiltIn(crate::BuiltIn::Barycentric { perspective: false }) => {
585                write!(self.out, "noperspective ")?;
586            }
587            crate::Binding::Location {
588                interpolation,
589                sampling,
590                ..
591            } => {
592                if let Some(interpolation) = interpolation {
593                    if let Some(string) = interpolation.to_hlsl_str() {
594                        write!(self.out, "{string} ")?
595                    }
596                }
597
598                if let Some(sampling) = sampling {
599                    if let Some(string) = sampling.to_hlsl_str() {
600                        write!(self.out, "{string} ")?
601                    }
602                }
603            }
604            crate::Binding::BuiltIn(_) => {}
605        }
606
607        Ok(())
608    }
609
610    //TODO: we could force fragment outputs to always go through `entry_point_io.output` path
611    // if they are struct, so that the `stage` argument here could be omitted.
612    pub(super) fn write_semantic(
613        &mut self,
614        binding: &Option<crate::Binding>,
615        stage: Option<(ShaderStage, Io)>,
616    ) -> BackendResult {
617        let is_per_primitive = match *binding {
618            Some(crate::Binding::BuiltIn(builtin)) if !is_subgroup_builtin_binding(binding) => {
619                if builtin == crate::BuiltIn::ViewIndex
620                    && self.options.shader_model < ShaderModel::V6_1
621                {
622                    return Err(Error::ShaderModelTooLow(
623                        "used @builtin(view_index) or SV_ViewID".to_string(),
624                        ShaderModel::V6_1,
625                    ));
626                }
627                if let Some(builtin_str) = builtin.to_hlsl_str()? {
628                    write!(self.out, " : {builtin_str}")?;
629                }
630                false
631            }
632            Some(crate::Binding::Location {
633                blend_src: Some(1),
634                per_primitive,
635                ..
636            }) => {
637                write!(self.out, " : SV_Target1")?;
638                per_primitive
639            }
640            Some(crate::Binding::Location {
641                location,
642                per_primitive,
643                ..
644            }) => {
645                if stage == Some((ShaderStage::Fragment, Io::Output)) {
646                    write!(self.out, " : SV_Target{location}")?;
647                } else {
648                    write!(self.out, " : {LOCATION_SEMANTIC}{location}")?;
649                }
650                per_primitive
651            }
652            _ => false,
653        };
654        if is_per_primitive {
655            write!(self.out, " : primitive")?;
656        }
657
658        Ok(())
659    }
660
661    pub(super) fn write_interface_struct(
662        &mut self,
663        module: &Module,
664        shader_stage: (ShaderStage, Io),
665        struct_name: String,
666        var_name: Option<&str>,
667        mut members: Vec<EpStructMember>,
668    ) -> Result<EntryPointBinding, Error> {
669        let struct_name = self.namer.call(&struct_name);
670        // Sort the members so that first come the user-defined varyings
671        // in ascending locations, and then built-ins. This allows VS and FS
672        // interfaces to match with regards to order.
673        members.sort_by_key(|m| InterfaceKey::new(m.binding.as_ref()));
674
675        write!(self.out, "struct {struct_name}")?;
676        writeln!(self.out, " {{")?;
677        let mut local_invocation_index_name = None;
678        let mut subgroup_id_used = false;
679        for m in members.iter() {
680            // Sanity check that each IO member is a built-in or is assigned a
681            // location. Also see note about nesting in `write_ep_input_struct`.
682            debug_assert!(m.binding.is_some());
683
684            match m.binding {
685                Some(crate::Binding::BuiltIn(crate::BuiltIn::SubgroupId)) => {
686                    subgroup_id_used = true;
687                }
688                Some(crate::Binding::BuiltIn(crate::BuiltIn::LocalInvocationIndex)) => {
689                    local_invocation_index_name = Some(m.name.clone());
690                }
691                _ => (),
692            }
693
694            if is_subgroup_builtin_binding(&m.binding) {
695                continue;
696            }
697            write!(self.out, "{}", back::INDENT)?;
698            if let Some(ref binding) = m.binding {
699                self.write_modifier(binding)?;
700            }
701            self.write_type(module, m.ty)?;
702            write!(self.out, " {}", &m.name)?;
703            self.write_semantic(&m.binding, Some(shader_stage))?;
704            writeln!(self.out, ";")?;
705        }
706        if subgroup_id_used && local_invocation_index_name.is_none() {
707            let name = self.namer.call("local_invocation_index");
708            writeln!(self.out, "{}uint {name} : SV_GroupIndex;", back::INDENT)?;
709            local_invocation_index_name = Some(name);
710        }
711        writeln!(self.out, "}};")?;
712        writeln!(self.out)?;
713
714        // See ordering notes on EntryPointInterface fields
715        match shader_stage.1 {
716            Io::Input => {
717                // bring back the original order
718                members.sort_by_key(|m| m.index);
719            }
720            Io::Output | Io::MeshVertices | Io::MeshPrimitives => {
721                // keep it sorted by binding
722            }
723        }
724
725        Ok(EntryPointBinding {
726            arg_name: self
727                .namer
728                .call(var_name.unwrap_or(struct_name.to_lowercase().as_str())),
729            ty_name: struct_name,
730            members,
731            local_invocation_index_name,
732        })
733    }
734
735    /// Flatten all entry point arguments into a single struct.
736    /// This is needed since we need to re-order them: first placing user locations,
737    /// then built-ins.
738    fn write_ep_input_struct(
739        &mut self,
740        module: &Module,
741        func: &crate::Function,
742        stage: ShaderStage,
743        entry_point_name: &str,
744    ) -> Result<EntryPointBinding, Error> {
745        let struct_name = format!("{stage:?}Input_{entry_point_name}");
746
747        let mut fake_members = Vec::new();
748        for arg in func.arguments.iter() {
749            // NOTE: We don't need to handle nesting structs. All members must
750            // be either built-ins or assigned a location. I.E. `binding` is
751            // `Some`. This is checked in `VaryingContext::validate`. See:
752            // https://gpuweb.github.io/gpuweb/wgsl/#input-output-locations
753            match module.types[arg.ty].inner {
754                TypeInner::Struct { ref members, .. } => {
755                    for member in members.iter() {
756                        let name = self.namer.call_or(&member.name, "member");
757                        let index = fake_members.len() as u32;
758                        fake_members.push(EpStructMember {
759                            name,
760                            ty: member.ty,
761                            binding: member.binding.clone(),
762                            index,
763                        });
764                    }
765                }
766                _ => {
767                    let member_name = self.namer.call_or(&arg.name, "member");
768                    let index = fake_members.len() as u32;
769                    fake_members.push(EpStructMember {
770                        name: member_name,
771                        ty: arg.ty,
772                        binding: arg.binding.clone(),
773                        index,
774                    });
775                }
776            }
777        }
778
779        self.write_interface_struct(module, (stage, Io::Input), struct_name, None, fake_members)
780    }
781
782    /// Flatten all entry point results into a single struct.
783    /// This is needed since we need to re-order them: first placing user locations,
784    /// then built-ins.
785    fn write_ep_output_struct(
786        &mut self,
787        module: &Module,
788        result: &crate::FunctionResult,
789        stage: ShaderStage,
790        entry_point_name: &str,
791        frag_ep: Option<&FragmentEntryPoint<'_>>,
792    ) -> Result<EntryPointBinding, Error> {
793        let struct_name = format!("{stage:?}Output_{entry_point_name}");
794
795        let empty = [];
796        let members = match module.types[result.ty].inner {
797            TypeInner::Struct { ref members, .. } => members,
798            ref other => {
799                log::error!("Unexpected {other:?} output type without a binding");
800                &empty[..]
801            }
802        };
803
804        // Gather list of fragment input locations. We use this below to remove user-defined
805        // varyings from VS outputs that aren't in the FS inputs. This makes the VS interface match
806        // as long as the FS inputs are a subset of the VS outputs. This is only applied if the
807        // writer is supplied with information about the fragment entry point.
808        let fs_input_locs = if let (Some(frag_ep), ShaderStage::Vertex) = (frag_ep, stage) {
809            let mut fs_input_locs = Vec::new();
810            for arg in frag_ep.func.arguments.iter() {
811                let mut push_if_location = |binding: &Option<crate::Binding>| match *binding {
812                    Some(crate::Binding::Location { location, .. }) => fs_input_locs.push(location),
813                    Some(crate::Binding::BuiltIn(_)) | None => {}
814                };
815
816                // NOTE: We don't need to handle struct nesting. See note in
817                // `write_ep_input_struct`.
818                match frag_ep.module.types[arg.ty].inner {
819                    TypeInner::Struct { ref members, .. } => {
820                        for member in members.iter() {
821                            push_if_location(&member.binding);
822                        }
823                    }
824                    _ => push_if_location(&arg.binding),
825                }
826            }
827            fs_input_locs.sort();
828            Some(fs_input_locs)
829        } else {
830            None
831        };
832
833        let mut fake_members = Vec::new();
834        for (index, member) in members.iter().enumerate() {
835            if let Some(ref fs_input_locs) = fs_input_locs {
836                match member.binding {
837                    Some(crate::Binding::Location { location, .. }) => {
838                        if fs_input_locs.binary_search(&location).is_err() {
839                            continue;
840                        }
841                    }
842                    Some(crate::Binding::BuiltIn(_)) | None => {}
843                }
844            }
845
846            let member_name = self.namer.call_or(&member.name, "member");
847            fake_members.push(EpStructMember {
848                name: member_name,
849                ty: member.ty,
850                binding: member.binding.clone(),
851                index: index as u32,
852            });
853        }
854
855        self.write_interface_struct(module, (stage, Io::Output), struct_name, None, fake_members)
856    }
857
858    /// Writes special interface structures for an entry point. The special structures have
859    /// all the fields flattened into them and sorted by binding. They are needed to emulate
860    /// subgroup built-ins and to make the interfaces between VS outputs and FS inputs match.
861    fn write_ep_interface(
862        &mut self,
863        module: &Module,
864        ep: &crate::EntryPoint,
865        ep_name: &str,
866        frag_ep: Option<&FragmentEntryPoint<'_>>,
867    ) -> Result<EntryPointInterface, Error> {
868        let func = &ep.function;
869        let stage = ep.stage;
870        Ok(EntryPointInterface {
871            input: if !func.arguments.is_empty()
872                && (stage == ShaderStage::Fragment
873                    || func
874                        .arguments
875                        .iter()
876                        .any(|arg| is_subgroup_builtin_binding(&arg.binding)))
877            {
878                Some(self.write_ep_input_struct(module, func, stage, ep_name)?)
879            } else {
880                None
881            },
882            output: match func.result {
883                Some(ref fr) if fr.binding.is_none() && stage == ShaderStage::Vertex => {
884                    Some(self.write_ep_output_struct(module, fr, stage, ep_name, frag_ep)?)
885                }
886                _ => None,
887            },
888            mesh_vertices: if let Some(ref info) = ep.mesh_info {
889                Some(self.write_ep_mesh_output_struct(module, ep_name, false, info)?)
890            } else {
891                None
892            },
893            mesh_primitives: if let Some(ref info) = ep.mesh_info {
894                Some(self.write_ep_mesh_output_struct(module, ep_name, true, info)?)
895            } else {
896                None
897            },
898            mesh_indices: if let Some(ref info) = ep.mesh_info {
899                Some(self.write_ep_mesh_output_indices(info.topology)?)
900            } else {
901                None
902            },
903        })
904    }
905
906    fn write_ep_argument_initialization(
907        &mut self,
908        ep: &crate::EntryPoint,
909        ep_input: &EntryPointBinding,
910        fake_member: &EpStructMember,
911    ) -> BackendResult {
912        match fake_member.binding {
913            Some(crate::Binding::BuiltIn(crate::BuiltIn::SubgroupSize)) => {
914                write!(self.out, "WaveGetLaneCount()")?
915            }
916            Some(crate::Binding::BuiltIn(crate::BuiltIn::SubgroupInvocationId)) => {
917                write!(self.out, "WaveGetLaneIndex()")?
918            }
919            Some(crate::Binding::BuiltIn(crate::BuiltIn::NumSubgroups)) => write!(
920                self.out,
921                "({}u + WaveGetLaneCount() - 1u) / WaveGetLaneCount()",
922                ep.workgroup_size[0] * ep.workgroup_size[1] * ep.workgroup_size[2]
923            )?,
924            Some(crate::Binding::BuiltIn(crate::BuiltIn::SubgroupId)) => {
925                write!(
926                    self.out,
927                    "{}.{} / WaveGetLaneCount()",
928                    ep_input.arg_name,
929                    // When writing SubgroupId, we always guarantee that local_invocation_index_name is written
930                    ep_input.local_invocation_index_name.as_ref().unwrap()
931                )?;
932            }
933            Some(crate::Binding::Location {
934                interpolation: Some(crate::Interpolation::PerVertex),
935                ..
936            }) => {
937                if self.options.shader_model < ShaderModel::V6_1 {
938                    return Err(Error::ShaderModelTooLow(
939                        "per_vertex fragment inputs".to_string(),
940                        ShaderModel::V6_1,
941                    ));
942                }
943                write!(
944                    self.out,
945                    "{{ GetAttributeAtVertex({0}.{1}, 0), GetAttributeAtVertex({0}.{1}, 1), GetAttributeAtVertex({0}.{1}, 2) }}",
946                    ep_input.arg_name,
947                    fake_member.name,
948                )?;
949            }
950            _ => {
951                write!(self.out, "{}.{}", ep_input.arg_name, fake_member.name)?;
952            }
953        }
954        Ok(())
955    }
956
957    /// Write an entry point preface that initializes the arguments as specified in IR.
958    fn write_ep_arguments_initialization(
959        &mut self,
960        module: &Module,
961        func: &crate::Function,
962        ep_index: u16,
963    ) -> BackendResult {
964        let ep = &module.entry_points[ep_index as usize];
965        let ep_input = match self
966            .entry_point_io
967            .get_mut(&(ep_index as usize))
968            .unwrap()
969            .input
970            .take()
971        {
972            Some(ep_input) => ep_input,
973            None => return Ok(()),
974        };
975        let mut fake_iter = ep_input.members.iter();
976        for (arg_index, arg) in func.arguments.iter().enumerate() {
977            write!(self.out, "{}", back::INDENT)?;
978            self.write_type(module, arg.ty)?;
979            let arg_name = &self.names[&NameKey::EntryPointArgument(ep_index, arg_index as u32)];
980            write!(self.out, " {arg_name}")?;
981            match module.types[arg.ty].inner {
982                TypeInner::Array { base, size, .. } => {
983                    self.write_array_size(module, base, size)?;
984                    write!(self.out, " = ")?;
985                    self.write_ep_argument_initialization(
986                        ep,
987                        &ep_input,
988                        fake_iter.next().unwrap(),
989                    )?;
990                    writeln!(self.out, ";")?;
991                }
992                TypeInner::Struct { ref members, .. } => {
993                    write!(self.out, " = {{ ")?;
994                    for index in 0..members.len() {
995                        if index != 0 {
996                            write!(self.out, ", ")?;
997                        }
998                        self.write_ep_argument_initialization(
999                            ep,
1000                            &ep_input,
1001                            fake_iter.next().unwrap(),
1002                        )?;
1003                    }
1004                    writeln!(self.out, " }};")?;
1005                }
1006                _ => {
1007                    write!(self.out, " = ")?;
1008                    self.write_ep_argument_initialization(
1009                        ep,
1010                        &ep_input,
1011                        fake_iter.next().unwrap(),
1012                    )?;
1013                    writeln!(self.out, ";")?;
1014                }
1015            }
1016        }
1017        assert!(fake_iter.next().is_none());
1018        Ok(())
1019    }
1020
1021    /// Helper method used to write global variables
1022    /// # Notes
1023    /// Always adds a newline
1024    fn write_global(
1025        &mut self,
1026        module: &Module,
1027        handle: Handle<crate::GlobalVariable>,
1028    ) -> BackendResult {
1029        let global = &module.global_variables[handle];
1030        let inner = &module.types[global.ty].inner;
1031
1032        let handle_ty = match *inner {
1033            TypeInner::BindingArray { ref base, .. } => &module.types[*base].inner,
1034            _ => inner,
1035        };
1036
1037        // External textures are handled entirely differently, so defer entirely to that method.
1038        // We do so prior to calling resolve_resource_binding() below, as we even need to resolve
1039        // their bindings separately.
1040        let is_external_texture = matches!(
1041            *handle_ty,
1042            TypeInner::Image {
1043                class: crate::ImageClass::External,
1044                ..
1045            }
1046        );
1047        if is_external_texture {
1048            return self.write_global_external_texture(module, handle, global);
1049        }
1050
1051        if let Some(ref binding) = global.binding {
1052            if let Err(err) = self.options.resolve_resource_binding(binding) {
1053                log::debug!(
1054                    "Skipping global {:?} (name {:?}) for being inaccessible: {}",
1055                    handle,
1056                    global.name,
1057                    err,
1058                );
1059                return Ok(());
1060            }
1061        }
1062
1063        // Samplers are handled entirely differently, so defer entirely to that method.
1064        let is_sampler = matches!(*handle_ty, TypeInner::Sampler { .. });
1065
1066        if is_sampler {
1067            return self.write_global_sampler(module, handle, global);
1068        }
1069
1070        // https://docs.microsoft.com/en-us/windows/win32/direct3dhlsl/dx-graphics-hlsl-variable-register
1071        let register_ty = match global.space {
1072            crate::AddressSpace::Function => unreachable!("Function address space"),
1073            crate::AddressSpace::Private => {
1074                write!(self.out, "static ")?;
1075                self.write_type(module, global.ty)?;
1076                ""
1077            }
1078            crate::AddressSpace::WorkGroup | crate::AddressSpace::TaskPayload => {
1079                write!(self.out, "groupshared ")?;
1080                self.write_type(module, global.ty)?;
1081                ""
1082            }
1083            crate::AddressSpace::Uniform => {
1084                // constant buffer declarations are expected to be inlined, e.g.
1085                // `cbuffer foo: register(b0) { field1: type1; }`
1086                write!(self.out, "cbuffer")?;
1087                "b"
1088            }
1089            crate::AddressSpace::Storage { access } => {
1090                if global
1091                    .memory_decorations
1092                    .contains(crate::MemoryDecorations::COHERENT)
1093                {
1094                    write!(self.out, "globallycoherent ")?;
1095                }
1096                let (prefix, register) = if access.contains(crate::StorageAccess::STORE) {
1097                    ("RW", "u")
1098                } else {
1099                    ("", "t")
1100                };
1101                write!(self.out, "{prefix}ByteAddressBuffer")?;
1102                register
1103            }
1104            crate::AddressSpace::Handle => {
1105                let register = match *handle_ty {
1106                    // all storage textures are UAV, unconditionally
1107                    TypeInner::Image {
1108                        class: crate::ImageClass::Storage { .. },
1109                        ..
1110                    } => "u",
1111                    _ => "t",
1112                };
1113                self.write_type(module, global.ty)?;
1114                register
1115            }
1116            crate::AddressSpace::Immediate => {
1117                // The type of the immediates will be wrapped in `ConstantBuffer`
1118                write!(self.out, "ConstantBuffer<")?;
1119                "b"
1120            }
1121            crate::AddressSpace::RayPayload | crate::AddressSpace::IncomingRayPayload => {
1122                unimplemented!()
1123            }
1124        };
1125
1126        // If the global is a immediate data write the type now because it will be a
1127        // generic argument to `ConstantBuffer`
1128        if global.space == crate::AddressSpace::Immediate {
1129            self.write_global_type(module, global.ty)?;
1130
1131            // need to write the array size if the type was emitted with `write_type`
1132            if let TypeInner::Array { base, size, .. } = module.types[global.ty].inner {
1133                self.write_array_size(module, base, size)?;
1134            }
1135
1136            // Close the angled brackets for the generic argument
1137            write!(self.out, ">")?;
1138        }
1139
1140        let name = &self.names[&NameKey::GlobalVariable(handle)];
1141        write!(self.out, " {name}")?;
1142
1143        // Immediates need to be assigned a binding explicitly by the consumer
1144        // since naga has no way to know the binding from the shader alone
1145        if global.space == crate::AddressSpace::Immediate {
1146            match module.types[global.ty].inner {
1147                TypeInner::Struct { .. } => {}
1148                _ => {
1149                    return Err(Error::Unimplemented(format!(
1150                        "push-constant '{name}' has non-struct type; tracked by: https://github.com/gfx-rs/wgpu/issues/5683"
1151                    )));
1152                }
1153            }
1154
1155            let target = self
1156                .options
1157                .immediates_target
1158                .as_ref()
1159                .expect("No bind target was defined for the immediates block");
1160            write!(self.out, ": register(b{}", target.register)?;
1161            if target.space != 0 {
1162                write!(self.out, ", space{}", target.space)?;
1163            }
1164            write!(self.out, ")")?;
1165        }
1166
1167        if let Some(ref binding) = global.binding {
1168            // this was already resolved earlier when we started evaluating an entry point.
1169            let bt = self.options.resolve_resource_binding(binding).unwrap();
1170
1171            // need to write the binding array size if the type was emitted with `write_type`
1172            if let TypeInner::BindingArray { base, size, .. } = module.types[global.ty].inner {
1173                if let Some(overridden_size) = bt.binding_array_size {
1174                    write!(self.out, "[{overridden_size}]")?;
1175                } else {
1176                    self.write_array_size(module, base, size)?;
1177                }
1178            }
1179
1180            write!(self.out, " : register({}{}", register_ty, bt.register)?;
1181            if bt.space != 0 {
1182                write!(self.out, ", space{}", bt.space)?;
1183            }
1184            write!(self.out, ")")?;
1185        } else {
1186            // need to write the array size if the type was emitted with `write_type`
1187            if let TypeInner::Array { base, size, .. } = module.types[global.ty].inner {
1188                self.write_array_size(module, base, size)?;
1189            }
1190            if global.space == crate::AddressSpace::Private {
1191                write!(self.out, " = ")?;
1192                if let Some(init) = global.init {
1193                    self.write_const_expression(module, init, &module.global_expressions)?;
1194                } else {
1195                    self.write_default_init(module, global.ty)?;
1196                }
1197            }
1198        }
1199
1200        if global.space == crate::AddressSpace::Uniform {
1201            write!(self.out, " {{ ")?;
1202
1203            self.write_global_type(module, global.ty)?;
1204
1205            write!(
1206                self.out,
1207                " {}",
1208                &self.names[&NameKey::GlobalVariable(handle)]
1209            )?;
1210
1211            // need to write the array size if the type was emitted with `write_type`
1212            if let TypeInner::Array { base, size, .. } = module.types[global.ty].inner {
1213                self.write_array_size(module, base, size)?;
1214            }
1215
1216            writeln!(self.out, "; }}")?;
1217        } else {
1218            writeln!(self.out, ";")?;
1219        }
1220
1221        Ok(())
1222    }
1223
1224    fn write_global_sampler(
1225        &mut self,
1226        module: &Module,
1227        handle: Handle<crate::GlobalVariable>,
1228        global: &crate::GlobalVariable,
1229    ) -> BackendResult {
1230        let binding = *global.binding.as_ref().unwrap();
1231
1232        let key = super::SamplerIndexBufferKey {
1233            group: binding.group,
1234        };
1235        self.write_wrapped_sampler_buffer(key)?;
1236
1237        // This was already validated, so we can confidently unwrap it.
1238        let bt = self.options.resolve_resource_binding(&binding).unwrap();
1239
1240        match module.types[global.ty].inner {
1241            TypeInner::Sampler { comparison } => {
1242                // If we are generating a static access, we create a variable for the sampler.
1243                //
1244                // This prevents the DXIL from containing multiple lookups for the sampler, which
1245                // the backend compiler will then have to eliminate. AMD does seem to be able to
1246                // eliminate these, but better safe than sorry.
1247
1248                write!(self.out, "static const ")?;
1249                self.write_type(module, global.ty)?;
1250
1251                let heap_var = if comparison {
1252                    COMPARISON_SAMPLER_HEAP_VAR
1253                } else {
1254                    SAMPLER_HEAP_VAR
1255                };
1256
1257                let index_buffer_name = &self.wrapped.sampler_index_buffers[&key];
1258                let name = &self.names[&NameKey::GlobalVariable(handle)];
1259                writeln!(
1260                    self.out,
1261                    " {name} = {heap_var}[{index_buffer_name}[{register}]];",
1262                    register = bt.register
1263                )?;
1264            }
1265            TypeInner::BindingArray { .. } => {
1266                // If we are generating a binding array, we cannot directly access the sampler as the index
1267                // into the sampler index buffer is unknown at compile time. Instead we generate a constant
1268                // that represents the "base" index into the sampler index buffer. This constant is added
1269                // to the user provided index to get the final index into the sampler index buffer.
1270
1271                let name = &self.names[&NameKey::GlobalVariable(handle)];
1272                writeln!(
1273                    self.out,
1274                    "static const uint {name} = {register};",
1275                    register = bt.register
1276                )?;
1277            }
1278            _ => unreachable!(),
1279        };
1280
1281        Ok(())
1282    }
1283
1284    /// Write the declarations for an external texture global variable.
1285    /// These are emitted as multiple global variables: Three `Texture2D`s
1286    /// (one for each plane) and a parameters cbuffer.
1287    fn write_global_external_texture(
1288        &mut self,
1289        module: &Module,
1290        handle: Handle<crate::GlobalVariable>,
1291        global: &crate::GlobalVariable,
1292    ) -> BackendResult {
1293        let res_binding = global
1294            .binding
1295            .as_ref()
1296            .expect("External texture global variables must have a resource binding");
1297        let ext_tex_bindings = match self
1298            .options
1299            .resolve_external_texture_resource_binding(res_binding)
1300        {
1301            Ok(bindings) => bindings,
1302            Err(err) => {
1303                log::debug!(
1304                    "Skipping global {:?} (name {:?}) for being inaccessible: {}",
1305                    handle,
1306                    global.name,
1307                    err,
1308                );
1309                return Ok(());
1310            }
1311        };
1312
1313        let mut write_plane = |bt: &super::BindTarget, name| -> BackendResult {
1314            write!(
1315                self.out,
1316                "Texture2D<float4> {}: register(t{}",
1317                name, bt.register
1318            )?;
1319            if bt.space != 0 {
1320                write!(self.out, ", space{}", bt.space)?;
1321            }
1322            writeln!(self.out, ");")?;
1323            Ok(())
1324        };
1325        for (i, bt) in ext_tex_bindings.planes.iter().enumerate() {
1326            let plane_name = &self.names
1327                [&NameKey::ExternalTextureGlobalVariable(handle, ExternalTextureNameKey::Plane(i))];
1328            write_plane(bt, plane_name)?;
1329        }
1330
1331        let params_name = &self.names
1332            [&NameKey::ExternalTextureGlobalVariable(handle, ExternalTextureNameKey::Params)];
1333        let params_ty_name =
1334            &self.names[&NameKey::Type(module.special_types.external_texture_params.unwrap())];
1335        write!(
1336            self.out,
1337            "cbuffer {}: register(b{}",
1338            params_name, ext_tex_bindings.params.register
1339        )?;
1340        if ext_tex_bindings.params.space != 0 {
1341            write!(self.out, ", space{}", ext_tex_bindings.params.space)?;
1342        }
1343        writeln!(self.out, ") {{ {params_ty_name} {params_name}; }};")?;
1344
1345        Ok(())
1346    }
1347
1348    /// Helper method used to write global constants
1349    ///
1350    /// # Notes
1351    /// Ends in a newline
1352    fn write_global_constant(
1353        &mut self,
1354        module: &Module,
1355        handle: Handle<crate::Constant>,
1356    ) -> BackendResult {
1357        write!(self.out, "static const ")?;
1358        let constant = &module.constants[handle];
1359        self.write_type(module, constant.ty)?;
1360        let name = &self.names[&NameKey::Constant(handle)];
1361        write!(self.out, " {name}")?;
1362        // Write size for array type
1363        if let TypeInner::Array { base, size, .. } = module.types[constant.ty].inner {
1364            self.write_array_size(module, base, size)?;
1365        }
1366        write!(self.out, " = ")?;
1367        self.write_const_expression(module, constant.init, &module.global_expressions)?;
1368        writeln!(self.out, ";")?;
1369        Ok(())
1370    }
1371
1372    pub(super) fn write_array_size(
1373        &mut self,
1374        module: &Module,
1375        base: Handle<crate::Type>,
1376        size: crate::ArraySize,
1377    ) -> BackendResult {
1378        write!(self.out, "[")?;
1379
1380        match size.resolve(module.to_ctx())? {
1381            proc::IndexableLength::Known(size) => {
1382                write!(self.out, "{size}")?;
1383            }
1384            proc::IndexableLength::Dynamic => unreachable!(),
1385        }
1386
1387        write!(self.out, "]")?;
1388
1389        if let TypeInner::Array {
1390            base: next_base,
1391            size: next_size,
1392            ..
1393        } = module.types[base].inner
1394        {
1395            self.write_array_size(module, next_base, next_size)?;
1396        }
1397
1398        Ok(())
1399    }
1400
1401    /// Helper method used to write structs
1402    ///
1403    /// # Notes
1404    /// Ends in a newline
1405    fn write_struct(
1406        &mut self,
1407        module: &Module,
1408        handle: Handle<crate::Type>,
1409        members: &[crate::StructMember],
1410        span: u32,
1411        shader_stage: Option<(ShaderStage, Io)>,
1412    ) -> BackendResult {
1413        // Write struct name
1414        let struct_name = &self.names[&NameKey::Type(handle)];
1415        writeln!(self.out, "struct {struct_name} {{")?;
1416
1417        let mut last_offset = 0;
1418        for (index, member) in members.iter().enumerate() {
1419            if member.binding.is_none() && member.offset > last_offset {
1420                // using int as padding should work as long as the backend
1421                // doesn't support a type that's less than 4 bytes in size
1422                // (Error::UnsupportedScalar catches this)
1423                let padding = (member.offset - last_offset) / 4;
1424                for i in 0..padding {
1425                    writeln!(self.out, "{}int _pad{}_{};", back::INDENT, index, i)?;
1426                }
1427            }
1428            let ty_inner = &module.types[member.ty].inner;
1429            last_offset = member.offset + ty_inner.size_hlsl(module.to_ctx())?;
1430
1431            // The indentation is only for readability
1432            write!(self.out, "{}", back::INDENT)?;
1433
1434            match module.types[member.ty].inner {
1435                TypeInner::Array { base, size, .. } => {
1436                    // HLSL arrays are written as `type name[size]`
1437
1438                    self.write_global_type(module, member.ty)?;
1439
1440                    // Write `name`
1441                    write!(
1442                        self.out,
1443                        " {}",
1444                        &self.names[&NameKey::StructMember(handle, index as u32)]
1445                    )?;
1446                    // Write [size]
1447                    self.write_array_size(module, base, size)?;
1448                }
1449                // We treat matrices of the form `matCx2` as a sequence of C `vec2`s.
1450                // See the module-level block comment in mod.rs for details.
1451                TypeInner::Matrix {
1452                    rows,
1453                    columns,
1454                    scalar,
1455                } if member.binding.is_none() && rows == crate::VectorSize::Bi => {
1456                    let vec_ty = TypeInner::Vector { size: rows, scalar };
1457                    let field_name_key = NameKey::StructMember(handle, index as u32);
1458
1459                    for i in 0..columns as u8 {
1460                        if i != 0 {
1461                            write!(self.out, "; ")?;
1462                        }
1463                        self.write_value_type(module, &vec_ty)?;
1464                        write!(self.out, " {}_{}", &self.names[&field_name_key], i)?;
1465                    }
1466                }
1467                _ => {
1468                    // Write modifier before type
1469                    if let Some(ref binding) = member.binding {
1470                        self.write_modifier(binding)?;
1471                    }
1472
1473                    // Even though Naga IR matrices are column-major, we must describe
1474                    // matrices passed from the CPU as being in row-major order.
1475                    // See the module-level block comment in mod.rs for details.
1476                    if let TypeInner::Matrix { .. } = module.types[member.ty].inner {
1477                        write!(self.out, "row_major ")?;
1478                    }
1479
1480                    // Write the member type and name
1481                    self.write_type(module, member.ty)?;
1482                    write!(
1483                        self.out,
1484                        " {}",
1485                        &self.names[&NameKey::StructMember(handle, index as u32)]
1486                    )?;
1487                }
1488            }
1489
1490            self.write_semantic(&member.binding, shader_stage)?;
1491            writeln!(self.out, ";")?;
1492        }
1493
1494        // add padding at the end since sizes of types don't get rounded up to their alignment in HLSL
1495        if members.last().unwrap().binding.is_none() && span > last_offset {
1496            let padding = (span - last_offset) / 4;
1497            for i in 0..padding {
1498                writeln!(self.out, "{}int _end_pad_{};", back::INDENT, i)?;
1499            }
1500        }
1501
1502        writeln!(self.out, "}};")?;
1503        Ok(())
1504    }
1505
1506    /// Helper method used to write global/structs non image/sampler types
1507    ///
1508    /// # Notes
1509    /// Adds no trailing or leading whitespace
1510    pub(super) fn write_global_type(
1511        &mut self,
1512        module: &Module,
1513        ty: Handle<crate::Type>,
1514    ) -> BackendResult {
1515        let matrix_data = get_inner_matrix_data(module, ty);
1516
1517        // We treat matrices of the form `matCx2` as a sequence of C `vec2`s.
1518        // See the module-level block comment in mod.rs for details.
1519        if let Some(MatrixType {
1520            columns,
1521            rows: crate::VectorSize::Bi,
1522            width,
1523        }) = matrix_data
1524        {
1525            write!(self.out, "__mat{}x2_f{}", columns as u8, width * 8)?;
1526        } else {
1527            // Even though Naga IR matrices are column-major, we must describe
1528            // matrices passed from the CPU as being in row-major order.
1529            // See the module-level block comment in mod.rs for details.
1530            if matrix_data.is_some() {
1531                write!(self.out, "row_major ")?;
1532            }
1533
1534            self.write_type(module, ty)?;
1535        }
1536
1537        Ok(())
1538    }
1539
1540    /// Helper method used to write non image/sampler types
1541    ///
1542    /// # Notes
1543    /// Adds no trailing or leading whitespace
1544    pub(super) fn write_type(&mut self, module: &Module, ty: Handle<crate::Type>) -> BackendResult {
1545        let inner = &module.types[ty].inner;
1546        match *inner {
1547            TypeInner::Struct { .. } => write!(self.out, "{}", self.names[&NameKey::Type(ty)])?,
1548            // hlsl array has the size separated from the base type
1549            TypeInner::Array { base, .. } | TypeInner::BindingArray { base, .. } => {
1550                self.write_type(module, base)?
1551            }
1552            ref other => self.write_value_type(module, other)?,
1553        }
1554
1555        Ok(())
1556    }
1557
1558    /// Helper method used to write value types
1559    ///
1560    /// # Notes
1561    /// Adds no trailing or leading whitespace
1562    pub(super) fn write_value_type(&mut self, module: &Module, inner: &TypeInner) -> BackendResult {
1563        match *inner {
1564            TypeInner::Scalar(scalar) | TypeInner::Atomic(scalar) => {
1565                write!(self.out, "{}", scalar.to_hlsl_str()?)?;
1566            }
1567            TypeInner::Vector { size, scalar } => {
1568                write!(
1569                    self.out,
1570                    "{}{}",
1571                    scalar.to_hlsl_str()?,
1572                    common::vector_size_str(size)
1573                )?;
1574            }
1575            TypeInner::Matrix {
1576                columns,
1577                rows,
1578                scalar,
1579            } => {
1580                // The IR supports only float matrix
1581                // https://docs.microsoft.com/en-us/windows/win32/direct3dhlsl/dx-graphics-hlsl-matrix
1582
1583                // Because of the implicit transpose all matrices have in HLSL, we need to transpose the size as well.
1584                write!(
1585                    self.out,
1586                    "{}{}x{}",
1587                    scalar.to_hlsl_str()?,
1588                    common::vector_size_str(columns),
1589                    common::vector_size_str(rows),
1590                )?;
1591            }
1592            TypeInner::Image {
1593                dim,
1594                arrayed,
1595                class,
1596            } => {
1597                self.write_image_type(dim, arrayed, class)?;
1598            }
1599            TypeInner::Sampler { comparison } => {
1600                let sampler = if comparison {
1601                    "SamplerComparisonState"
1602                } else {
1603                    "SamplerState"
1604                };
1605                write!(self.out, "{sampler}")?;
1606            }
1607            // HLSL arrays are written as `type name[size]`
1608            // Current code is written arrays only as `[size]`
1609            // Base `type` and `name` should be written outside
1610            TypeInner::Array { base, size, .. } | TypeInner::BindingArray { base, size } => {
1611                self.write_array_size(module, base, size)?;
1612            }
1613            TypeInner::AccelerationStructure { .. } => {
1614                write!(self.out, "RaytracingAccelerationStructure")?;
1615            }
1616            TypeInner::RayQuery { .. } => {
1617                // these are constant flags, there are dynamic flags also but constant flags are not supported by naga
1618                write!(self.out, "RayQuery<RAY_FLAG_NONE>")?;
1619            }
1620            _ => return Err(Error::Unimplemented(format!("write_value_type {inner:?}"))),
1621        }
1622
1623        Ok(())
1624    }
1625
1626    /// Helper method used to write functions
1627    /// # Notes
1628    /// Ends in a newline
1629    fn write_function(
1630        &mut self,
1631        module: &Module,
1632        name: &str,
1633        func: &crate::Function,
1634        func_ctx: &back::FunctionCtx<'_>,
1635        info: &valid::FunctionInfo,
1636        header: String,
1637    ) -> BackendResult {
1638        // Function Declaration Syntax - https://docs.microsoft.com/en-us/windows/win32/direct3dhlsl/dx-graphics-hlsl-function-syntax
1639
1640        self.update_expressions_to_bake(module, func, info);
1641        let ep = match func_ctx.ty {
1642            back::FunctionType::EntryPoint(idx) => Some(&module.entry_points[idx as usize]),
1643            back::FunctionType::Function(_) => None,
1644        };
1645
1646        let nested = matches!(
1647            ep,
1648            Some(crate::EntryPoint {
1649                stage: ShaderStage::Task | ShaderStage::Mesh,
1650                ..
1651            })
1652        );
1653        if !nested {
1654            write!(self.out, "{header}")?;
1655        }
1656
1657        if let Some(ref result) = func.result {
1658            // Write typedef if return type is an array
1659            let array_return_type = match module.types[result.ty].inner {
1660                TypeInner::Array { base, size, .. } => {
1661                    let array_return_type = self.namer.call(&format!("ret_{name}"));
1662                    write!(self.out, "typedef ")?;
1663                    self.write_type(module, result.ty)?;
1664                    write!(self.out, " {array_return_type}")?;
1665                    self.write_array_size(module, base, size)?;
1666                    writeln!(self.out, ";")?;
1667                    Some(array_return_type)
1668                }
1669                _ => None,
1670            };
1671
1672            // Write modifier
1673            if let Some(
1674                ref binding @ crate::Binding::BuiltIn(crate::BuiltIn::Position { invariant: true }),
1675            ) = result.binding
1676            {
1677                self.write_modifier(binding)?;
1678            }
1679
1680            // Write return type
1681            match func_ctx.ty {
1682                back::FunctionType::Function(_) => {
1683                    if let Some(array_return_type) = array_return_type {
1684                        write!(self.out, "{array_return_type}")?;
1685                    } else {
1686                        self.write_type(module, result.ty)?;
1687                    }
1688                }
1689                back::FunctionType::EntryPoint(index) => {
1690                    if let Some(ref ep_output) =
1691                        self.entry_point_io.get(&(index as usize)).unwrap().output
1692                    {
1693                        write!(self.out, "{}", ep_output.ty_name)?;
1694                    } else {
1695                        self.write_type(module, result.ty)?;
1696                    }
1697                }
1698            }
1699        } else {
1700            write!(self.out, "void")?;
1701        }
1702
1703        let nested_name = if nested {
1704            self.namer.call(&format!("_{name}"))
1705        } else {
1706            name.to_string()
1707        };
1708
1709        // Write function name
1710        write!(self.out, " {nested_name}(")?;
1711
1712        let need_workgroup_variables_initialization =
1713            self.need_workgroup_variables_initialization(func_ctx, module);
1714
1715        let mut any_args_written = false;
1716        let mut separator = || {
1717            if any_args_written {
1718                ", "
1719            } else {
1720                any_args_written = true;
1721                ""
1722            }
1723        };
1724
1725        let needs_local_invocation_index_name = need_workgroup_variables_initialization || nested;
1726        let mut local_invocation_index_name = None;
1727        // For nested entry points, collect arg names as we write them so that
1728        // write_nested_function_outer can pass the exact same names to the call site.
1729        let mut nested_wgsl_args: Vec<String> = Vec::new();
1730        let mut nested_task_payload_name: Option<String> = None;
1731        // Write function arguments for non entry point functions
1732        match func_ctx.ty {
1733            back::FunctionType::Function(handle) => {
1734                for (index, arg) in func.arguments.iter().enumerate() {
1735                    write!(self.out, "{}", separator())?;
1736                    self.write_function_argument(module, handle, arg, index)?;
1737                }
1738                // If this reads a task payload variable the variable needs to be passed as an `in` argument
1739                for (var_handle, var) in module.global_variables.iter() {
1740                    let uses = info[var_handle];
1741                    if uses.contains(valid::GlobalUse::READ)
1742                        && !uses.contains(valid::GlobalUse::WRITE)
1743                        && var.space == crate::AddressSpace::TaskPayload
1744                    {
1745                        self.function_task_payload_var.insert(handle, var_handle);
1746                        write!(self.out, "{}in ", separator())?;
1747
1748                        self.write_type(module, var.ty)?;
1749                        let name = &self.names[&NameKey::GlobalVariable(var_handle)];
1750                        write!(self.out, " {name}")?;
1751                        break;
1752                    }
1753                }
1754            }
1755            back::FunctionType::EntryPoint(ep_index) => {
1756                let ep = &module.entry_points[ep_index as usize];
1757                if let Some(ref ep_input) =
1758                    self.entry_point_io.get(&(ep_index as usize)).unwrap().input
1759                {
1760                    write!(self.out, "{} {}", ep_input.ty_name, ep_input.arg_name)?;
1761                    separator();
1762                    nested_wgsl_args.push(ep_input.arg_name.clone());
1763                } else {
1764                    let stage = ep.stage;
1765                    for (index, arg) in func.arguments.iter().enumerate() {
1766                        write!(self.out, "{}", separator())?;
1767                        self.write_type(module, arg.ty)?;
1768
1769                        let argument_name =
1770                            &self.names[&NameKey::EntryPointArgument(ep_index, index as u32)];
1771
1772                        if arg.binding
1773                            == Some(crate::Binding::BuiltIn(
1774                                crate::BuiltIn::LocalInvocationIndex,
1775                            ))
1776                        {
1777                            local_invocation_index_name = Some(argument_name.clone());
1778                        }
1779
1780                        nested_wgsl_args.push(argument_name.clone());
1781                        write!(self.out, " {argument_name}")?;
1782                        if let TypeInner::Array { base, size, .. } = module.types[arg.ty].inner {
1783                            self.write_array_size(module, base, size)?;
1784                        }
1785
1786                        self.write_semantic(&arg.binding, Some((stage, Io::Input)))?;
1787                    }
1788                }
1789                if ep.stage == ShaderStage::Mesh {
1790                    if let Some(var_handle) = ep.task_payload {
1791                        let var = &module.global_variables[var_handle];
1792                        write!(self.out, "{}in ", separator())?;
1793                        self.write_type(module, var.ty)?;
1794                        let arg_name = &self.names[&NameKey::GlobalVariable(var_handle)];
1795                        write!(self.out, " {arg_name}")?;
1796                        nested_task_payload_name = Some(arg_name.clone());
1797                        if let TypeInner::Array { base, size, .. } = module.types[var.ty].inner {
1798                            self.write_array_size(module, base, size)?;
1799                        }
1800                    }
1801                }
1802                if needs_local_invocation_index_name && local_invocation_index_name.is_none() {
1803                    let name = self.namer.call("local_invocation_index");
1804                    write!(self.out, "{}uint {name}", separator())?;
1805                    write!(self.out, " : SV_GroupIndex")?;
1806                    local_invocation_index_name = Some(name);
1807                }
1808            }
1809        }
1810        // Ends of arguments
1811        write!(self.out, ")")?;
1812
1813        // Write semantic if it present
1814        if let back::FunctionType::EntryPoint(index) = func_ctx.ty {
1815            let stage = module.entry_points[index as usize].stage;
1816            if let Some(crate::FunctionResult { ref binding, .. }) = func.result {
1817                self.write_semantic(binding, Some((stage, Io::Output)))?;
1818            }
1819        }
1820
1821        // Function body start
1822        writeln!(self.out)?;
1823        writeln!(self.out, "{{")?;
1824
1825        if need_workgroup_variables_initialization && !nested {
1826            let back::FunctionType::EntryPoint(index) = func_ctx.ty else {
1827                unreachable!();
1828            };
1829            writeln!(
1830                self.out,
1831                "{}if ({} == 0) {{",
1832                back::INDENT,
1833                // need_workgroup_variables_initialization forces this to be written
1834                // if the user doesn't specify it (so this must be Some())
1835                local_invocation_index_name.as_ref().unwrap(),
1836            )?;
1837            self.write_workgroup_variables_initialization(
1838                func_ctx,
1839                module,
1840                module.entry_points[index as usize].stage,
1841            )?;
1842
1843            writeln!(self.out, "{}}}", back::INDENT)?;
1844            self.write_control_barrier(crate::Barrier::WORK_GROUP, back::Level(1))?;
1845        }
1846
1847        if let back::FunctionType::EntryPoint(index) = func_ctx.ty {
1848            self.write_ep_arguments_initialization(module, func, index)?;
1849        }
1850
1851        // Write function local variables
1852        for (handle, local) in func.local_variables.iter() {
1853            // Write indentation (only for readability)
1854            write!(self.out, "{}", back::INDENT)?;
1855
1856            // Write the local name
1857            // The leading space is important
1858            self.write_type(module, local.ty)?;
1859            write!(self.out, " {}", self.names[&func_ctx.name_key(handle)])?;
1860            // Write size for array type
1861            if let TypeInner::Array { base, size, .. } = module.types[local.ty].inner {
1862                self.write_array_size(module, base, size)?;
1863            }
1864
1865            let is_ray_query = match module.types[local.ty].inner {
1866                // from https://microsoft.github.io/DirectX-Specs/d3d/Raytracing.html#tracerayinline-example-1 it seems that ray queries shouldn't be zeroed
1867                TypeInner::RayQuery { .. } => true,
1868                _ => {
1869                    write!(self.out, " = ")?;
1870                    // Write the local initializer if needed
1871                    if let Some(init) = local.init {
1872                        self.write_expr(module, init, func_ctx)?;
1873                    } else {
1874                        // Zero initialize local variables
1875                        self.write_default_init(module, local.ty)?;
1876                    }
1877                    false
1878                }
1879            };
1880            // Finish the local with `;` and add a newline (only for readability)
1881            writeln!(self.out, ";")?;
1882            // If it's a ray query, we also want a tracker variable
1883            if is_ray_query {
1884                write!(self.out, "{}", back::INDENT)?;
1885                self.write_value_type(module, &TypeInner::Scalar(Scalar::U32))?;
1886                writeln!(
1887                    self.out,
1888                    " {RAY_QUERY_TRACKER_VARIABLE_PREFIX}{} = 0;",
1889                    self.names[&func_ctx.name_key(handle)]
1890                )?;
1891            }
1892        }
1893
1894        if !func.local_variables.is_empty() {
1895            writeln!(self.out)?;
1896        }
1897
1898        // Write the function body (statement list)
1899        for sta in func.body.iter() {
1900            // The indentation should always be 1 when writing the function body
1901            self.write_stmt(module, sta, func_ctx, back::Level(1))?;
1902        }
1903
1904        writeln!(self.out, "}}")?;
1905
1906        if nested {
1907            self.write_nested_function_outer(
1908                module,
1909                func_ctx,
1910                &header,
1911                name,
1912                need_workgroup_variables_initialization,
1913                &nested_name,
1914                ep.unwrap(),
1915                NestedEntryPointArgs {
1916                    user_args: nested_wgsl_args,
1917                    task_payload: nested_task_payload_name,
1918                    // guaranteed to be set for nested functions (task/mesh shaders)
1919                    local_invocation_index: local_invocation_index_name.unwrap(),
1920                },
1921            )?;
1922        }
1923
1924        self.named_expressions.clear();
1925
1926        Ok(())
1927    }
1928
1929    fn write_function_argument(
1930        &mut self,
1931        module: &Module,
1932        handle: Handle<crate::Function>,
1933        arg: &crate::FunctionArgument,
1934        index: usize,
1935    ) -> BackendResult {
1936        // External texture arguments must be expanded into separate
1937        // arguments for each plane and the params buffer.
1938        if let TypeInner::Image {
1939            class: crate::ImageClass::External,
1940            ..
1941        } = module.types[arg.ty].inner
1942        {
1943            return self.write_function_external_texture_argument(module, handle, index);
1944        }
1945
1946        // Write argument type
1947        let arg_ty = match module.types[arg.ty].inner {
1948            // pointers in function arguments are expected and resolve to `inout`
1949            TypeInner::Pointer { base, .. } => {
1950                //TODO: can we narrow this down to just `in` when possible?
1951                write!(self.out, "inout ")?;
1952                base
1953            }
1954            _ => arg.ty,
1955        };
1956        self.write_type(module, arg_ty)?;
1957
1958        let argument_name = &self.names[&NameKey::FunctionArgument(handle, index as u32)];
1959
1960        // Write argument name. Space is important.
1961        write!(self.out, " {argument_name}")?;
1962        if let TypeInner::Array { base, size, .. } = module.types[arg_ty].inner {
1963            self.write_array_size(module, base, size)?;
1964        }
1965
1966        Ok(())
1967    }
1968
1969    fn write_function_external_texture_argument(
1970        &mut self,
1971        module: &Module,
1972        handle: Handle<crate::Function>,
1973        index: usize,
1974    ) -> BackendResult {
1975        let plane_names = [0, 1, 2].map(|i| {
1976            &self.names[&NameKey::ExternalTextureFunctionArgument(
1977                handle,
1978                index as u32,
1979                ExternalTextureNameKey::Plane(i),
1980            )]
1981        });
1982        let params_name = &self.names[&NameKey::ExternalTextureFunctionArgument(
1983            handle,
1984            index as u32,
1985            ExternalTextureNameKey::Params,
1986        )];
1987        let params_ty_name =
1988            &self.names[&NameKey::Type(module.special_types.external_texture_params.unwrap())];
1989        write!(
1990            self.out,
1991            "Texture2D<float4> {}, Texture2D<float4> {}, Texture2D<float4> {}, {params_ty_name} {params_name}",
1992            plane_names[0], plane_names[1], plane_names[2],
1993        )?;
1994        Ok(())
1995    }
1996
1997    fn need_workgroup_variables_initialization(
1998        &mut self,
1999        func_ctx: &back::FunctionCtx,
2000        module: &Module,
2001    ) -> bool {
2002        self.options.zero_initialize_workgroup_memory
2003            && func_ctx.ty.is_compute_like_entry_point(module)
2004            && module.global_variables.iter().any(|(handle, var)| {
2005                !func_ctx.info[handle].is_empty() && var.space.is_workgroup_like()
2006            })
2007    }
2008
2009    pub(super) fn write_workgroup_variables_initialization(
2010        &mut self,
2011        func_ctx: &back::FunctionCtx,
2012        module: &Module,
2013        stage: ShaderStage,
2014    ) -> BackendResult {
2015        let vars = module.global_variables.iter().filter(|&(handle, var)| {
2016            // Read-only in mesh shaders
2017            let task_needs_zero =
2018                (var.space == crate::AddressSpace::TaskPayload) && stage == ShaderStage::Task;
2019            !func_ctx.info[handle].is_empty()
2020                && (var.space == crate::AddressSpace::WorkGroup || task_needs_zero)
2021        });
2022
2023        for (handle, var) in vars {
2024            let name = &self.names[&NameKey::GlobalVariable(handle)];
2025            write!(self.out, "{}{} = ", back::Level(2), name)?;
2026            self.write_default_init(module, var.ty)?;
2027            writeln!(self.out, ";")?;
2028        }
2029        Ok(())
2030    }
2031
2032    /// Helper method used to write switches
2033    fn write_switch(
2034        &mut self,
2035        module: &Module,
2036        func_ctx: &back::FunctionCtx<'_>,
2037        level: back::Level,
2038        selector: Handle<crate::Expression>,
2039        cases: &[crate::SwitchCase],
2040    ) -> BackendResult {
2041        // Write all cases
2042        let indent_level_1 = level.next();
2043        let indent_level_2 = indent_level_1.next();
2044
2045        // See docs of `back::continue_forward` module.
2046        if let Some(variable) = self.continue_ctx.enter_switch(&mut self.namer) {
2047            writeln!(self.out, "{level}bool {variable} = false;",)?;
2048        };
2049
2050        // Check if there is only one body, by seeing if all except the last case are fall through
2051        // with empty bodies. FXC doesn't handle these switches correctly, so
2052        // we generate a `do {} while(false);` loop instead. There must be a default case, so there
2053        // is no need to check if one of the cases would have matched.
2054        let one_body = cases
2055            .iter()
2056            .rev()
2057            .skip(1)
2058            .all(|case| case.fall_through && case.body.is_empty());
2059        if one_body {
2060            // Start the do-while
2061            writeln!(self.out, "{level}do {{")?;
2062            // Note: Expressions have no side-effects so we don't need to emit selector expression.
2063
2064            // Body
2065            if let Some(case) = cases.last() {
2066                for sta in case.body.iter() {
2067                    self.write_stmt(module, sta, func_ctx, indent_level_1)?;
2068                }
2069            }
2070            // End do-while
2071            writeln!(self.out, "{level}}} while(false);")?;
2072        } else {
2073            // Start the switch
2074            write!(self.out, "{level}")?;
2075            write!(self.out, "switch(")?;
2076            self.write_expr(module, selector, func_ctx)?;
2077            writeln!(self.out, ") {{")?;
2078
2079            for (i, case) in cases.iter().enumerate() {
2080                match case.value {
2081                    crate::SwitchValue::I32(value) => {
2082                        write!(self.out, "{indent_level_1}case {value}:")?
2083                    }
2084                    crate::SwitchValue::U32(value) => {
2085                        write!(self.out, "{indent_level_1}case {value}u:")?
2086                    }
2087                    crate::SwitchValue::Default => write!(self.out, "{indent_level_1}default:")?,
2088                }
2089
2090                // The new block is not only stylistic, it plays a role here:
2091                // We might end up having to write the same case body
2092                // multiple times due to FXC not supporting fallthrough.
2093                // Therefore, some `Expression`s written by `Statement::Emit`
2094                // will end up having the same name (`_expr<handle_index>`).
2095                // So we need to put each case in its own scope.
2096                let write_block_braces = !(case.fall_through && case.body.is_empty());
2097                if write_block_braces {
2098                    writeln!(self.out, " {{")?;
2099                } else {
2100                    writeln!(self.out)?;
2101                }
2102
2103                // Although FXC does support a series of case clauses before
2104                // a block[^yes], it does not support fallthrough from a
2105                // non-empty case block to the next[^no]. If this case has a
2106                // non-empty body with a fallthrough, emulate that by
2107                // duplicating the bodies of all the cases it would fall
2108                // into as extensions of this case's own body. This makes
2109                // the HLSL output potentially quadratic in the size of the
2110                // Naga IR.
2111                //
2112                // [^yes]: ```hlsl
2113                // case 1:
2114                // case 2: do_stuff()
2115                // ```
2116                // [^no]: ```hlsl
2117                // case 1: do_this();
2118                // case 2: do_that();
2119                // ```
2120                if case.fall_through && !case.body.is_empty() {
2121                    let curr_len = i + 1;
2122                    let end_case_idx = curr_len
2123                        + cases
2124                            .iter()
2125                            .skip(curr_len)
2126                            .position(|case| !case.fall_through)
2127                            .unwrap();
2128                    let indent_level_3 = indent_level_2.next();
2129                    for case in &cases[i..=end_case_idx] {
2130                        writeln!(self.out, "{indent_level_2}{{")?;
2131                        let prev_len = self.named_expressions.len();
2132                        for sta in case.body.iter() {
2133                            self.write_stmt(module, sta, func_ctx, indent_level_3)?;
2134                        }
2135                        // Clear all named expressions that were previously inserted by the statements in the block
2136                        self.named_expressions.truncate(prev_len);
2137                        writeln!(self.out, "{indent_level_2}}}")?;
2138                    }
2139
2140                    let last_case = &cases[end_case_idx];
2141                    if last_case.body.last().is_none_or(|s| !s.is_terminator()) {
2142                        writeln!(self.out, "{indent_level_2}break;")?;
2143                    }
2144                } else {
2145                    for sta in case.body.iter() {
2146                        self.write_stmt(module, sta, func_ctx, indent_level_2)?;
2147                    }
2148                    if !case.fall_through && case.body.last().is_none_or(|s| !s.is_terminator()) {
2149                        writeln!(self.out, "{indent_level_2}break;")?;
2150                    }
2151                }
2152
2153                if write_block_braces {
2154                    writeln!(self.out, "{indent_level_1}}}")?;
2155                }
2156            }
2157
2158            writeln!(self.out, "{level}}}")?;
2159        }
2160
2161        // Handle any forwarded continue statements.
2162        use back::continue_forward::ExitControlFlow;
2163        let op = match self.continue_ctx.exit_switch() {
2164            ExitControlFlow::None => None,
2165            ExitControlFlow::Continue { variable } => Some(("continue", variable)),
2166            ExitControlFlow::Break { variable } => Some(("break", variable)),
2167        };
2168        if let Some((control_flow, variable)) = op {
2169            writeln!(self.out, "{level}if ({variable}) {{")?;
2170            writeln!(self.out, "{indent_level_1}{control_flow};")?;
2171            writeln!(self.out, "{level}}}")?;
2172        }
2173
2174        Ok(())
2175    }
2176
2177    fn write_index(
2178        &mut self,
2179        module: &Module,
2180        index: Index,
2181        func_ctx: &back::FunctionCtx<'_>,
2182    ) -> BackendResult {
2183        match index {
2184            Index::Static(index) => {
2185                write!(self.out, "{index}")?;
2186            }
2187            Index::Expression(index) => {
2188                self.write_expr(module, index, func_ctx)?;
2189            }
2190        }
2191        Ok(())
2192    }
2193
2194    /// Helper method used to write statements
2195    ///
2196    /// # Notes
2197    /// Always adds a newline
2198    fn write_stmt(
2199        &mut self,
2200        module: &Module,
2201        stmt: &crate::Statement,
2202        func_ctx: &back::FunctionCtx<'_>,
2203        level: back::Level,
2204    ) -> BackendResult {
2205        use crate::Statement;
2206
2207        match *stmt {
2208            Statement::Emit(ref range) => {
2209                for handle in range.clone() {
2210                    let ptr_class = func_ctx.resolve_type(handle, &module.types).pointer_space();
2211                    let expr_name = if ptr_class.is_some() {
2212                        // HLSL can't save a pointer-valued expression in a variable,
2213                        // but we shouldn't ever need to: they should never be named expressions,
2214                        // and none of the expression types flagged by bake_ref_count can be pointer-valued.
2215                        None
2216                    } else if let Some(name) = func_ctx.named_expressions.get(&handle) {
2217                        // Front end provides names for all variables at the start of writing.
2218                        // But we write them to step by step. We need to recache them
2219                        // Otherwise, we could accidentally write variable name instead of full expression.
2220                        // Also, we use sanitized names! It defense backend from generating variable with name from reserved keywords.
2221                        Some(self.namer.call(name))
2222                    } else if self.need_bake_expressions.contains(&handle) {
2223                        Some(Baked(handle).to_string())
2224                    } else {
2225                        None
2226                    };
2227
2228                    if let Some(name) = expr_name {
2229                        write!(self.out, "{level}")?;
2230                        self.write_named_expr(module, handle, name, handle, func_ctx)?;
2231                    }
2232                }
2233            }
2234            // TODO: copy-paste from glsl-out
2235            Statement::Block(ref block) => {
2236                write!(self.out, "{level}")?;
2237                writeln!(self.out, "{{")?;
2238                for sta in block.iter() {
2239                    // Increase the indentation to help with readability
2240                    self.write_stmt(module, sta, func_ctx, level.next())?
2241                }
2242                writeln!(self.out, "{level}}}")?
2243            }
2244            // TODO: copy-paste from glsl-out
2245            Statement::If {
2246                condition,
2247                ref accept,
2248                ref reject,
2249            } => {
2250                write!(self.out, "{level}")?;
2251                write!(self.out, "if (")?;
2252                self.write_expr(module, condition, func_ctx)?;
2253                writeln!(self.out, ") {{")?;
2254
2255                let l2 = level.next();
2256                for sta in accept {
2257                    // Increase indentation to help with readability
2258                    self.write_stmt(module, sta, func_ctx, l2)?;
2259                }
2260
2261                // If there are no statements in the reject block we skip writing it
2262                // This is only for readability
2263                if !reject.is_empty() {
2264                    writeln!(self.out, "{level}}} else {{")?;
2265
2266                    for sta in reject {
2267                        // Increase indentation to help with readability
2268                        self.write_stmt(module, sta, func_ctx, l2)?;
2269                    }
2270                }
2271
2272                writeln!(self.out, "{level}}}")?
2273            }
2274            // TODO: copy-paste from glsl-out
2275            Statement::Kill => writeln!(self.out, "{level}discard;")?,
2276            Statement::Return { value: None } => {
2277                writeln!(self.out, "{level}return;")?;
2278            }
2279            Statement::Return { value: Some(expr) } => {
2280                let base_ty_res = &func_ctx.info[expr].ty;
2281                let mut resolved = base_ty_res.inner_with(&module.types);
2282                if let TypeInner::Pointer { base, space: _ } = *resolved {
2283                    resolved = &module.types[base].inner;
2284                }
2285
2286                if let TypeInner::Struct { .. } = *resolved {
2287                    // We can safely unwrap here, since we now we working with struct
2288                    let ty = base_ty_res.handle().unwrap();
2289                    let struct_name = &self.names[&NameKey::Type(ty)];
2290                    let variable_name = self.namer.call(&struct_name.to_lowercase());
2291                    write!(self.out, "{level}const {struct_name} {variable_name} = ",)?;
2292                    self.write_expr(module, expr, func_ctx)?;
2293                    writeln!(self.out, ";")?;
2294
2295                    // for entry point returns, we may need to reshuffle the outputs into a different struct
2296                    let ep_output = match func_ctx.ty {
2297                        back::FunctionType::Function(_) => None,
2298                        back::FunctionType::EntryPoint(index) => self
2299                            .entry_point_io
2300                            .get(&(index as usize))
2301                            .unwrap()
2302                            .output
2303                            .as_ref(),
2304                    };
2305                    let final_name = match ep_output {
2306                        Some(ep_output) => {
2307                            let final_name = self.namer.call(&variable_name);
2308                            write!(
2309                                self.out,
2310                                "{}const {} {} = {{ ",
2311                                level, ep_output.ty_name, final_name,
2312                            )?;
2313                            for (index, m) in ep_output.members.iter().enumerate() {
2314                                if index != 0 {
2315                                    write!(self.out, ", ")?;
2316                                }
2317                                let member_name = &self.names[&NameKey::StructMember(ty, m.index)];
2318                                write!(self.out, "{variable_name}.{member_name}")?;
2319                            }
2320                            writeln!(self.out, " }};")?;
2321                            final_name
2322                        }
2323                        None => variable_name,
2324                    };
2325                    writeln!(self.out, "{level}return {final_name};")?;
2326                } else {
2327                    write!(self.out, "{level}return ")?;
2328                    self.write_expr(module, expr, func_ctx)?;
2329                    writeln!(self.out, ";")?
2330                }
2331            }
2332            Statement::Store { pointer, value } => {
2333                let ty_inner = func_ctx.resolve_type(pointer, &module.types);
2334                if ty_inner.is_atomic_pointer(&module.types) {
2335                    let pointer_space = ty_inner.pointer_space().unwrap();
2336                    let dummy = self.namer.call("dummy");
2337                    write!(self.out, "{level}{{ ")?;
2338                    if let TypeInner::Pointer { base, .. } = *ty_inner {
2339                        self.write_value_type(module, &module.types[base].inner)?;
2340                    }
2341                    write!(self.out, " {dummy} = 0; ")?;
2342                    match pointer_space {
2343                        crate::AddressSpace::WorkGroup => {
2344                            write!(self.out, "InterlockedExchange(")?;
2345                            self.write_expr(module, pointer, func_ctx)?;
2346                        }
2347                        crate::AddressSpace::Storage { .. } => {
2348                            let var_handle = self.fill_access_chain(module, pointer, func_ctx)?;
2349                            let var_name = &self.names[&NameKey::GlobalVariable(var_handle)];
2350                            write!(self.out, "{var_name}.InterlockedExchange(")?;
2351                            let chain = mem::take(&mut self.temp_access_chain);
2352                            self.write_storage_address(module, &chain, func_ctx)?;
2353                            self.temp_access_chain = chain;
2354                        }
2355                        _ => unreachable!(),
2356                    }
2357                    write!(self.out, ", ")?;
2358                    self.write_expr(module, value, func_ctx)?;
2359                    writeln!(self.out, ", {dummy}); }}")?;
2360                } else if let Some(crate::AddressSpace::Storage { .. }) = ty_inner.pointer_space() {
2361                    let var_handle = self.fill_access_chain(module, pointer, func_ctx)?;
2362                    self.write_storage_store(
2363                        module,
2364                        var_handle,
2365                        StoreValue::Expression(value),
2366                        func_ctx,
2367                        level,
2368                        None,
2369                    )?;
2370                } else {
2371                    // We treat matrices of the form `matCx2` as a sequence of C `vec2`s.
2372                    // See the module-level block comment in mod.rs for details.
2373                    //
2374                    // We handle matrix Stores here directly (including sub accesses for Vectors and Scalars).
2375                    // Loads are handled by `Expression::AccessIndex` (since sub accesses work fine for Loads).
2376                    enum MatrixAccess {
2377                        Direct {
2378                            base: Handle<crate::Expression>,
2379                            index: u32,
2380                        },
2381                        Struct {
2382                            columns: crate::VectorSize,
2383                            width: u8,
2384                            base: Handle<crate::Expression>,
2385                        },
2386                    }
2387
2388                    let get_members = |expr: Handle<crate::Expression>| {
2389                        let resolved = func_ctx.resolve_type(expr, &module.types);
2390                        match *resolved {
2391                            TypeInner::Pointer { base, .. } => match module.types[base].inner {
2392                                TypeInner::Struct { ref members, .. } => Some(members),
2393                                _ => None,
2394                            },
2395                            _ => None,
2396                        }
2397                    };
2398
2399                    write!(self.out, "{level}")?;
2400
2401                    let matrix_access_on_lhs =
2402                        find_matrix_in_access_chain(module, pointer, func_ctx).and_then(
2403                            |(matrix_expr, vector, scalar)| match (
2404                                func_ctx.resolve_type(matrix_expr, &module.types),
2405                                &func_ctx.expressions[matrix_expr],
2406                            ) {
2407                                (
2408                                    &TypeInner::Pointer { base: ty, .. },
2409                                    &crate::Expression::AccessIndex { base, index },
2410                                ) if matches!(
2411                                    module.types[ty].inner,
2412                                    TypeInner::Matrix {
2413                                        rows: crate::VectorSize::Bi,
2414                                        ..
2415                                    }
2416                                ) && get_members(base)
2417                                    .map(|members| members[index as usize].binding.is_none())
2418                                    == Some(true) =>
2419                                {
2420                                    Some((MatrixAccess::Direct { base, index }, vector, scalar))
2421                                }
2422                                _ => {
2423                                    if let Some(MatrixType {
2424                                        columns,
2425                                        rows: crate::VectorSize::Bi,
2426                                        width,
2427                                    }) = get_inner_matrix_of_struct_array_member(
2428                                        module,
2429                                        matrix_expr,
2430                                        func_ctx,
2431                                        true,
2432                                    ) {
2433                                        Some((
2434                                            MatrixAccess::Struct {
2435                                                columns,
2436                                                width,
2437                                                base: matrix_expr,
2438                                            },
2439                                            vector,
2440                                            scalar,
2441                                        ))
2442                                    } else {
2443                                        None
2444                                    }
2445                                }
2446                            },
2447                        );
2448
2449                    match matrix_access_on_lhs {
2450                        Some((MatrixAccess::Direct { index, base }, vector, scalar)) => {
2451                            let base_ty_res = &func_ctx.info[base].ty;
2452                            let resolved = base_ty_res.inner_with(&module.types);
2453                            let ty = match *resolved {
2454                                TypeInner::Pointer { base, .. } => base,
2455                                _ => base_ty_res.handle().unwrap(),
2456                            };
2457
2458                            if let Some(Index::Static(vec_index)) = vector {
2459                                self.write_expr(module, base, func_ctx)?;
2460                                write!(
2461                                    self.out,
2462                                    ".{}_{}",
2463                                    &self.names[&NameKey::StructMember(ty, index)],
2464                                    vec_index
2465                                )?;
2466
2467                                if let Some(scalar_index) = scalar {
2468                                    write!(self.out, "[")?;
2469                                    self.write_index(module, scalar_index, func_ctx)?;
2470                                    write!(self.out, "]")?;
2471                                }
2472
2473                                write!(self.out, " = ")?;
2474                                self.write_expr(module, value, func_ctx)?;
2475                                writeln!(self.out, ";")?;
2476                            } else {
2477                                let access = WrappedStructMatrixAccess { ty, index };
2478                                match (&vector, &scalar) {
2479                                    (&Some(_), &Some(_)) => {
2480                                        self.write_wrapped_struct_matrix_set_scalar_function_name(
2481                                            access,
2482                                        )?;
2483                                    }
2484                                    (&Some(_), &None) => {
2485                                        self.write_wrapped_struct_matrix_set_vec_function_name(
2486                                            access,
2487                                        )?;
2488                                    }
2489                                    (&None, _) => {
2490                                        self.write_wrapped_struct_matrix_set_function_name(access)?;
2491                                    }
2492                                }
2493
2494                                write!(self.out, "(")?;
2495                                self.write_expr(module, base, func_ctx)?;
2496                                write!(self.out, ", ")?;
2497                                self.write_expr(module, value, func_ctx)?;
2498
2499                                if let Some(Index::Expression(vec_index)) = vector {
2500                                    write!(self.out, ", ")?;
2501                                    self.write_expr(module, vec_index, func_ctx)?;
2502
2503                                    if let Some(scalar_index) = scalar {
2504                                        write!(self.out, ", ")?;
2505                                        self.write_index(module, scalar_index, func_ctx)?;
2506                                    }
2507                                }
2508                                writeln!(self.out, ");")?;
2509                            }
2510                        }
2511                        Some((
2512                            MatrixAccess::Struct {
2513                                columns,
2514                                width,
2515                                base,
2516                            },
2517                            Some(Index::Expression(vec_index)),
2518                            scalar,
2519                        )) => {
2520                            // We handle `Store`s to __matCx2 column vectors and scalar elements via
2521                            // the previously injected functions __set_col_of_matCx2 / __set_el_of_matCx2.
2522
2523                            if scalar.is_some() {
2524                                write!(
2525                                    self.out,
2526                                    "__set_el_of_mat{}x2_f{}",
2527                                    columns as u8,
2528                                    width * 8
2529                                )?;
2530                            } else {
2531                                write!(
2532                                    self.out,
2533                                    "__set_col_of_mat{}x2_f{}",
2534                                    columns as u8,
2535                                    width * 8
2536                                )?;
2537                            }
2538                            write!(self.out, "(")?;
2539                            self.write_expr(module, base, func_ctx)?;
2540                            write!(self.out, ", ")?;
2541                            self.write_expr(module, vec_index, func_ctx)?;
2542
2543                            if let Some(scalar_index) = scalar {
2544                                write!(self.out, ", ")?;
2545                                self.write_index(module, scalar_index, func_ctx)?;
2546                            }
2547
2548                            write!(self.out, ", ")?;
2549                            self.write_expr(module, value, func_ctx)?;
2550
2551                            writeln!(self.out, ");")?;
2552                        }
2553                        Some((MatrixAccess::Struct { .. }, Some(Index::Static(_)), _))
2554                        | Some((MatrixAccess::Struct { .. }, None, _))
2555                        | None => {
2556                            self.write_expr(module, pointer, func_ctx)?;
2557                            write!(self.out, " = ")?;
2558
2559                            // We cast the RHS of this store in cases where the LHS
2560                            // is a struct member with type:
2561                            //  - matCx2 or
2562                            //  - a (possibly nested) array of matCx2's
2563                            if let Some(MatrixType {
2564                                columns,
2565                                rows: crate::VectorSize::Bi,
2566                                width,
2567                            }) = get_inner_matrix_of_struct_array_member(
2568                                module, pointer, func_ctx, false,
2569                            ) {
2570                                let mut resolved = func_ctx.resolve_type(pointer, &module.types);
2571                                if let TypeInner::Pointer { base, .. } = *resolved {
2572                                    resolved = &module.types[base].inner;
2573                                }
2574
2575                                write!(self.out, "(__mat{}x2_f{}", columns as u8, width * 8)?;
2576                                if let TypeInner::Array { base, size, .. } = *resolved {
2577                                    self.write_array_size(module, base, size)?;
2578                                }
2579                                write!(self.out, ")")?;
2580                            }
2581
2582                            self.write_expr(module, value, func_ctx)?;
2583                            writeln!(self.out, ";")?
2584                        }
2585                    }
2586                }
2587            }
2588            Statement::Loop {
2589                ref body,
2590                ref continuing,
2591                break_if,
2592            } => {
2593                let force_loop_bound_statements = self.gen_force_bounded_loop_statements(level);
2594                let gate_name = (!continuing.is_empty() || break_if.is_some())
2595                    .then(|| self.namer.call("loop_init"));
2596
2597                if let Some((ref decl, _)) = force_loop_bound_statements {
2598                    writeln!(self.out, "{decl}")?;
2599                }
2600                if let Some(ref gate_name) = gate_name {
2601                    writeln!(self.out, "{level}bool {gate_name} = true;")?;
2602                }
2603
2604                self.continue_ctx.enter_loop();
2605                writeln!(self.out, "{level}while(true) {{")?;
2606                if let Some((_, ref break_and_inc)) = force_loop_bound_statements {
2607                    writeln!(self.out, "{break_and_inc}")?;
2608                }
2609                let l2 = level.next();
2610                if let Some(gate_name) = gate_name {
2611                    writeln!(self.out, "{l2}if (!{gate_name}) {{")?;
2612                    let l3 = l2.next();
2613                    for sta in continuing.iter() {
2614                        self.write_stmt(module, sta, func_ctx, l3)?;
2615                    }
2616                    if let Some(condition) = break_if {
2617                        write!(self.out, "{l3}if (")?;
2618                        self.write_expr(module, condition, func_ctx)?;
2619                        writeln!(self.out, ") {{")?;
2620                        writeln!(self.out, "{}break;", l3.next())?;
2621                        writeln!(self.out, "{l3}}}")?;
2622                    }
2623                    writeln!(self.out, "{l2}}}")?;
2624                    writeln!(self.out, "{l2}{gate_name} = false;")?;
2625                }
2626
2627                for sta in body.iter() {
2628                    self.write_stmt(module, sta, func_ctx, l2)?;
2629                }
2630
2631                writeln!(self.out, "{level}}}")?;
2632                self.continue_ctx.exit_loop();
2633            }
2634            Statement::Break => writeln!(self.out, "{level}break;")?,
2635            Statement::Continue => {
2636                if let Some(variable) = self.continue_ctx.continue_encountered() {
2637                    writeln!(self.out, "{level}{variable} = true;")?;
2638                    writeln!(self.out, "{level}break;")?
2639                } else {
2640                    writeln!(self.out, "{level}continue;")?
2641                }
2642            }
2643            Statement::ControlBarrier(barrier) => {
2644                self.write_control_barrier(barrier, level)?;
2645            }
2646            Statement::MemoryBarrier(barrier) => {
2647                self.write_memory_barrier(barrier, level)?;
2648            }
2649            Statement::ImageStore {
2650                image,
2651                coordinate,
2652                array_index,
2653                value,
2654            } => {
2655                write!(self.out, "{level}")?;
2656                self.write_expr(module, image, func_ctx)?;
2657
2658                write!(self.out, "[")?;
2659                if let Some(index) = array_index {
2660                    // Array index accepted only for texture_storage_2d_array, so we can safety use int3(coordinate, array_index) here
2661                    write!(self.out, "int3(")?;
2662                    self.write_expr(module, coordinate, func_ctx)?;
2663                    write!(self.out, ", ")?;
2664                    self.write_expr(module, index, func_ctx)?;
2665                    write!(self.out, ")")?;
2666                } else {
2667                    self.write_expr(module, coordinate, func_ctx)?;
2668                }
2669                write!(self.out, "]")?;
2670
2671                write!(self.out, " = ")?;
2672                self.write_expr(module, value, func_ctx)?;
2673                writeln!(self.out, ";")?;
2674            }
2675            Statement::Call {
2676                function,
2677                ref arguments,
2678                result,
2679            } => {
2680                write!(self.out, "{level}")?;
2681
2682                if let Some(expr) = result {
2683                    write!(self.out, "const ")?;
2684                    let name = Baked(expr).to_string();
2685                    let expr_ty = &func_ctx.info[expr].ty;
2686                    let ty_inner = match *expr_ty {
2687                        proc::TypeResolution::Handle(handle) => {
2688                            self.write_type(module, handle)?;
2689                            &module.types[handle].inner
2690                        }
2691                        proc::TypeResolution::Value(ref value) => {
2692                            self.write_value_type(module, value)?;
2693                            value
2694                        }
2695                    };
2696                    write!(self.out, " {name}")?;
2697                    if let TypeInner::Array { base, size, .. } = *ty_inner {
2698                        self.write_array_size(module, base, size)?;
2699                    }
2700                    write!(self.out, " = ")?;
2701                    self.named_expressions.insert(expr, name);
2702                }
2703                let func_name = &self.names[&NameKey::Function(function)];
2704                write!(self.out, "{func_name}(")?;
2705                let mut any_args_written = false;
2706                let mut separator = || {
2707                    if any_args_written {
2708                        ", "
2709                    } else {
2710                        any_args_written = true;
2711                        ""
2712                    }
2713                };
2714                for argument in arguments {
2715                    write!(self.out, "{}", separator())?;
2716                    self.write_expr(module, *argument, func_ctx)?;
2717                }
2718                if let Some(&var) = self.function_task_payload_var.get(&function) {
2719                    let name = &self.names[&NameKey::GlobalVariable(var)];
2720                    // Pass it through directly, whether its an in variable to this function or the global variable
2721                    write!(self.out, "{}{name}", separator())?;
2722                }
2723                writeln!(self.out, ");")?;
2724            }
2725            Statement::Atomic {
2726                pointer,
2727                ref fun,
2728                value,
2729                result,
2730            } => {
2731                write!(self.out, "{level}")?;
2732                let res_var_info = if let Some(res_handle) = result {
2733                    let name = Baked(res_handle).to_string();
2734                    match func_ctx.info[res_handle].ty {
2735                        proc::TypeResolution::Handle(handle) => self.write_type(module, handle)?,
2736                        proc::TypeResolution::Value(ref value) => {
2737                            self.write_value_type(module, value)?
2738                        }
2739                    };
2740                    write!(self.out, " {name}; ")?;
2741                    self.named_expressions.insert(res_handle, name.clone());
2742                    Some((res_handle, name))
2743                } else {
2744                    None
2745                };
2746                let pointer_space = func_ctx
2747                    .resolve_type(pointer, &module.types)
2748                    .pointer_space()
2749                    .unwrap();
2750                let fun_str = fun.to_hlsl_suffix();
2751                let compare_expr = match *fun {
2752                    crate::AtomicFunction::Exchange { compare: Some(cmp) } => Some(cmp),
2753                    _ => None,
2754                };
2755                match pointer_space {
2756                    crate::AddressSpace::WorkGroup => {
2757                        write!(self.out, "Interlocked{fun_str}(")?;
2758                        self.write_expr(module, pointer, func_ctx)?;
2759                        self.emit_hlsl_atomic_tail(
2760                            module,
2761                            func_ctx,
2762                            fun,
2763                            compare_expr,
2764                            value,
2765                            &res_var_info,
2766                        )?;
2767                    }
2768                    crate::AddressSpace::Storage { .. } => {
2769                        let var_handle = self.fill_access_chain(module, pointer, func_ctx)?;
2770                        let var_name = &self.names[&NameKey::GlobalVariable(var_handle)];
2771                        let width = match func_ctx.resolve_type(value, &module.types) {
2772                            &TypeInner::Scalar(Scalar { width: 8, .. }) => "64",
2773                            _ => "",
2774                        };
2775                        write!(self.out, "{var_name}.Interlocked{fun_str}{width}(")?;
2776                        let chain = mem::take(&mut self.temp_access_chain);
2777                        self.write_storage_address(module, &chain, func_ctx)?;
2778                        self.temp_access_chain = chain;
2779                        self.emit_hlsl_atomic_tail(
2780                            module,
2781                            func_ctx,
2782                            fun,
2783                            compare_expr,
2784                            value,
2785                            &res_var_info,
2786                        )?;
2787                    }
2788                    ref other => {
2789                        return Err(Error::Custom(format!(
2790                            "invalid address space {other:?} for atomic statement"
2791                        )))
2792                    }
2793                }
2794                if let Some(cmp) = compare_expr {
2795                    if let Some(&(_res_handle, ref res_name)) = res_var_info.as_ref() {
2796                        write!(
2797                            self.out,
2798                            "{level}{res_name}.exchanged = ({res_name}.old_value == "
2799                        )?;
2800                        self.write_expr(module, cmp, func_ctx)?;
2801                        writeln!(self.out, ");")?;
2802                    }
2803                }
2804            }
2805            Statement::ImageAtomic {
2806                image,
2807                coordinate,
2808                array_index,
2809                fun,
2810                value,
2811            } => {
2812                write!(self.out, "{level}")?;
2813
2814                let fun_str = fun.to_hlsl_suffix();
2815                write!(self.out, "Interlocked{fun_str}(")?;
2816                self.write_expr(module, image, func_ctx)?;
2817                write!(self.out, "[")?;
2818                self.write_texture_coordinates(
2819                    "int",
2820                    coordinate,
2821                    array_index,
2822                    None,
2823                    module,
2824                    func_ctx,
2825                )?;
2826                write!(self.out, "],")?;
2827
2828                self.write_expr(module, value, func_ctx)?;
2829                writeln!(self.out, ");")?;
2830            }
2831            Statement::WorkGroupUniformLoad { pointer, result } => {
2832                self.write_control_barrier(crate::Barrier::WORK_GROUP, level)?;
2833                write!(self.out, "{level}")?;
2834                let name = Baked(result).to_string();
2835                self.write_named_expr(module, pointer, name, result, func_ctx)?;
2836
2837                self.write_control_barrier(crate::Barrier::WORK_GROUP, level)?;
2838            }
2839            Statement::Switch {
2840                selector,
2841                ref cases,
2842            } => {
2843                self.write_switch(module, func_ctx, level, selector, cases)?;
2844            }
2845            Statement::RayQuery { query, ref fun } => {
2846                // There are three possibilities for a ptr to be:
2847                // 1. A variable
2848                // 2. A function argument
2849                // 3. part of a struct
2850                //
2851                // 2 and 3 are not possible, a ray query (in naga IR)
2852                // is not allowed to be passed into a function, and
2853                // all languages disallow it in a struct (you get fun results if
2854                // you try it :) ).
2855                //
2856                // Therefore, the ray query expression must be a variable.
2857                let crate::Expression::LocalVariable(query_var) = func_ctx.expressions[query]
2858                else {
2859                    unreachable!()
2860                };
2861
2862                let tracker_expr_name = format!(
2863                    "{RAY_QUERY_TRACKER_VARIABLE_PREFIX}{}",
2864                    self.names[&func_ctx.name_key(query_var)]
2865                );
2866
2867                match *fun {
2868                    RayQueryFunction::Initialize {
2869                        acceleration_structure,
2870                        descriptor,
2871                    } => {
2872                        self.write_initialize_function(
2873                            module,
2874                            level,
2875                            query,
2876                            acceleration_structure,
2877                            descriptor,
2878                            &tracker_expr_name,
2879                            func_ctx,
2880                        )?;
2881                    }
2882                    RayQueryFunction::Proceed { result } => {
2883                        self.write_proceed(
2884                            module,
2885                            level,
2886                            query,
2887                            result,
2888                            &tracker_expr_name,
2889                            func_ctx,
2890                        )?;
2891                    }
2892                    RayQueryFunction::GenerateIntersection { hit_t } => {
2893                        self.write_generate_intersection(
2894                            module,
2895                            level,
2896                            query,
2897                            hit_t,
2898                            &tracker_expr_name,
2899                            func_ctx,
2900                        )?;
2901                    }
2902                    RayQueryFunction::ConfirmIntersection => {
2903                        self.write_confirm_intersection(
2904                            module,
2905                            level,
2906                            query,
2907                            &tracker_expr_name,
2908                            func_ctx,
2909                        )?;
2910                    }
2911                    RayQueryFunction::Terminate => {
2912                        self.write_terminate(module, level, query, &tracker_expr_name, func_ctx)?;
2913                    }
2914                }
2915            }
2916            Statement::SubgroupBallot { result, predicate } => {
2917                write!(self.out, "{level}")?;
2918                let name = Baked(result).to_string();
2919                write!(self.out, "const uint4 {name} = ")?;
2920                self.named_expressions.insert(result, name);
2921
2922                write!(self.out, "WaveActiveBallot(")?;
2923                match predicate {
2924                    Some(predicate) => self.write_expr(module, predicate, func_ctx)?,
2925                    None => write!(self.out, "true")?,
2926                }
2927                writeln!(self.out, ");")?;
2928            }
2929            Statement::SubgroupCollectiveOperation {
2930                op,
2931                collective_op,
2932                argument,
2933                result,
2934            } => {
2935                write!(self.out, "{level}")?;
2936                write!(self.out, "const ")?;
2937                let name = Baked(result).to_string();
2938                match func_ctx.info[result].ty {
2939                    proc::TypeResolution::Handle(handle) => self.write_type(module, handle)?,
2940                    proc::TypeResolution::Value(ref value) => {
2941                        self.write_value_type(module, value)?
2942                    }
2943                };
2944                write!(self.out, " {name} = ")?;
2945                self.named_expressions.insert(result, name);
2946
2947                match (collective_op, op) {
2948                    (crate::CollectiveOperation::Reduce, crate::SubgroupOperation::All) => {
2949                        write!(self.out, "WaveActiveAllTrue(")?
2950                    }
2951                    (crate::CollectiveOperation::Reduce, crate::SubgroupOperation::Any) => {
2952                        write!(self.out, "WaveActiveAnyTrue(")?
2953                    }
2954                    (crate::CollectiveOperation::Reduce, crate::SubgroupOperation::Add) => {
2955                        write!(self.out, "WaveActiveSum(")?
2956                    }
2957                    (crate::CollectiveOperation::Reduce, crate::SubgroupOperation::Mul) => {
2958                        write!(self.out, "WaveActiveProduct(")?
2959                    }
2960                    (crate::CollectiveOperation::Reduce, crate::SubgroupOperation::Max) => {
2961                        write!(self.out, "WaveActiveMax(")?
2962                    }
2963                    (crate::CollectiveOperation::Reduce, crate::SubgroupOperation::Min) => {
2964                        write!(self.out, "WaveActiveMin(")?
2965                    }
2966                    (crate::CollectiveOperation::Reduce, crate::SubgroupOperation::And) => {
2967                        write!(self.out, "WaveActiveBitAnd(")?
2968                    }
2969                    (crate::CollectiveOperation::Reduce, crate::SubgroupOperation::Or) => {
2970                        write!(self.out, "WaveActiveBitOr(")?
2971                    }
2972                    (crate::CollectiveOperation::Reduce, crate::SubgroupOperation::Xor) => {
2973                        write!(self.out, "WaveActiveBitXor(")?
2974                    }
2975                    (crate::CollectiveOperation::ExclusiveScan, crate::SubgroupOperation::Add) => {
2976                        write!(self.out, "WavePrefixSum(")?
2977                    }
2978                    (crate::CollectiveOperation::ExclusiveScan, crate::SubgroupOperation::Mul) => {
2979                        write!(self.out, "WavePrefixProduct(")?
2980                    }
2981                    (crate::CollectiveOperation::InclusiveScan, crate::SubgroupOperation::Add) => {
2982                        self.write_expr(module, argument, func_ctx)?;
2983                        write!(self.out, " + WavePrefixSum(")?;
2984                    }
2985                    (crate::CollectiveOperation::InclusiveScan, crate::SubgroupOperation::Mul) => {
2986                        self.write_expr(module, argument, func_ctx)?;
2987                        write!(self.out, " * WavePrefixProduct(")?;
2988                    }
2989                    _ => unimplemented!(),
2990                }
2991                self.write_expr(module, argument, func_ctx)?;
2992                writeln!(self.out, ");")?;
2993            }
2994            Statement::SubgroupGather {
2995                mode,
2996                argument,
2997                result,
2998            } => {
2999                write!(self.out, "{level}")?;
3000                write!(self.out, "const ")?;
3001                let name = Baked(result).to_string();
3002                match func_ctx.info[result].ty {
3003                    proc::TypeResolution::Handle(handle) => self.write_type(module, handle)?,
3004                    proc::TypeResolution::Value(ref value) => {
3005                        self.write_value_type(module, value)?
3006                    }
3007                };
3008                write!(self.out, " {name} = ")?;
3009                self.named_expressions.insert(result, name);
3010                match mode {
3011                    crate::GatherMode::BroadcastFirst => {
3012                        write!(self.out, "WaveReadLaneFirst(")?;
3013                        self.write_expr(module, argument, func_ctx)?;
3014                    }
3015                    crate::GatherMode::QuadBroadcast(index) => {
3016                        write!(self.out, "QuadReadLaneAt(")?;
3017                        self.write_expr(module, argument, func_ctx)?;
3018                        write!(self.out, ", ")?;
3019                        self.write_expr(module, index, func_ctx)?;
3020                    }
3021                    crate::GatherMode::QuadSwap(direction) => {
3022                        match direction {
3023                            crate::Direction::X => {
3024                                write!(self.out, "QuadReadAcrossX(")?;
3025                            }
3026                            crate::Direction::Y => {
3027                                write!(self.out, "QuadReadAcrossY(")?;
3028                            }
3029                            crate::Direction::Diagonal => {
3030                                write!(self.out, "QuadReadAcrossDiagonal(")?;
3031                            }
3032                        }
3033                        self.write_expr(module, argument, func_ctx)?;
3034                    }
3035                    _ => {
3036                        write!(self.out, "WaveReadLaneAt(")?;
3037                        self.write_expr(module, argument, func_ctx)?;
3038                        write!(self.out, ", ")?;
3039                        match mode {
3040                            crate::GatherMode::BroadcastFirst => unreachable!(),
3041                            crate::GatherMode::Broadcast(index)
3042                            | crate::GatherMode::Shuffle(index) => {
3043                                self.write_expr(module, index, func_ctx)?;
3044                            }
3045                            crate::GatherMode::ShuffleDown(index) => {
3046                                write!(self.out, "WaveGetLaneIndex() + ")?;
3047                                self.write_expr(module, index, func_ctx)?;
3048                            }
3049                            crate::GatherMode::ShuffleUp(index) => {
3050                                write!(self.out, "WaveGetLaneIndex() - ")?;
3051                                self.write_expr(module, index, func_ctx)?;
3052                            }
3053                            crate::GatherMode::ShuffleXor(index) => {
3054                                write!(self.out, "WaveGetLaneIndex() ^ ")?;
3055                                self.write_expr(module, index, func_ctx)?;
3056                            }
3057                            crate::GatherMode::QuadBroadcast(_) => unreachable!(),
3058                            crate::GatherMode::QuadSwap(_) => unreachable!(),
3059                        }
3060                    }
3061                }
3062                writeln!(self.out, ");")?;
3063            }
3064            Statement::CooperativeStore { .. } => unimplemented!(),
3065            Statement::RayPipelineFunction(_) => unreachable!(),
3066        }
3067
3068        Ok(())
3069    }
3070
3071    fn write_const_expression(
3072        &mut self,
3073        module: &Module,
3074        expr: Handle<crate::Expression>,
3075        arena: &crate::Arena<crate::Expression>,
3076    ) -> BackendResult {
3077        self.write_possibly_const_expression(module, expr, arena, |writer, expr| {
3078            writer.write_const_expression(module, expr, arena)
3079        })
3080    }
3081
3082    pub(super) fn write_literal(&mut self, literal: crate::Literal) -> BackendResult {
3083        match literal {
3084            crate::Literal::F64(value) => write!(self.out, "{value:?}L")?,
3085            crate::Literal::F32(value) => write!(self.out, "{value:?}")?,
3086            crate::Literal::F16(value) => write!(self.out, "{value:?}h")?,
3087            crate::Literal::U16(value) => write!(self.out, "uint16_t({value})")?,
3088            crate::Literal::I16(value) => write!(self.out, "int16_t({value})")?,
3089            crate::Literal::U32(value) => write!(self.out, "{value}u")?,
3090            // `-2147483648` is parsed by some compilers as unary negation of
3091            // positive 2147483648, which is too large for an int, causing
3092            // issues for some compilers. Neither DXC nor FXC appear to have
3093            // this problem, but this is not specified and could change. We
3094            // therefore use `-2147483647 - 1` as a precaution.
3095            crate::Literal::I32(value) if value == i32::MIN => {
3096                write!(self.out, "int({} - 1)", value + 1)?
3097            }
3098            // HLSL has no suffix for explicit i32 literals, but not using any suffix
3099            // makes the type ambiguous which prevents overload resolution from
3100            // working. So we explicitly use the int() constructor syntax.
3101            crate::Literal::I32(value) => write!(self.out, "int({value})")?,
3102            crate::Literal::U64(value) => write!(self.out, "{value}uL")?,
3103            // I64 version of the minimum I32 value issue described above.
3104            crate::Literal::I64(value) if value == i64::MIN => {
3105                write!(self.out, "({}L - 1L)", value + 1)?;
3106            }
3107            crate::Literal::I64(value) => write!(self.out, "{value}L")?,
3108            crate::Literal::Bool(value) => write!(self.out, "{value}")?,
3109            crate::Literal::AbstractInt(_) | crate::Literal::AbstractFloat(_) => {
3110                return Err(Error::Custom(
3111                    "Abstract types should not appear in IR presented to backends".into(),
3112                ));
3113            }
3114        }
3115        Ok(())
3116    }
3117
3118    fn write_possibly_const_expression<E>(
3119        &mut self,
3120        module: &Module,
3121        expr: Handle<crate::Expression>,
3122        expressions: &crate::Arena<crate::Expression>,
3123        write_expression: E,
3124    ) -> BackendResult
3125    where
3126        E: Fn(&mut Self, Handle<crate::Expression>) -> BackendResult,
3127    {
3128        use crate::Expression;
3129
3130        match expressions[expr] {
3131            Expression::Literal(literal) => {
3132                self.write_literal(literal)?;
3133            }
3134            Expression::Constant(handle) => {
3135                let constant = &module.constants[handle];
3136                if constant.name.is_some() {
3137                    write!(self.out, "{}", self.names[&NameKey::Constant(handle)])?;
3138                } else {
3139                    self.write_const_expression(module, constant.init, &module.global_expressions)?;
3140                }
3141            }
3142            Expression::ZeroValue(ty) => {
3143                self.write_wrapped_zero_value_function_name(module, WrappedZeroValue { ty })?;
3144                write!(self.out, "()")?;
3145            }
3146            Expression::Compose { ty, ref components } => {
3147                match module.types[ty].inner {
3148                    TypeInner::Struct { .. } | TypeInner::Array { .. } => {
3149                        self.write_wrapped_constructor_function_name(
3150                            module,
3151                            WrappedConstructor { ty },
3152                        )?;
3153                    }
3154                    _ => {
3155                        self.write_type(module, ty)?;
3156                    }
3157                };
3158                write!(self.out, "(")?;
3159                for (index, component) in components.iter().enumerate() {
3160                    if index != 0 {
3161                        write!(self.out, ", ")?;
3162                    }
3163                    write_expression(self, *component)?;
3164                }
3165                write!(self.out, ")")?;
3166            }
3167            Expression::Splat { size, value } => {
3168                // hlsl is not supported one value constructor
3169                // if we write, for example, int4(0), dxc returns error:
3170                // error: too few elements in vector initialization (expected 4 elements, have 1)
3171                let number_of_components = match size {
3172                    crate::VectorSize::Bi => "xx",
3173                    crate::VectorSize::Tri => "xxx",
3174                    crate::VectorSize::Quad => "xxxx",
3175                };
3176                write!(self.out, "(")?;
3177                write_expression(self, value)?;
3178                write!(self.out, ").{number_of_components}")?
3179            }
3180            _ => {
3181                return Err(Error::Override);
3182            }
3183        }
3184
3185        Ok(())
3186    }
3187
3188    /// Helper method to write expressions
3189    ///
3190    /// # Notes
3191    /// Doesn't add any newlines or leading/trailing spaces
3192    pub(super) fn write_expr(
3193        &mut self,
3194        module: &Module,
3195        expr: Handle<crate::Expression>,
3196        func_ctx: &back::FunctionCtx<'_>,
3197    ) -> BackendResult {
3198        use crate::Expression;
3199
3200        // Handle the special semantics of vertex_index/instance_index
3201        let ff_input = if self.options.special_constants_binding.is_some() {
3202            func_ctx.is_fixed_function_input(expr, module)
3203        } else {
3204            None
3205        };
3206        let closing_bracket = match ff_input {
3207            Some(crate::BuiltIn::VertexIndex) => {
3208                write!(self.out, "({SPECIAL_CBUF_VAR}.{SPECIAL_FIRST_VERTEX} + ")?;
3209                ")"
3210            }
3211            Some(crate::BuiltIn::InstanceIndex) => {
3212                write!(self.out, "({SPECIAL_CBUF_VAR}.{SPECIAL_FIRST_INSTANCE} + ",)?;
3213                ")"
3214            }
3215            Some(crate::BuiltIn::NumWorkGroups) => {
3216                // Note: despite their names (`FIRST_VERTEX` and `FIRST_INSTANCE`),
3217                // in compute shaders the special constants contain the number
3218                // of workgroups, which we are using here.
3219                write!(
3220                    self.out,
3221                    "uint3({SPECIAL_CBUF_VAR}.{SPECIAL_FIRST_VERTEX}, {SPECIAL_CBUF_VAR}.{SPECIAL_FIRST_INSTANCE}, {SPECIAL_CBUF_VAR}.{SPECIAL_OTHER})",
3222                )?;
3223                return Ok(());
3224            }
3225            _ => "",
3226        };
3227
3228        if let Some(name) = self.named_expressions.get(&expr) {
3229            write!(self.out, "{name}{closing_bracket}")?;
3230            return Ok(());
3231        }
3232
3233        let expression = &func_ctx.expressions[expr];
3234
3235        match *expression {
3236            Expression::Literal(_)
3237            | Expression::Constant(_)
3238            | Expression::ZeroValue(_)
3239            | Expression::Compose { .. }
3240            | Expression::Splat { .. } => {
3241                self.write_possibly_const_expression(
3242                    module,
3243                    expr,
3244                    func_ctx.expressions,
3245                    |writer, expr| writer.write_expr(module, expr, func_ctx),
3246                )?;
3247            }
3248            Expression::Override(_) => return Err(Error::Override),
3249            // Avoid undefined behaviour for addition, subtraction, and
3250            // multiplication of signed integers by casting operands to
3251            // unsigned, performing the operation, then casting the result back
3252            // to signed.
3253            // TODO(#7109): This relies on the asint()/asuint() functions which only work
3254            // for 32-bit types, so we must find another solution for different bit widths.
3255            Expression::Binary {
3256                op:
3257                    op @ crate::BinaryOperator::Add
3258                    | op @ crate::BinaryOperator::Subtract
3259                    | op @ crate::BinaryOperator::Multiply,
3260                left,
3261                right,
3262            } if matches!(
3263                func_ctx.resolve_type(expr, &module.types).scalar(),
3264                Some(Scalar::I32)
3265            ) =>
3266            {
3267                write!(self.out, "asint(asuint(",)?;
3268                self.write_expr(module, left, func_ctx)?;
3269                write!(self.out, ") {} asuint(", back::binary_operation_str(op))?;
3270                self.write_expr(module, right, func_ctx)?;
3271                write!(self.out, "))")?;
3272            }
3273            // All of the multiplication can be expressed as `mul`,
3274            // except vector * vector, which needs to use the "*" operator.
3275            Expression::Binary {
3276                op: crate::BinaryOperator::Multiply,
3277                left,
3278                right,
3279            } if func_ctx.resolve_type(left, &module.types).is_matrix()
3280                || func_ctx.resolve_type(right, &module.types).is_matrix() =>
3281            {
3282                // We intentionally flip the order of multiplication as our matrices are implicitly transposed.
3283                write!(self.out, "mul(")?;
3284                self.write_expr(module, right, func_ctx)?;
3285                write!(self.out, ", ")?;
3286                self.write_expr(module, left, func_ctx)?;
3287                write!(self.out, ")")?;
3288            }
3289
3290            // WGSL says that floating-point division by zero should return
3291            // infinity. Microsoft's Direct3D 11 functional specification
3292            // (https://microsoft.github.io/DirectX-Specs/d3d/archive/D3D11_3_FunctionalSpec.htm)
3293            // says:
3294            //
3295            //     Divide by 0 produces +/- INF, except 0/0 which results in NaN.
3296            //
3297            // which is what we want. The DXIL specification for the FDiv
3298            // instruction corroborates this:
3299            //
3300            // https://github.com/microsoft/DirectXShaderCompiler/blob/main/docs/DXIL.rst#fdiv
3301            Expression::Binary {
3302                op: crate::BinaryOperator::Divide,
3303                left,
3304                right,
3305            } if matches!(
3306                func_ctx.resolve_type(expr, &module.types).scalar_kind(),
3307                Some(ScalarKind::Sint | ScalarKind::Uint)
3308            ) =>
3309            {
3310                write!(self.out, "{DIV_FUNCTION}(")?;
3311                self.write_expr(module, left, func_ctx)?;
3312                write!(self.out, ", ")?;
3313                self.write_expr(module, right, func_ctx)?;
3314                write!(self.out, ")")?;
3315            }
3316
3317            Expression::Binary {
3318                op: crate::BinaryOperator::Modulo,
3319                left,
3320                right,
3321            } if matches!(
3322                func_ctx.resolve_type(expr, &module.types).scalar_kind(),
3323                Some(ScalarKind::Sint | ScalarKind::Uint | ScalarKind::Float)
3324            ) =>
3325            {
3326                write!(self.out, "{MOD_FUNCTION}(")?;
3327                self.write_expr(module, left, func_ctx)?;
3328                write!(self.out, ", ")?;
3329                self.write_expr(module, right, func_ctx)?;
3330                write!(self.out, ")")?;
3331            }
3332
3333            Expression::Binary { op, left, right } => {
3334                write!(self.out, "(")?;
3335                self.write_expr(module, left, func_ctx)?;
3336                write!(self.out, " {} ", back::binary_operation_str(op))?;
3337                self.write_expr(module, right, func_ctx)?;
3338                write!(self.out, ")")?;
3339            }
3340            Expression::Access { base, index } => {
3341                if let Some(crate::AddressSpace::Storage { .. }) =
3342                    func_ctx.resolve_type(expr, &module.types).pointer_space()
3343                {
3344                    // do nothing, the chain is written on `Load`/`Store`
3345                } else {
3346                    // We use the function __get_col_of_matCx2 here in cases
3347                    // where `base`s type resolves to a matCx2 and is part of a
3348                    // struct member with type of (possibly nested) array of matCx2's.
3349                    //
3350                    // Note that this only works for `Load`s and we handle
3351                    // `Store`s differently in `Statement::Store`.
3352                    if let Some(MatrixType {
3353                        columns,
3354                        rows: crate::VectorSize::Bi,
3355                        width,
3356                    }) = get_inner_matrix_of_struct_array_member(module, base, func_ctx, true)
3357                        .or_else(|| {
3358                            get_inner_matrix_of_global_uniform(module, base, func_ctx, true)
3359                        })
3360                    {
3361                        write!(
3362                            self.out,
3363                            "__get_col_of_mat{}x2_f{}(",
3364                            columns as u8,
3365                            width * 8
3366                        )?;
3367                        self.write_expr(module, base, func_ctx)?;
3368                        write!(self.out, ", ")?;
3369                        self.write_expr(module, index, func_ctx)?;
3370                        write!(self.out, ")")?;
3371                        return Ok(());
3372                    }
3373
3374                    let resolved = func_ctx.resolve_type(base, &module.types);
3375
3376                    let (indexing_binding_array, non_uniform_qualifier) = match *resolved {
3377                        TypeInner::BindingArray { .. } => {
3378                            let uniformity = &func_ctx.info[index].uniformity;
3379
3380                            (true, uniformity.non_uniform_result.is_some())
3381                        }
3382                        _ => (false, false),
3383                    };
3384
3385                    self.write_expr(module, base, func_ctx)?;
3386
3387                    let array_sampler_info = self.sampler_binding_array_info_from_expression(
3388                        module, func_ctx, base, resolved,
3389                    );
3390
3391                    if let Some(ref info) = array_sampler_info {
3392                        write!(self.out, "{}[", info.sampler_heap_name)?;
3393                    } else {
3394                        write!(self.out, "[")?;
3395                    }
3396
3397                    let needs_bound_check = self.options.restrict_indexing
3398                        && !indexing_binding_array
3399                        && match resolved.pointer_space() {
3400                            Some(
3401                                crate::AddressSpace::Function
3402                                | crate::AddressSpace::Private
3403                                | crate::AddressSpace::WorkGroup
3404                                | crate::AddressSpace::Immediate
3405                                | crate::AddressSpace::TaskPayload
3406                                | crate::AddressSpace::RayPayload
3407                                | crate::AddressSpace::IncomingRayPayload,
3408                            )
3409                            | None => true,
3410                            Some(crate::AddressSpace::Uniform) => {
3411                                // check if BindTarget.restrict_indexing is set, this is used for dynamic buffers
3412                                let var_handle = self.fill_access_chain(module, base, func_ctx)?;
3413                                let bind_target = self
3414                                    .options
3415                                    .resolve_resource_binding(
3416                                        module.global_variables[var_handle]
3417                                            .binding
3418                                            .as_ref()
3419                                            .unwrap(),
3420                                    )
3421                                    .unwrap();
3422                                bind_target.restrict_indexing
3423                            }
3424                            Some(
3425                                crate::AddressSpace::Handle | crate::AddressSpace::Storage { .. },
3426                            ) => unreachable!(),
3427                        };
3428                    // Decide whether this index needs to be clamped to fall within range.
3429                    let restriction_needed = if needs_bound_check {
3430                        index::access_needs_check(
3431                            base,
3432                            index::GuardedIndex::Expression(index),
3433                            module,
3434                            func_ctx.expressions,
3435                            func_ctx.info,
3436                        )
3437                    } else {
3438                        None
3439                    };
3440                    if let Some(limit) = restriction_needed {
3441                        write!(self.out, "min(uint(")?;
3442                        self.write_expr(module, index, func_ctx)?;
3443                        write!(self.out, "), ")?;
3444                        match limit {
3445                            index::IndexableLength::Known(limit) => {
3446                                write!(self.out, "{}u", limit - 1)?;
3447                            }
3448                            index::IndexableLength::Dynamic => unreachable!(),
3449                        }
3450                        write!(self.out, ")")?;
3451                    } else {
3452                        if non_uniform_qualifier {
3453                            write!(self.out, "NonUniformResourceIndex(")?;
3454                        }
3455                        if let Some(ref info) = array_sampler_info {
3456                            write!(
3457                                self.out,
3458                                "{}[{} + ",
3459                                info.sampler_index_buffer_name, info.binding_array_base_index_name,
3460                            )?;
3461                        }
3462                        self.write_expr(module, index, func_ctx)?;
3463                        if array_sampler_info.is_some() {
3464                            write!(self.out, "]")?;
3465                        }
3466                        if non_uniform_qualifier {
3467                            write!(self.out, ")")?;
3468                        }
3469                    }
3470
3471                    write!(self.out, "]")?;
3472                }
3473            }
3474            Expression::AccessIndex { base, index } => {
3475                if let Some(crate::AddressSpace::Storage { .. }) =
3476                    func_ctx.resolve_type(expr, &module.types).pointer_space()
3477                {
3478                    // do nothing, the chain is written on `Load`/`Store`
3479                } else {
3480                    // See if we need to write the matrix column access in a
3481                    // special way since the type of `base` is our special
3482                    // __matCx2 struct.
3483                    if let Some(MatrixType {
3484                        rows: crate::VectorSize::Bi,
3485                        ..
3486                    }) = get_inner_matrix_of_struct_array_member(module, base, func_ctx, true)
3487                        .or_else(|| {
3488                            get_inner_matrix_of_global_uniform(module, base, func_ctx, true)
3489                        })
3490                    {
3491                        self.write_expr(module, base, func_ctx)?;
3492                        write!(self.out, "._{index}")?;
3493                        return Ok(());
3494                    }
3495
3496                    let base_ty_res = &func_ctx.info[base].ty;
3497                    let mut resolved = base_ty_res.inner_with(&module.types);
3498                    let base_ty_handle = match *resolved {
3499                        TypeInner::Pointer { base, .. } => {
3500                            resolved = &module.types[base].inner;
3501                            Some(base)
3502                        }
3503                        _ => base_ty_res.handle(),
3504                    };
3505
3506                    // We treat matrices of the form `matCx2` as a sequence of C `vec2`s.
3507                    // See the module-level block comment in mod.rs for details.
3508                    //
3509                    // We handle matrix reconstruction here for Loads.
3510                    // Stores are handled directly by `Statement::Store`.
3511                    if let TypeInner::Struct { ref members, .. } = *resolved {
3512                        let member = &members[index as usize];
3513
3514                        match module.types[member.ty].inner {
3515                            TypeInner::Matrix {
3516                                rows: crate::VectorSize::Bi,
3517                                ..
3518                            } if member.binding.is_none() => {
3519                                let ty = base_ty_handle.unwrap();
3520                                self.write_wrapped_struct_matrix_get_function_name(
3521                                    WrappedStructMatrixAccess { ty, index },
3522                                )?;
3523                                write!(self.out, "(")?;
3524                                self.write_expr(module, base, func_ctx)?;
3525                                write!(self.out, ")")?;
3526                                return Ok(());
3527                            }
3528                            _ => {}
3529                        }
3530                    }
3531
3532                    let array_sampler_info = self.sampler_binding_array_info_from_expression(
3533                        module, func_ctx, base, resolved,
3534                    );
3535
3536                    if let Some(ref info) = array_sampler_info {
3537                        write!(
3538                            self.out,
3539                            "{}[{}",
3540                            info.sampler_heap_name, info.sampler_index_buffer_name
3541                        )?;
3542                    }
3543
3544                    self.write_expr(module, base, func_ctx)?;
3545
3546                    match *resolved {
3547                        // We specifically lift the ValuePointer to this case. While `[0]` is valid
3548                        // HLSL for any vector behind a value pointer, FXC completely miscompiles
3549                        // it and generates completely nonsensical DXBC.
3550                        //
3551                        // See https://github.com/gfx-rs/naga/issues/2095 for more details.
3552                        TypeInner::Vector { .. } | TypeInner::ValuePointer { .. } => {
3553                            // Write vector access as a swizzle
3554                            write!(self.out, ".{}", back::COMPONENTS[index as usize])?
3555                        }
3556                        TypeInner::Matrix { .. }
3557                        | TypeInner::Array { .. }
3558                        | TypeInner::BindingArray { .. } => {
3559                            if let Some(ref info) = array_sampler_info {
3560                                write!(
3561                                    self.out,
3562                                    "[{} + {index}]",
3563                                    info.binding_array_base_index_name
3564                                )?;
3565                            } else {
3566                                write!(self.out, "[{index}]")?;
3567                            }
3568                        }
3569                        TypeInner::Struct { .. } => {
3570                            // This will never panic in case the type is a `Struct`, this is not true
3571                            // for other types so we can only check while inside this match arm
3572                            let ty = base_ty_handle.unwrap();
3573
3574                            write!(
3575                                self.out,
3576                                ".{}",
3577                                &self.names[&NameKey::StructMember(ty, index)]
3578                            )?
3579                        }
3580                        ref other => return Err(Error::Custom(format!("Cannot index {other:?}"))),
3581                    }
3582
3583                    if array_sampler_info.is_some() {
3584                        write!(self.out, "]")?;
3585                    }
3586                }
3587            }
3588            Expression::FunctionArgument(pos) => {
3589                let ty = func_ctx.resolve_type(expr, &module.types);
3590
3591                // We know that any external texture function argument has been expanded into
3592                // separate consecutive arguments for each plane and the parameters buffer. And we
3593                // also know that external textures can only ever be used as an argument to another
3594                // function. Therefore we can simply emit each of the expanded arguments in a
3595                // consecutive comma-separated list.
3596                if let TypeInner::Image {
3597                    class: crate::ImageClass::External,
3598                    ..
3599                } = *ty
3600                {
3601                    let plane_names = [0, 1, 2].map(|i| {
3602                        &self.names[&func_ctx
3603                            .external_texture_argument_key(pos, ExternalTextureNameKey::Plane(i))]
3604                    });
3605                    let params_name = &self.names[&func_ctx
3606                        .external_texture_argument_key(pos, ExternalTextureNameKey::Params)];
3607                    write!(
3608                        self.out,
3609                        "{}, {}, {}, {}",
3610                        plane_names[0], plane_names[1], plane_names[2], params_name
3611                    )?;
3612                } else {
3613                    let key = func_ctx.argument_key(pos);
3614                    let name = &self.names[&key];
3615                    write!(self.out, "{name}")?;
3616                }
3617            }
3618            Expression::ImageSample {
3619                coordinate,
3620                image,
3621                sampler,
3622                clamp_to_edge: true,
3623                gather: None,
3624                array_index: None,
3625                offset: None,
3626                level: crate::SampleLevel::Zero,
3627                depth_ref: None,
3628            } => {
3629                write!(self.out, "{IMAGE_SAMPLE_BASE_CLAMP_TO_EDGE_FUNCTION}(")?;
3630                self.write_expr(module, image, func_ctx)?;
3631                write!(self.out, ", ")?;
3632                self.write_expr(module, sampler, func_ctx)?;
3633                write!(self.out, ", ")?;
3634                self.write_expr(module, coordinate, func_ctx)?;
3635                write!(self.out, ")")?;
3636            }
3637            Expression::ImageSample {
3638                image,
3639                sampler,
3640                gather,
3641                coordinate,
3642                array_index,
3643                offset,
3644                level,
3645                depth_ref,
3646                clamp_to_edge,
3647            } => {
3648                if clamp_to_edge {
3649                    return Err(Error::Custom(
3650                        "ImageSample::clamp_to_edge should have been validated out".to_string(),
3651                    ));
3652                }
3653
3654                use crate::SampleLevel as Sl;
3655                const COMPONENTS: [&str; 4] = ["", "Green", "Blue", "Alpha"];
3656
3657                let (base_str, component_str) = match gather {
3658                    Some(component) => ("Gather", COMPONENTS[component as usize]),
3659                    None => ("Sample", ""),
3660                };
3661                let cmp_str = match depth_ref {
3662                    Some(_) => "Cmp",
3663                    None => "",
3664                };
3665                let level_str = match level {
3666                    Sl::Zero if gather.is_none() => "LevelZero",
3667                    Sl::Auto | Sl::Zero => "",
3668                    Sl::Exact(_) => "Level",
3669                    Sl::Bias(_) => "Bias",
3670                    Sl::Gradient { .. } => "Grad",
3671                };
3672
3673                self.write_expr(module, image, func_ctx)?;
3674                write!(self.out, ".{base_str}{cmp_str}{component_str}{level_str}(")?;
3675                self.write_expr(module, sampler, func_ctx)?;
3676                write!(self.out, ", ")?;
3677                self.write_texture_coordinates(
3678                    "float",
3679                    coordinate,
3680                    array_index,
3681                    None,
3682                    module,
3683                    func_ctx,
3684                )?;
3685
3686                if let Some(depth_ref) = depth_ref {
3687                    write!(self.out, ", ")?;
3688                    self.write_expr(module, depth_ref, func_ctx)?;
3689                }
3690
3691                match level {
3692                    Sl::Auto | Sl::Zero => {}
3693                    Sl::Exact(expr) => {
3694                        write!(self.out, ", ")?;
3695                        self.write_expr(module, expr, func_ctx)?;
3696                    }
3697                    Sl::Bias(expr) => {
3698                        write!(self.out, ", ")?;
3699                        self.write_expr(module, expr, func_ctx)?;
3700                    }
3701                    Sl::Gradient { x, y } => {
3702                        write!(self.out, ", ")?;
3703                        self.write_expr(module, x, func_ctx)?;
3704                        write!(self.out, ", ")?;
3705                        self.write_expr(module, y, func_ctx)?;
3706                    }
3707                }
3708
3709                if let Some(offset) = offset {
3710                    write!(self.out, ", ")?;
3711                    write!(self.out, "int2(")?; // work around https://github.com/microsoft/DirectXShaderCompiler/issues/5082#issuecomment-1540147807
3712                    self.write_const_expression(module, offset, func_ctx.expressions)?;
3713                    write!(self.out, ")")?;
3714                }
3715
3716                write!(self.out, ")")?;
3717            }
3718            Expression::ImageQuery { image, query } => {
3719                // use wrapped image query function
3720                if let TypeInner::Image {
3721                    dim,
3722                    arrayed,
3723                    class,
3724                } = *func_ctx.resolve_type(image, &module.types)
3725                {
3726                    let wrapped_image_query = WrappedImageQuery {
3727                        dim,
3728                        arrayed,
3729                        class,
3730                        query: query.into(),
3731                    };
3732
3733                    self.write_wrapped_image_query_function_name(wrapped_image_query)?;
3734                    write!(self.out, "(")?;
3735                    // Image always first param
3736                    self.write_expr(module, image, func_ctx)?;
3737                    if let crate::ImageQuery::Size { level: Some(level) } = query {
3738                        write!(self.out, ", ")?;
3739                        self.write_expr(module, level, func_ctx)?;
3740                    }
3741                    write!(self.out, ")")?;
3742                }
3743            }
3744            Expression::ImageLoad {
3745                image,
3746                coordinate,
3747                array_index,
3748                sample,
3749                level,
3750            } => self.write_image_load(
3751                &module,
3752                expr,
3753                func_ctx,
3754                image,
3755                coordinate,
3756                array_index,
3757                sample,
3758                level,
3759            )?,
3760            Expression::GlobalVariable(handle) => {
3761                let global_variable = &module.global_variables[handle];
3762                let ty = &module.types[global_variable.ty].inner;
3763
3764                // In the case of binding arrays of samplers, we need to not write anything
3765                // as the we are in the wrong position to fully write the expression.
3766                //
3767                // The entire writing is done by AccessIndex.
3768                let is_binding_array_of_samplers = match *ty {
3769                    TypeInner::BindingArray { base, .. } => {
3770                        let base_ty = &module.types[base].inner;
3771                        matches!(*base_ty, TypeInner::Sampler { .. })
3772                    }
3773                    _ => false,
3774                };
3775
3776                let is_storage_space =
3777                    matches!(global_variable.space, crate::AddressSpace::Storage { .. });
3778
3779                // Our external texture global variable has been expanded into multiple
3780                // global variables, one for each plane and the parameters buffer.
3781                // External textures can only ever be used as arguments to a function
3782                // call, and we know that an external texture argument to any function
3783                // will have been expanded to separate consecutive arguments for each
3784                // plane and the parameters buffer. Therefore we can simply emit each of
3785                // the expanded global variables in a consecutive comma-separated list.
3786                if let TypeInner::Image {
3787                    class: crate::ImageClass::External,
3788                    ..
3789                } = *ty
3790                {
3791                    let plane_names = [0, 1, 2].map(|i| {
3792                        &self.names[&NameKey::ExternalTextureGlobalVariable(
3793                            handle,
3794                            ExternalTextureNameKey::Plane(i),
3795                        )]
3796                    });
3797                    let params_name = &self.names[&NameKey::ExternalTextureGlobalVariable(
3798                        handle,
3799                        ExternalTextureNameKey::Params,
3800                    )];
3801                    write!(
3802                        self.out,
3803                        "{}, {}, {}, {}",
3804                        plane_names[0], plane_names[1], plane_names[2], params_name
3805                    )?;
3806                } else if !is_binding_array_of_samplers && !is_storage_space {
3807                    let name = &self.names[&NameKey::GlobalVariable(handle)];
3808                    write!(self.out, "{name}")?;
3809                }
3810            }
3811            Expression::LocalVariable(handle) => {
3812                write!(self.out, "{}", self.names[&func_ctx.name_key(handle)])?
3813            }
3814            Expression::Load { pointer } => {
3815                match func_ctx
3816                    .resolve_type(pointer, &module.types)
3817                    .pointer_space()
3818                {
3819                    Some(crate::AddressSpace::Storage { .. }) => {
3820                        let var_handle = self.fill_access_chain(module, pointer, func_ctx)?;
3821                        let result_ty = func_ctx.info[expr].ty.clone();
3822                        self.write_storage_load(module, var_handle, result_ty, func_ctx)?;
3823                    }
3824                    _ => {
3825                        let mut close_paren = false;
3826
3827                        // We cast the value loaded to a native HLSL floatCx2
3828                        // in cases where it is of type:
3829                        //  - __matCx2 or
3830                        //  - a (possibly nested) array of __matCx2's
3831                        if let Some(MatrixType {
3832                            rows: crate::VectorSize::Bi,
3833                            ..
3834                        }) = get_inner_matrix_of_struct_array_member(
3835                            module, pointer, func_ctx, false,
3836                        )
3837                        .or_else(|| {
3838                            get_inner_matrix_of_global_uniform(module, pointer, func_ctx, false)
3839                        }) {
3840                            let mut resolved = func_ctx.resolve_type(pointer, &module.types);
3841                            let ptr_tr = resolved.pointer_base_type();
3842                            if let Some(ptr_ty) =
3843                                ptr_tr.as_ref().map(|tr| tr.inner_with(&module.types))
3844                            {
3845                                resolved = ptr_ty;
3846                            }
3847
3848                            write!(self.out, "((")?;
3849                            if let TypeInner::Array { base, size, .. } = *resolved {
3850                                self.write_type(module, base)?;
3851                                self.write_array_size(module, base, size)?;
3852                            } else {
3853                                self.write_value_type(module, resolved)?;
3854                            }
3855                            write!(self.out, ")")?;
3856                            close_paren = true;
3857                        }
3858
3859                        self.write_expr(module, pointer, func_ctx)?;
3860
3861                        if close_paren {
3862                            write!(self.out, ")")?;
3863                        }
3864                    }
3865                }
3866            }
3867            Expression::Unary { op, expr } => {
3868                // https://docs.microsoft.com/en-us/windows/win32/direct3dhlsl/dx-graphics-hlsl-operators#unary-operators
3869                let op_str = match op {
3870                    crate::UnaryOperator::Negate => {
3871                        match func_ctx.resolve_type(expr, &module.types).scalar() {
3872                            Some(Scalar::I32) => NEG_FUNCTION,
3873                            _ => "-",
3874                        }
3875                    }
3876                    crate::UnaryOperator::LogicalNot => "!",
3877                    crate::UnaryOperator::BitwiseNot => "~",
3878                };
3879                write!(self.out, "{op_str}(")?;
3880                self.write_expr(module, expr, func_ctx)?;
3881                write!(self.out, ")")?;
3882            }
3883            Expression::As {
3884                expr,
3885                kind,
3886                convert,
3887            } => {
3888                let inner = func_ctx.resolve_type(expr, &module.types);
3889                if inner.scalar_kind() == Some(ScalarKind::Float)
3890                    && (kind == ScalarKind::Sint || kind == ScalarKind::Uint)
3891                    && convert.is_some()
3892                    && matches!(convert, Some(4) | Some(8))
3893                {
3894                    // Use helper functions for float to int casts in order to
3895                    // avoid undefined behaviour when value is out of range for
3896                    // the target type.
3897                    let fun_name = match (kind, convert) {
3898                        (ScalarKind::Sint, Some(4)) => F2I32_FUNCTION,
3899                        (ScalarKind::Uint, Some(4)) => F2U32_FUNCTION,
3900                        (ScalarKind::Sint, Some(8)) => F2I64_FUNCTION,
3901                        (ScalarKind::Uint, Some(8)) => F2U64_FUNCTION,
3902                        _ => unreachable!(),
3903                    };
3904                    write!(self.out, "{fun_name}(")?;
3905                    self.write_expr(module, expr, func_ctx)?;
3906                    write!(self.out, ")")?;
3907                } else {
3908                    let close_paren = match convert {
3909                        Some(dst_width) => {
3910                            let scalar = Scalar {
3911                                kind,
3912                                width: dst_width,
3913                            };
3914                            match *inner {
3915                                TypeInner::Vector { size, .. } => {
3916                                    write!(
3917                                        self.out,
3918                                        "{}{}(",
3919                                        scalar.to_hlsl_str()?,
3920                                        common::vector_size_str(size)
3921                                    )?;
3922                                }
3923                                TypeInner::Scalar(_) => {
3924                                    write!(self.out, "{}(", scalar.to_hlsl_str()?,)?;
3925                                }
3926                                TypeInner::Matrix { columns, rows, .. } => {
3927                                    write!(
3928                                        self.out,
3929                                        "{}{}x{}(",
3930                                        scalar.to_hlsl_str()?,
3931                                        common::vector_size_str(columns),
3932                                        common::vector_size_str(rows)
3933                                    )?;
3934                                }
3935                                _ => {
3936                                    return Err(Error::Unimplemented(format!(
3937                                        "write_expr expression::as {inner:?}"
3938                                    )));
3939                                }
3940                            };
3941                            true
3942                        }
3943                        None => {
3944                            if inner.scalar_width() == Some(8) {
3945                                false
3946                            } else if inner.scalar_width() == Some(2) {
3947                                // HLSL's asint()/asuint() only work on 32-bit types.
3948                                // For 16-bit bitcasts, use type constructor instead.
3949                                let dst_scalar = Scalar { kind, width: 2 };
3950                                match *inner {
3951                                    TypeInner::Vector { size, .. } => {
3952                                        write!(
3953                                            self.out,
3954                                            "{}{}(",
3955                                            dst_scalar.to_hlsl_str()?,
3956                                            common::vector_size_str(size)
3957                                        )?;
3958                                    }
3959                                    _ => {
3960                                        write!(self.out, "{}(", dst_scalar.to_hlsl_str()?)?;
3961                                    }
3962                                };
3963                                true
3964                            } else {
3965                                write!(self.out, "{}(", kind.to_hlsl_cast(),)?;
3966                                true
3967                            }
3968                        }
3969                    };
3970                    self.write_expr(module, expr, func_ctx)?;
3971                    if close_paren {
3972                        write!(self.out, ")")?;
3973                    }
3974                }
3975            }
3976            Expression::Math {
3977                fun,
3978                arg,
3979                arg1,
3980                arg2,
3981                arg3,
3982            } => {
3983                use crate::MathFunction as Mf;
3984
3985                enum Function {
3986                    Asincosh { is_sin: bool },
3987                    Atanh,
3988                    Pack2x16float,
3989                    Pack2x16snorm,
3990                    Pack2x16unorm,
3991                    Pack4x8snorm,
3992                    Pack4x8unorm,
3993                    Pack4xI8,
3994                    Pack4xU8,
3995                    Pack4xI8Clamp,
3996                    Pack4xU8Clamp,
3997                    Unpack2x16float,
3998                    Unpack2x16snorm,
3999                    Unpack2x16unorm,
4000                    Unpack4x8snorm,
4001                    Unpack4x8unorm,
4002                    Unpack4xI8,
4003                    Unpack4xU8,
4004                    Dot4I8Packed,
4005                    Dot4U8Packed,
4006                    QuantizeToF16,
4007                    Regular(&'static str),
4008                    MissingIntOverload(&'static str),
4009                    MissingIntReturnType(&'static str),
4010                    CountTrailingZeros,
4011                    CountLeadingZeros,
4012                }
4013
4014                let fun = match fun {
4015                    // comparison
4016                    Mf::Abs => match func_ctx.resolve_type(arg, &module.types).scalar() {
4017                        Some(Scalar::I32) => Function::Regular(ABS_FUNCTION),
4018                        _ => Function::Regular("abs"),
4019                    },
4020                    Mf::Min => Function::Regular("min"),
4021                    Mf::Max => Function::Regular("max"),
4022                    Mf::Clamp => Function::Regular("clamp"),
4023                    Mf::Saturate => Function::Regular("saturate"),
4024                    // trigonometry
4025                    Mf::Cos => Function::Regular("cos"),
4026                    Mf::Cosh => Function::Regular("cosh"),
4027                    Mf::Sin => Function::Regular("sin"),
4028                    Mf::Sinh => Function::Regular("sinh"),
4029                    Mf::Tan => Function::Regular("tan"),
4030                    Mf::Tanh => Function::Regular("tanh"),
4031                    Mf::Acos => Function::Regular("acos"),
4032                    Mf::Asin => Function::Regular("asin"),
4033                    Mf::Atan => Function::Regular("atan"),
4034                    Mf::Atan2 => Function::Regular("atan2"),
4035                    Mf::Asinh => Function::Asincosh { is_sin: true },
4036                    Mf::Acosh => Function::Asincosh { is_sin: false },
4037                    Mf::Atanh => Function::Atanh,
4038                    Mf::Radians => Function::Regular("radians"),
4039                    Mf::Degrees => Function::Regular("degrees"),
4040                    // decomposition
4041                    Mf::Ceil => Function::Regular("ceil"),
4042                    Mf::Floor => Function::Regular("floor"),
4043                    Mf::Round => Function::Regular("round"),
4044                    Mf::Fract => Function::Regular("frac"),
4045                    Mf::Trunc => Function::Regular("trunc"),
4046                    Mf::Modf => Function::Regular(MODF_FUNCTION),
4047                    Mf::Frexp => Function::Regular(FREXP_FUNCTION),
4048                    Mf::Ldexp => Function::Regular("ldexp"),
4049                    // exponent
4050                    Mf::Exp => Function::Regular("exp"),
4051                    Mf::Exp2 => Function::Regular("exp2"),
4052                    Mf::Log => Function::Regular("log"),
4053                    Mf::Log2 => Function::Regular("log2"),
4054                    Mf::Pow => Function::Regular("pow"),
4055                    // geometry
4056                    Mf::Dot => Function::Regular("dot"),
4057                    Mf::Dot4I8Packed => Function::Dot4I8Packed,
4058                    Mf::Dot4U8Packed => Function::Dot4U8Packed,
4059                    //Mf::Outer => ,
4060                    Mf::Cross => Function::Regular("cross"),
4061                    Mf::Distance => Function::Regular("distance"),
4062                    Mf::Length => Function::Regular("length"),
4063                    Mf::Normalize => Function::Regular("normalize"),
4064                    Mf::FaceForward => Function::Regular("faceforward"),
4065                    Mf::Reflect => Function::Regular("reflect"),
4066                    Mf::Refract => Function::Regular("refract"),
4067                    // computational
4068                    Mf::Sign => Function::Regular("sign"),
4069                    Mf::Fma => Function::Regular("mad"),
4070                    Mf::Mix => Function::Regular("lerp"),
4071                    Mf::Step => Function::Regular("step"),
4072                    Mf::SmoothStep => Function::Regular("smoothstep"),
4073                    Mf::Sqrt => Function::Regular("sqrt"),
4074                    Mf::InverseSqrt => Function::Regular("rsqrt"),
4075                    //Mf::Inverse =>,
4076                    Mf::Transpose => Function::Regular("transpose"),
4077                    Mf::Determinant => Function::Regular("determinant"),
4078                    Mf::QuantizeToF16 => Function::QuantizeToF16,
4079                    // bits
4080                    Mf::CountTrailingZeros => Function::CountTrailingZeros,
4081                    Mf::CountLeadingZeros => Function::CountLeadingZeros,
4082                    Mf::CountOneBits => Function::MissingIntOverload("countbits"),
4083                    Mf::ReverseBits => Function::MissingIntOverload("reversebits"),
4084                    Mf::FirstTrailingBit => Function::MissingIntReturnType("firstbitlow"),
4085                    Mf::FirstLeadingBit => Function::MissingIntReturnType("firstbithigh"),
4086                    Mf::ExtractBits => Function::Regular(EXTRACT_BITS_FUNCTION),
4087                    Mf::InsertBits => Function::Regular(INSERT_BITS_FUNCTION),
4088                    // Data Packing
4089                    Mf::Pack2x16float => Function::Pack2x16float,
4090                    Mf::Pack2x16snorm => Function::Pack2x16snorm,
4091                    Mf::Pack2x16unorm => Function::Pack2x16unorm,
4092                    Mf::Pack4x8snorm => Function::Pack4x8snorm,
4093                    Mf::Pack4x8unorm => Function::Pack4x8unorm,
4094                    Mf::Pack4xI8 => Function::Pack4xI8,
4095                    Mf::Pack4xU8 => Function::Pack4xU8,
4096                    Mf::Pack4xI8Clamp => Function::Pack4xI8Clamp,
4097                    Mf::Pack4xU8Clamp => Function::Pack4xU8Clamp,
4098                    // Data Unpacking
4099                    Mf::Unpack2x16float => Function::Unpack2x16float,
4100                    Mf::Unpack2x16snorm => Function::Unpack2x16snorm,
4101                    Mf::Unpack2x16unorm => Function::Unpack2x16unorm,
4102                    Mf::Unpack4x8snorm => Function::Unpack4x8snorm,
4103                    Mf::Unpack4x8unorm => Function::Unpack4x8unorm,
4104                    Mf::Unpack4xI8 => Function::Unpack4xI8,
4105                    Mf::Unpack4xU8 => Function::Unpack4xU8,
4106                    _ => return Err(Error::Unimplemented(format!("write_expr_math {fun:?}"))),
4107                };
4108
4109                match fun {
4110                    Function::Asincosh { is_sin } => {
4111                        write!(self.out, "log(")?;
4112                        self.write_expr(module, arg, func_ctx)?;
4113                        write!(self.out, " + sqrt(")?;
4114                        self.write_expr(module, arg, func_ctx)?;
4115                        write!(self.out, " * ")?;
4116                        self.write_expr(module, arg, func_ctx)?;
4117                        match is_sin {
4118                            true => write!(self.out, " + 1.0))")?,
4119                            false => write!(self.out, " - 1.0))")?,
4120                        }
4121                    }
4122                    Function::Atanh => {
4123                        write!(self.out, "0.5 * log((1.0 + ")?;
4124                        self.write_expr(module, arg, func_ctx)?;
4125                        write!(self.out, ") / (1.0 - ")?;
4126                        self.write_expr(module, arg, func_ctx)?;
4127                        write!(self.out, "))")?;
4128                    }
4129                    Function::Pack2x16float => {
4130                        write!(self.out, "(f32tof16(")?;
4131                        self.write_expr(module, arg, func_ctx)?;
4132                        write!(self.out, "[0]) | f32tof16(")?;
4133                        self.write_expr(module, arg, func_ctx)?;
4134                        write!(self.out, "[1]) << 16)")?;
4135                    }
4136                    Function::Pack2x16snorm => {
4137                        let scale = 32767;
4138
4139                        write!(self.out, "uint((int(round(clamp(")?;
4140                        self.write_expr(module, arg, func_ctx)?;
4141                        write!(
4142                            self.out,
4143                            "[0], -1.0, 1.0) * {scale}.0)) & 0xFFFF) | ((int(round(clamp("
4144                        )?;
4145                        self.write_expr(module, arg, func_ctx)?;
4146                        write!(self.out, "[1], -1.0, 1.0) * {scale}.0)) & 0xFFFF) << 16))",)?;
4147                    }
4148                    Function::Pack2x16unorm => {
4149                        let scale = 65535;
4150
4151                        write!(self.out, "(uint(round(clamp(")?;
4152                        self.write_expr(module, arg, func_ctx)?;
4153                        write!(self.out, "[0], 0.0, 1.0) * {scale}.0)) | uint(round(clamp(")?;
4154                        self.write_expr(module, arg, func_ctx)?;
4155                        write!(self.out, "[1], 0.0, 1.0) * {scale}.0)) << 16)")?;
4156                    }
4157                    Function::Pack4x8snorm => {
4158                        let scale = 127;
4159
4160                        write!(self.out, "uint((int(round(clamp(")?;
4161                        self.write_expr(module, arg, func_ctx)?;
4162                        write!(
4163                            self.out,
4164                            "[0], -1.0, 1.0) * {scale}.0)) & 0xFF) | ((int(round(clamp("
4165                        )?;
4166                        self.write_expr(module, arg, func_ctx)?;
4167                        write!(
4168                            self.out,
4169                            "[1], -1.0, 1.0) * {scale}.0)) & 0xFF) << 8) | ((int(round(clamp("
4170                        )?;
4171                        self.write_expr(module, arg, func_ctx)?;
4172                        write!(
4173                            self.out,
4174                            "[2], -1.0, 1.0) * {scale}.0)) & 0xFF) << 16) | ((int(round(clamp("
4175                        )?;
4176                        self.write_expr(module, arg, func_ctx)?;
4177                        write!(self.out, "[3], -1.0, 1.0) * {scale}.0)) & 0xFF) << 24))",)?;
4178                    }
4179                    Function::Pack4x8unorm => {
4180                        let scale = 255;
4181
4182                        write!(self.out, "(uint(round(clamp(")?;
4183                        self.write_expr(module, arg, func_ctx)?;
4184                        write!(self.out, "[0], 0.0, 1.0) * {scale}.0)) | uint(round(clamp(")?;
4185                        self.write_expr(module, arg, func_ctx)?;
4186                        write!(
4187                            self.out,
4188                            "[1], 0.0, 1.0) * {scale}.0)) << 8 | uint(round(clamp("
4189                        )?;
4190                        self.write_expr(module, arg, func_ctx)?;
4191                        write!(
4192                            self.out,
4193                            "[2], 0.0, 1.0) * {scale}.0)) << 16 | uint(round(clamp("
4194                        )?;
4195                        self.write_expr(module, arg, func_ctx)?;
4196                        write!(self.out, "[3], 0.0, 1.0) * {scale}.0)) << 24)")?;
4197                    }
4198                    fun @ (Function::Pack4xI8
4199                    | Function::Pack4xU8
4200                    | Function::Pack4xI8Clamp
4201                    | Function::Pack4xU8Clamp) => {
4202                        let was_signed =
4203                            matches!(fun, Function::Pack4xI8 | Function::Pack4xI8Clamp);
4204                        let clamp_bounds = match fun {
4205                            Function::Pack4xI8Clamp => Some(("-128", "127")),
4206                            Function::Pack4xU8Clamp => Some(("0", "255")),
4207                            _ => None,
4208                        };
4209                        if was_signed {
4210                            write!(self.out, "uint(")?;
4211                        }
4212                        let write_arg = |this: &mut Self| -> BackendResult {
4213                            if let Some((min, max)) = clamp_bounds {
4214                                write!(this.out, "clamp(")?;
4215                                this.write_expr(module, arg, func_ctx)?;
4216                                write!(this.out, ", {min}, {max})")?;
4217                            } else {
4218                                this.write_expr(module, arg, func_ctx)?;
4219                            }
4220                            Ok(())
4221                        };
4222                        write!(self.out, "(")?;
4223                        write_arg(self)?;
4224                        write!(self.out, "[0] & 0xFF) | ((")?;
4225                        write_arg(self)?;
4226                        write!(self.out, "[1] & 0xFF) << 8) | ((")?;
4227                        write_arg(self)?;
4228                        write!(self.out, "[2] & 0xFF) << 16) | ((")?;
4229                        write_arg(self)?;
4230                        write!(self.out, "[3] & 0xFF) << 24)")?;
4231                        if was_signed {
4232                            write!(self.out, ")")?;
4233                        }
4234                    }
4235
4236                    Function::Unpack2x16float => {
4237                        write!(self.out, "float2(f16tof32(")?;
4238                        self.write_expr(module, arg, func_ctx)?;
4239                        write!(self.out, "), f16tof32((")?;
4240                        self.write_expr(module, arg, func_ctx)?;
4241                        write!(self.out, ") >> 16))")?;
4242                    }
4243                    Function::Unpack2x16snorm => {
4244                        let scale = 32767;
4245
4246                        write!(self.out, "(float2(int2(")?;
4247                        self.write_expr(module, arg, func_ctx)?;
4248                        write!(self.out, " << 16, ")?;
4249                        self.write_expr(module, arg, func_ctx)?;
4250                        write!(self.out, ") >> 16) / {scale}.0)")?;
4251                    }
4252                    Function::Unpack2x16unorm => {
4253                        let scale = 65535;
4254
4255                        write!(self.out, "(float2(")?;
4256                        self.write_expr(module, arg, func_ctx)?;
4257                        write!(self.out, " & 0xFFFF, ")?;
4258                        self.write_expr(module, arg, func_ctx)?;
4259                        write!(self.out, " >> 16) / {scale}.0)")?;
4260                    }
4261                    Function::Unpack4x8snorm => {
4262                        let scale = 127;
4263
4264                        write!(self.out, "(float4(int4(")?;
4265                        self.write_expr(module, arg, func_ctx)?;
4266                        write!(self.out, " << 24, ")?;
4267                        self.write_expr(module, arg, func_ctx)?;
4268                        write!(self.out, " << 16, ")?;
4269                        self.write_expr(module, arg, func_ctx)?;
4270                        write!(self.out, " << 8, ")?;
4271                        self.write_expr(module, arg, func_ctx)?;
4272                        write!(self.out, ") >> 24) / {scale}.0)")?;
4273                    }
4274                    Function::Unpack4x8unorm => {
4275                        let scale = 255;
4276
4277                        write!(self.out, "(float4(")?;
4278                        self.write_expr(module, arg, func_ctx)?;
4279                        write!(self.out, " & 0xFF, ")?;
4280                        self.write_expr(module, arg, func_ctx)?;
4281                        write!(self.out, " >> 8 & 0xFF, ")?;
4282                        self.write_expr(module, arg, func_ctx)?;
4283                        write!(self.out, " >> 16 & 0xFF, ")?;
4284                        self.write_expr(module, arg, func_ctx)?;
4285                        write!(self.out, " >> 24) / {scale}.0)")?;
4286                    }
4287                    fun @ (Function::Unpack4xI8 | Function::Unpack4xU8) => {
4288                        write!(self.out, "(")?;
4289                        if matches!(fun, Function::Unpack4xU8) {
4290                            write!(self.out, "u")?;
4291                        }
4292                        write!(self.out, "int4(")?;
4293                        self.write_expr(module, arg, func_ctx)?;
4294                        write!(self.out, ", ")?;
4295                        self.write_expr(module, arg, func_ctx)?;
4296                        write!(self.out, " >> 8, ")?;
4297                        self.write_expr(module, arg, func_ctx)?;
4298                        write!(self.out, " >> 16, ")?;
4299                        self.write_expr(module, arg, func_ctx)?;
4300                        write!(self.out, " >> 24) << 24 >> 24)")?;
4301                    }
4302                    fun @ (Function::Dot4I8Packed | Function::Dot4U8Packed) => {
4303                        let arg1 = arg1.unwrap();
4304
4305                        if self.options.shader_model >= ShaderModel::V6_4 {
4306                            // Intrinsics `dot4add_{i, u}8packed` are available in SM 6.4 and later.
4307                            let function_name = match fun {
4308                                Function::Dot4I8Packed => "dot4add_i8packed",
4309                                Function::Dot4U8Packed => "dot4add_u8packed",
4310                                _ => unreachable!(),
4311                            };
4312                            write!(self.out, "{function_name}(")?;
4313                            self.write_expr(module, arg, func_ctx)?;
4314                            write!(self.out, ", ")?;
4315                            self.write_expr(module, arg1, func_ctx)?;
4316                            write!(self.out, ", 0)")?;
4317                        } else {
4318                            // Fall back to a polyfill as `dot4add_u8packed` is not available.
4319                            write!(self.out, "dot(")?;
4320
4321                            if matches!(fun, Function::Dot4U8Packed) {
4322                                write!(self.out, "u")?;
4323                            }
4324                            write!(self.out, "int4(")?;
4325                            self.write_expr(module, arg, func_ctx)?;
4326                            write!(self.out, ", ")?;
4327                            self.write_expr(module, arg, func_ctx)?;
4328                            write!(self.out, " >> 8, ")?;
4329                            self.write_expr(module, arg, func_ctx)?;
4330                            write!(self.out, " >> 16, ")?;
4331                            self.write_expr(module, arg, func_ctx)?;
4332                            write!(self.out, " >> 24) << 24 >> 24, ")?;
4333
4334                            if matches!(fun, Function::Dot4U8Packed) {
4335                                write!(self.out, "u")?;
4336                            }
4337                            write!(self.out, "int4(")?;
4338                            self.write_expr(module, arg1, func_ctx)?;
4339                            write!(self.out, ", ")?;
4340                            self.write_expr(module, arg1, func_ctx)?;
4341                            write!(self.out, " >> 8, ")?;
4342                            self.write_expr(module, arg1, func_ctx)?;
4343                            write!(self.out, " >> 16, ")?;
4344                            self.write_expr(module, arg1, func_ctx)?;
4345                            write!(self.out, " >> 24) << 24 >> 24)")?;
4346                        }
4347                    }
4348                    Function::QuantizeToF16 => {
4349                        write!(self.out, "f16tof32(f32tof16(")?;
4350                        self.write_expr(module, arg, func_ctx)?;
4351                        write!(self.out, "))")?;
4352                    }
4353                    Function::Regular(fun_name) => {
4354                        write!(self.out, "{fun_name}(")?;
4355                        self.write_expr(module, arg, func_ctx)?;
4356                        if let Some(arg) = arg1 {
4357                            write!(self.out, ", ")?;
4358                            self.write_expr(module, arg, func_ctx)?;
4359                        }
4360                        if let Some(arg) = arg2 {
4361                            write!(self.out, ", ")?;
4362                            self.write_expr(module, arg, func_ctx)?;
4363                        }
4364                        if let Some(arg) = arg3 {
4365                            write!(self.out, ", ")?;
4366                            self.write_expr(module, arg, func_ctx)?;
4367                        }
4368                        write!(self.out, ")")?
4369                    }
4370                    // These overloads are only missing on FXC, so this is only needed for 32bit types,
4371                    // as non-32bit types are DXC only.
4372                    Function::MissingIntOverload(fun_name) => {
4373                        let scalar_kind = func_ctx.resolve_type(arg, &module.types).scalar();
4374                        if let Some(Scalar::I32) = scalar_kind {
4375                            write!(self.out, "asint({fun_name}(asuint(")?;
4376                            self.write_expr(module, arg, func_ctx)?;
4377                            write!(self.out, ")))")?;
4378                        } else {
4379                            write!(self.out, "{fun_name}(")?;
4380                            self.write_expr(module, arg, func_ctx)?;
4381                            write!(self.out, ")")?;
4382                        }
4383                    }
4384                    // These overloads are only missing on FXC, so this is only needed for 32bit types,
4385                    // as non-32bit types are DXC only.
4386                    Function::MissingIntReturnType(fun_name) => {
4387                        let scalar_kind = func_ctx.resolve_type(arg, &module.types).scalar();
4388                        if let Some(Scalar::I32) = scalar_kind {
4389                            write!(self.out, "asint({fun_name}(")?;
4390                            self.write_expr(module, arg, func_ctx)?;
4391                            write!(self.out, "))")?;
4392                        } else {
4393                            write!(self.out, "{fun_name}(")?;
4394                            self.write_expr(module, arg, func_ctx)?;
4395                            write!(self.out, ")")?;
4396                        }
4397                    }
4398                    Function::CountTrailingZeros => {
4399                        match *func_ctx.resolve_type(arg, &module.types) {
4400                            TypeInner::Vector { size, scalar } => {
4401                                let s = match size {
4402                                    crate::VectorSize::Bi => ".xx",
4403                                    crate::VectorSize::Tri => ".xxx",
4404                                    crate::VectorSize::Quad => ".xxxx",
4405                                };
4406
4407                                let scalar_width_bits = scalar.width * 8;
4408
4409                                if scalar.kind == ScalarKind::Uint || scalar.width != 4 {
4410                                    write!(
4411                                        self.out,
4412                                        "min(({scalar_width_bits}u){s}, firstbitlow("
4413                                    )?;
4414                                    self.write_expr(module, arg, func_ctx)?;
4415                                    write!(self.out, "))")?;
4416                                } else {
4417                                    // This is only needed for the FXC path, on 32bit signed integers.
4418                                    write!(
4419                                        self.out,
4420                                        "asint(min(({scalar_width_bits}u){s}, firstbitlow("
4421                                    )?;
4422                                    self.write_expr(module, arg, func_ctx)?;
4423                                    write!(self.out, ")))")?;
4424                                }
4425                            }
4426                            TypeInner::Scalar(scalar) => {
4427                                let scalar_width_bits = scalar.width * 8;
4428
4429                                if scalar.kind == ScalarKind::Uint || scalar.width != 4 {
4430                                    write!(self.out, "min({scalar_width_bits}u, firstbitlow(")?;
4431                                    self.write_expr(module, arg, func_ctx)?;
4432                                    write!(self.out, "))")?;
4433                                } else {
4434                                    // This is only needed for the FXC path, on 32bit signed integers.
4435                                    write!(
4436                                        self.out,
4437                                        "asint(min({scalar_width_bits}u, firstbitlow("
4438                                    )?;
4439                                    self.write_expr(module, arg, func_ctx)?;
4440                                    write!(self.out, ")))")?;
4441                                }
4442                            }
4443                            _ => unreachable!(),
4444                        }
4445
4446                        return Ok(());
4447                    }
4448                    Function::CountLeadingZeros => {
4449                        match *func_ctx.resolve_type(arg, &module.types) {
4450                            TypeInner::Vector { size, scalar } => {
4451                                let s = match size {
4452                                    crate::VectorSize::Bi => ".xx",
4453                                    crate::VectorSize::Tri => ".xxx",
4454                                    crate::VectorSize::Quad => ".xxxx",
4455                                };
4456
4457                                // scalar width - 1
4458                                let constant = scalar.width * 8 - 1;
4459
4460                                if scalar.kind == ScalarKind::Uint {
4461                                    write!(self.out, "(({constant}u){s} - firstbithigh(")?;
4462                                    self.write_expr(module, arg, func_ctx)?;
4463                                    write!(self.out, "))")?;
4464                                } else {
4465                                    let conversion_func = match scalar.width {
4466                                        4 => "asint",
4467                                        _ => "",
4468                                    };
4469                                    write!(self.out, "(")?;
4470                                    self.write_expr(module, arg, func_ctx)?;
4471                                    write!(
4472                                        self.out,
4473                                        " < (0){s} ? (0){s} : ({constant}){s} - {conversion_func}(firstbithigh("
4474                                    )?;
4475                                    self.write_expr(module, arg, func_ctx)?;
4476                                    write!(self.out, ")))")?;
4477                                }
4478                            }
4479                            TypeInner::Scalar(scalar) => {
4480                                // scalar width - 1
4481                                let constant = scalar.width * 8 - 1;
4482
4483                                if let ScalarKind::Uint = scalar.kind {
4484                                    write!(self.out, "({constant}u - firstbithigh(")?;
4485                                    self.write_expr(module, arg, func_ctx)?;
4486                                    write!(self.out, "))")?;
4487                                } else {
4488                                    let conversion_func = match scalar.width {
4489                                        4 => "asint",
4490                                        _ => "",
4491                                    };
4492                                    write!(self.out, "(")?;
4493                                    self.write_expr(module, arg, func_ctx)?;
4494                                    write!(
4495                                        self.out,
4496                                        " < 0 ? 0 : {constant} - {conversion_func}(firstbithigh("
4497                                    )?;
4498                                    self.write_expr(module, arg, func_ctx)?;
4499                                    write!(self.out, ")))")?;
4500                                }
4501                            }
4502                            _ => unreachable!(),
4503                        }
4504
4505                        return Ok(());
4506                    }
4507                }
4508            }
4509            Expression::Swizzle {
4510                size,
4511                vector,
4512                pattern,
4513            } => {
4514                self.write_expr(module, vector, func_ctx)?;
4515                write!(self.out, ".")?;
4516                for &sc in pattern[..size as usize].iter() {
4517                    self.out.write_char(back::COMPONENTS[sc as usize])?;
4518                }
4519            }
4520            Expression::ArrayLength(expr) => {
4521                let var_handle = match func_ctx.expressions[expr] {
4522                    Expression::AccessIndex { base, index: _ } => {
4523                        match func_ctx.expressions[base] {
4524                            Expression::GlobalVariable(handle) => handle,
4525                            _ => unreachable!(),
4526                        }
4527                    }
4528                    Expression::GlobalVariable(handle) => handle,
4529                    _ => unreachable!(),
4530                };
4531
4532                let var = &module.global_variables[var_handle];
4533                let (offset, stride) = match module.types[var.ty].inner {
4534                    TypeInner::Array { stride, .. } => (0, stride),
4535                    TypeInner::Struct { ref members, .. } => {
4536                        let last = members.last().unwrap();
4537                        let stride = match module.types[last.ty].inner {
4538                            TypeInner::Array { stride, .. } => stride,
4539                            _ => unreachable!(),
4540                        };
4541                        (last.offset, stride)
4542                    }
4543                    _ => unreachable!(),
4544                };
4545
4546                let storage_access = match var.space {
4547                    crate::AddressSpace::Storage { access } => access,
4548                    _ => crate::StorageAccess::default(),
4549                };
4550                let wrapped_array_length = WrappedArrayLength {
4551                    writable: storage_access.contains(crate::StorageAccess::STORE),
4552                };
4553
4554                write!(self.out, "((")?;
4555                self.write_wrapped_array_length_function_name(wrapped_array_length)?;
4556                let var_name = &self.names[&NameKey::GlobalVariable(var_handle)];
4557                write!(self.out, "({var_name}) - {offset}) / {stride})")?
4558            }
4559            Expression::Derivative { axis, ctrl, expr } => {
4560                use crate::{DerivativeAxis as Axis, DerivativeControl as Ctrl};
4561                if axis == Axis::Width && (ctrl == Ctrl::Coarse || ctrl == Ctrl::Fine) {
4562                    let tail = match ctrl {
4563                        Ctrl::Coarse => "coarse",
4564                        Ctrl::Fine => "fine",
4565                        Ctrl::None => unreachable!(),
4566                    };
4567                    write!(self.out, "abs(ddx_{tail}(")?;
4568                    self.write_expr(module, expr, func_ctx)?;
4569                    write!(self.out, ")) + abs(ddy_{tail}(")?;
4570                    self.write_expr(module, expr, func_ctx)?;
4571                    write!(self.out, "))")?
4572                } else {
4573                    let fun_str = match (axis, ctrl) {
4574                        (Axis::X, Ctrl::Coarse) => "ddx_coarse",
4575                        (Axis::X, Ctrl::Fine) => "ddx_fine",
4576                        (Axis::X, Ctrl::None) => "ddx",
4577                        (Axis::Y, Ctrl::Coarse) => "ddy_coarse",
4578                        (Axis::Y, Ctrl::Fine) => "ddy_fine",
4579                        (Axis::Y, Ctrl::None) => "ddy",
4580                        (Axis::Width, Ctrl::Coarse | Ctrl::Fine) => unreachable!(),
4581                        (Axis::Width, Ctrl::None) => "fwidth",
4582                    };
4583                    write!(self.out, "{fun_str}(")?;
4584                    self.write_expr(module, expr, func_ctx)?;
4585                    write!(self.out, ")")?
4586                }
4587            }
4588            Expression::Relational { fun, argument } => {
4589                use crate::RelationalFunction as Rf;
4590
4591                let fun_str = match fun {
4592                    Rf::All => "all",
4593                    Rf::Any => "any",
4594                    Rf::IsNan => "isnan",
4595                    Rf::IsInf => "isinf",
4596                };
4597                write!(self.out, "{fun_str}(")?;
4598                self.write_expr(module, argument, func_ctx)?;
4599                write!(self.out, ")")?
4600            }
4601            Expression::Select {
4602                condition,
4603                accept,
4604                reject,
4605            } => {
4606                write!(self.out, "(")?;
4607                self.write_expr(module, condition, func_ctx)?;
4608                write!(self.out, " ? ")?;
4609                self.write_expr(module, accept, func_ctx)?;
4610                write!(self.out, " : ")?;
4611                self.write_expr(module, reject, func_ctx)?;
4612                write!(self.out, ")")?
4613            }
4614            Expression::RayQueryGetIntersection { query, committed } => {
4615                // For reasoning, see write_stmt
4616                let Expression::LocalVariable(query_var) = func_ctx.expressions[query] else {
4617                    unreachable!()
4618                };
4619
4620                let tracker_expr_name = format!(
4621                    "{RAY_QUERY_TRACKER_VARIABLE_PREFIX}{}",
4622                    self.names[&func_ctx.name_key(query_var)]
4623                );
4624
4625                if committed {
4626                    write!(self.out, "GetCommittedIntersection(")?;
4627                    self.write_expr(module, query, func_ctx)?;
4628                    write!(self.out, ", {tracker_expr_name})")?;
4629                } else {
4630                    write!(self.out, "GetCandidateIntersection(")?;
4631                    self.write_expr(module, query, func_ctx)?;
4632                    write!(self.out, ", {tracker_expr_name})")?;
4633                }
4634            }
4635            // Not supported yet
4636            Expression::RayQueryVertexPositions { .. }
4637            | Expression::CooperativeLoad { .. }
4638            | Expression::CooperativeMultiplyAdd { .. } => {
4639                unreachable!()
4640            }
4641            // Nothing to do here, since call expression already cached
4642            Expression::CallResult(_)
4643            | Expression::AtomicResult { .. }
4644            | Expression::WorkGroupUniformLoadResult { .. }
4645            | Expression::RayQueryProceedResult
4646            | Expression::SubgroupBallotResult
4647            | Expression::SubgroupOperationResult { .. } => {}
4648        }
4649
4650        if !closing_bracket.is_empty() {
4651            write!(self.out, "{closing_bracket}")?;
4652        }
4653        Ok(())
4654    }
4655
4656    #[allow(clippy::too_many_arguments)]
4657    fn write_image_load(
4658        &mut self,
4659        module: &&Module,
4660        expr: Handle<crate::Expression>,
4661        func_ctx: &back::FunctionCtx,
4662        image: Handle<crate::Expression>,
4663        coordinate: Handle<crate::Expression>,
4664        array_index: Option<Handle<crate::Expression>>,
4665        sample: Option<Handle<crate::Expression>>,
4666        level: Option<Handle<crate::Expression>>,
4667    ) -> Result<(), Error> {
4668        let mut wrapping_type = None;
4669        match *func_ctx.resolve_type(image, &module.types) {
4670            TypeInner::Image {
4671                class: crate::ImageClass::External,
4672                ..
4673            } => {
4674                write!(self.out, "{IMAGE_LOAD_EXTERNAL_FUNCTION}(")?;
4675                self.write_expr(module, image, func_ctx)?;
4676                write!(self.out, ", ")?;
4677                self.write_expr(module, coordinate, func_ctx)?;
4678                write!(self.out, ")")?;
4679                return Ok(());
4680            }
4681            TypeInner::Image {
4682                class: crate::ImageClass::Storage { format, .. },
4683                ..
4684            } => {
4685                if format.single_component() {
4686                    wrapping_type = Some(Scalar::from(format));
4687                }
4688            }
4689            _ => {}
4690        }
4691        if let Some(scalar) = wrapping_type {
4692            write!(
4693                self.out,
4694                "{}{}(",
4695                help::IMAGE_STORAGE_LOAD_SCALAR_WRAPPER,
4696                scalar.to_hlsl_str()?
4697            )?;
4698        }
4699        // https://docs.microsoft.com/en-us/windows/win32/direct3dhlsl/dx-graphics-hlsl-to-load
4700        self.write_expr(module, image, func_ctx)?;
4701        write!(self.out, ".Load(")?;
4702
4703        self.write_texture_coordinates("int", coordinate, array_index, level, module, func_ctx)?;
4704
4705        if let Some(sample) = sample {
4706            write!(self.out, ", ")?;
4707            self.write_expr(module, sample, func_ctx)?;
4708        }
4709
4710        // close bracket for Load function
4711        write!(self.out, ")")?;
4712
4713        if wrapping_type.is_some() {
4714            write!(self.out, ")")?;
4715        }
4716
4717        // return x component if return type is scalar
4718        if let TypeInner::Scalar(_) = *func_ctx.resolve_type(expr, &module.types) {
4719            write!(self.out, ".x")?;
4720        }
4721        Ok(())
4722    }
4723
4724    /// Find the [`BindingArraySamplerInfo`] from an expression so that such an access
4725    /// can be generated later.
4726    fn sampler_binding_array_info_from_expression(
4727        &mut self,
4728        module: &Module,
4729        func_ctx: &back::FunctionCtx<'_>,
4730        base: Handle<crate::Expression>,
4731        resolved: &TypeInner,
4732    ) -> Option<BindingArraySamplerInfo> {
4733        if let TypeInner::BindingArray {
4734            base: base_ty_handle,
4735            ..
4736        } = *resolved
4737        {
4738            let base_ty = &module.types[base_ty_handle].inner;
4739            if let TypeInner::Sampler { comparison, .. } = *base_ty {
4740                let base = &func_ctx.expressions[base];
4741
4742                if let crate::Expression::GlobalVariable(handle) = *base {
4743                    let variable = &module.global_variables[handle];
4744
4745                    let sampler_heap_name = match comparison {
4746                        true => COMPARISON_SAMPLER_HEAP_VAR,
4747                        false => SAMPLER_HEAP_VAR,
4748                    };
4749
4750                    return Some(BindingArraySamplerInfo {
4751                        sampler_heap_name,
4752                        sampler_index_buffer_name: self
4753                            .wrapped
4754                            .sampler_index_buffers
4755                            .get(&super::SamplerIndexBufferKey {
4756                                group: variable.binding.unwrap().group,
4757                            })
4758                            .unwrap()
4759                            .clone(),
4760                        binding_array_base_index_name: self.names[&NameKey::GlobalVariable(handle)]
4761                            .clone(),
4762                    });
4763                }
4764            }
4765        }
4766
4767        None
4768    }
4769
4770    fn write_named_expr(
4771        &mut self,
4772        module: &Module,
4773        handle: Handle<crate::Expression>,
4774        name: String,
4775        // The expression which is being named.
4776        // Generally, this is the same as handle, except in WorkGroupUniformLoad
4777        expr: Handle<crate::Expression>,
4778        func_ctx: &back::FunctionCtx,
4779    ) -> BackendResult {
4780        if let crate::Expression::Load { pointer } = func_ctx.expressions[expr] {
4781            let ty_inner = func_ctx.resolve_type(pointer, &module.types);
4782            if ty_inner.is_atomic_pointer(&module.types) {
4783                let pointer_space = ty_inner.pointer_space().unwrap();
4784                self.write_value_type(module, func_ctx.info[handle].ty.inner_with(&module.types))?;
4785                write!(self.out, " {name}; ")?;
4786                match pointer_space {
4787                    crate::AddressSpace::WorkGroup => {
4788                        write!(self.out, "InterlockedOr(")?;
4789                        self.write_expr(module, pointer, func_ctx)?;
4790                    }
4791                    crate::AddressSpace::Storage { .. } => {
4792                        let var_handle = self.fill_access_chain(module, pointer, func_ctx)?;
4793                        let var_name = &self.names[&NameKey::GlobalVariable(var_handle)];
4794                        write!(self.out, "{var_name}.InterlockedOr(")?;
4795                        let chain = mem::take(&mut self.temp_access_chain);
4796                        self.write_storage_address(module, &chain, func_ctx)?;
4797                        self.temp_access_chain = chain;
4798                    }
4799                    _ => unreachable!(),
4800                }
4801                writeln!(self.out, ", 0, {name});")?;
4802                self.named_expressions.insert(expr, name);
4803                return Ok(());
4804            }
4805        }
4806        match func_ctx.info[expr].ty {
4807            proc::TypeResolution::Handle(ty_handle) => match module.types[ty_handle].inner {
4808                TypeInner::Struct { .. } => {
4809                    let ty_name = &self.names[&NameKey::Type(ty_handle)];
4810                    write!(self.out, "{ty_name}")?;
4811                }
4812                _ => {
4813                    self.write_type(module, ty_handle)?;
4814                }
4815            },
4816            proc::TypeResolution::Value(ref inner) => {
4817                self.write_value_type(module, inner)?;
4818            }
4819        }
4820
4821        let resolved = func_ctx.resolve_type(expr, &module.types);
4822
4823        write!(self.out, " {name}")?;
4824        // If rhs is a array type, we should write array size
4825        if let TypeInner::Array { base, size, .. } = *resolved {
4826            self.write_array_size(module, base, size)?;
4827        }
4828        write!(self.out, " = ")?;
4829        self.write_expr(module, handle, func_ctx)?;
4830        writeln!(self.out, ";")?;
4831        self.named_expressions.insert(expr, name);
4832
4833        Ok(())
4834    }
4835
4836    /// Helper function that write default zero initialization
4837    pub(super) fn write_default_init(
4838        &mut self,
4839        module: &Module,
4840        ty: Handle<crate::Type>,
4841    ) -> BackendResult {
4842        write!(self.out, "(")?;
4843        self.write_type(module, ty)?;
4844        if let TypeInner::Array { base, size, .. } = module.types[ty].inner {
4845            self.write_array_size(module, base, size)?;
4846        }
4847        write!(self.out, ")0")?;
4848        Ok(())
4849    }
4850
4851    pub(super) fn write_control_barrier(
4852        &mut self,
4853        barrier: crate::Barrier,
4854        level: back::Level,
4855    ) -> BackendResult {
4856        if barrier.contains(crate::Barrier::STORAGE) {
4857            writeln!(self.out, "{level}DeviceMemoryBarrierWithGroupSync();")?;
4858        }
4859        if barrier.contains(crate::Barrier::WORK_GROUP) {
4860            writeln!(self.out, "{level}GroupMemoryBarrierWithGroupSync();")?;
4861        }
4862        if barrier.contains(crate::Barrier::SUB_GROUP) {
4863            // Does not exist in DirectX
4864        }
4865        if barrier.contains(crate::Barrier::TEXTURE) {
4866            writeln!(self.out, "{level}DeviceMemoryBarrierWithGroupSync();")?;
4867        }
4868        Ok(())
4869    }
4870
4871    fn write_memory_barrier(
4872        &mut self,
4873        barrier: crate::Barrier,
4874        level: back::Level,
4875    ) -> BackendResult {
4876        if barrier.contains(crate::Barrier::STORAGE) {
4877            writeln!(self.out, "{level}DeviceMemoryBarrier();")?;
4878        }
4879        if barrier.contains(crate::Barrier::WORK_GROUP) {
4880            writeln!(self.out, "{level}GroupMemoryBarrier();")?;
4881        }
4882        if barrier.contains(crate::Barrier::SUB_GROUP) {
4883            // Does not exist in DirectX
4884        }
4885        if barrier.contains(crate::Barrier::TEXTURE) {
4886            writeln!(self.out, "{level}DeviceMemoryBarrier();")?;
4887        }
4888        Ok(())
4889    }
4890
4891    /// Helper to emit the shared tail of an HLSL atomic call (arguments, value, result)
4892    fn emit_hlsl_atomic_tail(
4893        &mut self,
4894        module: &Module,
4895        func_ctx: &back::FunctionCtx<'_>,
4896        fun: &crate::AtomicFunction,
4897        compare_expr: Option<Handle<crate::Expression>>,
4898        value: Handle<crate::Expression>,
4899        res_var_info: &Option<(Handle<crate::Expression>, String)>,
4900    ) -> BackendResult {
4901        if let Some(cmp) = compare_expr {
4902            write!(self.out, ", ")?;
4903            self.write_expr(module, cmp, func_ctx)?;
4904        }
4905        write!(self.out, ", ")?;
4906        if let crate::AtomicFunction::Subtract = *fun {
4907            // we just wrote `InterlockedAdd`, so negate the argument
4908            write!(self.out, "-")?;
4909        }
4910        self.write_expr(module, value, func_ctx)?;
4911        if let Some(&(_res_handle, ref res_name)) = res_var_info.as_ref() {
4912            write!(self.out, ", ")?;
4913            if compare_expr.is_some() {
4914                write!(self.out, "{res_name}.old_value")?;
4915            } else {
4916                write!(self.out, "{res_name}")?;
4917            }
4918        }
4919        writeln!(self.out, ");")?;
4920        Ok(())
4921    }
4922}
4923
4924pub(super) struct MatrixType {
4925    pub(super) columns: crate::VectorSize,
4926    pub(super) rows: crate::VectorSize,
4927    pub(super) width: crate::Bytes,
4928}
4929
4930pub(super) fn get_inner_matrix_data(
4931    module: &Module,
4932    handle: Handle<crate::Type>,
4933) -> Option<MatrixType> {
4934    match module.types[handle].inner {
4935        TypeInner::Matrix {
4936            columns,
4937            rows,
4938            scalar,
4939        } => Some(MatrixType {
4940            columns,
4941            rows,
4942            width: scalar.width,
4943        }),
4944        TypeInner::Array { base, .. } => get_inner_matrix_data(module, base),
4945        _ => None,
4946    }
4947}
4948
4949/// If `base` is an access chain of the form `mat`, `mat[col]`, or `mat[col][row]`,
4950/// returns a tuple of the matrix, the column (vector) index (if present), and
4951/// the row (scalar) index (if present).
4952fn find_matrix_in_access_chain(
4953    module: &Module,
4954    base: Handle<crate::Expression>,
4955    func_ctx: &back::FunctionCtx<'_>,
4956) -> Option<(Handle<crate::Expression>, Option<Index>, Option<Index>)> {
4957    let mut current_base = base;
4958    let mut vector = None;
4959    let mut scalar = None;
4960    loop {
4961        let resolved_tr = func_ctx
4962            .resolve_type(current_base, &module.types)
4963            .pointer_base_type();
4964        let resolved = resolved_tr.as_ref()?.inner_with(&module.types);
4965
4966        match *resolved {
4967            TypeInner::Matrix { .. } => return Some((current_base, vector, scalar)),
4968            TypeInner::Scalar(_) | TypeInner::Vector { .. } => {}
4969            _ => return None,
4970        }
4971
4972        let index;
4973        (current_base, index) = match func_ctx.expressions[current_base] {
4974            crate::Expression::Access { base, index } => (base, Index::Expression(index)),
4975            crate::Expression::AccessIndex { base, index } => (base, Index::Static(index)),
4976            _ => return None,
4977        };
4978
4979        match *resolved {
4980            TypeInner::Scalar(_) => scalar = Some(index),
4981            TypeInner::Vector { .. } => vector = Some(index),
4982            _ => unreachable!(),
4983        }
4984    }
4985}
4986
4987/// Returns the matrix data if the access chain starting at `base`:
4988/// - starts with an expression with resolved type of [`TypeInner::Matrix`] if `direct = true`
4989/// - contains one or more expressions with resolved type of [`TypeInner::Array`] of [`TypeInner::Matrix`]
4990/// - ends at an expression with resolved type of [`TypeInner::Struct`]
4991pub(super) fn get_inner_matrix_of_struct_array_member(
4992    module: &Module,
4993    base: Handle<crate::Expression>,
4994    func_ctx: &back::FunctionCtx<'_>,
4995    direct: bool,
4996) -> Option<MatrixType> {
4997    let mut mat_data = None;
4998    let mut array_base = None;
4999
5000    let mut current_base = base;
5001    loop {
5002        let mut resolved = func_ctx.resolve_type(current_base, &module.types);
5003        if let TypeInner::Pointer { base, .. } = *resolved {
5004            resolved = &module.types[base].inner;
5005        };
5006
5007        match *resolved {
5008            TypeInner::Matrix {
5009                columns,
5010                rows,
5011                scalar,
5012            } => {
5013                mat_data = Some(MatrixType {
5014                    columns,
5015                    rows,
5016                    width: scalar.width,
5017                })
5018            }
5019            TypeInner::Array { base, .. } => {
5020                array_base = Some(base);
5021            }
5022            TypeInner::Struct { .. } => {
5023                if let Some(array_base) = array_base {
5024                    if direct {
5025                        return mat_data;
5026                    } else {
5027                        return get_inner_matrix_data(module, array_base);
5028                    }
5029                }
5030
5031                break;
5032            }
5033            _ => break,
5034        }
5035
5036        current_base = match func_ctx.expressions[current_base] {
5037            crate::Expression::Access { base, .. } => base,
5038            crate::Expression::AccessIndex { base, .. } => base,
5039            _ => break,
5040        };
5041    }
5042    None
5043}
5044
5045/// Returns the matrix data if the access chain starting at `base`:
5046/// - starts with an expression with resolved type of [`TypeInner::Matrix`], or
5047/// - contains zero or more expressions with resolved type of [`TypeInner::Array`] of [`TypeInner::Matrix`] if `direct = false`
5048/// - and ends with an [`Expression::GlobalVariable`](crate::Expression::GlobalVariable) in [`AddressSpace::Uniform`](crate::AddressSpace::Uniform)
5049fn get_inner_matrix_of_global_uniform(
5050    module: &Module,
5051    base: Handle<crate::Expression>,
5052    func_ctx: &back::FunctionCtx<'_>,
5053    direct: bool,
5054) -> Option<MatrixType> {
5055    let mut mat_data = None;
5056    let mut array_base = None;
5057
5058    let mut current_base = base;
5059    loop {
5060        let mut resolved = func_ctx.resolve_type(current_base, &module.types);
5061        if let TypeInner::Pointer { base, .. } = *resolved {
5062            resolved = &module.types[base].inner;
5063        };
5064
5065        match *resolved {
5066            TypeInner::Matrix {
5067                columns,
5068                rows,
5069                scalar,
5070            } => {
5071                mat_data = Some(MatrixType {
5072                    columns,
5073                    rows,
5074                    width: scalar.width,
5075                })
5076            }
5077            TypeInner::Array { base, .. } => {
5078                if !direct {
5079                    array_base = Some(base);
5080                }
5081            }
5082            _ => break,
5083        }
5084
5085        current_base = match func_ctx.expressions[current_base] {
5086            crate::Expression::Access { base, .. } => base,
5087            crate::Expression::AccessIndex { base, .. } => base,
5088            crate::Expression::GlobalVariable(handle)
5089                if module.global_variables[handle].space == crate::AddressSpace::Uniform =>
5090            {
5091                return mat_data.or_else(|| {
5092                    array_base.and_then(|array_base| get_inner_matrix_data(module, array_base))
5093                })
5094            }
5095            _ => break,
5096        };
5097    }
5098    None
5099}