naga/back/msl/
writer.rs

1use alloc::{
2    format,
3    string::{String, ToString},
4    vec,
5    vec::Vec,
6};
7use core::{
8    cmp::Ordering,
9    fmt::{Display, Error as FmtError, Formatter, Write},
10    iter,
11};
12use num_traits::real::Real as _;
13
14use half::f16;
15
16use super::{
17    ray::RT_NAMESPACE, sampler as sm, Error, LocationMode, Options, PipelineOptions,
18    TranslationInfo, NAMESPACE, WRAPPED_ARRAY_FIELD,
19};
20use crate::{
21    arena::{Handle, HandleSet},
22    back::{
23        self, get_entry_points,
24        msl::{mesh_shader::NestedFunctionInfo, BackendResult, EntryPointArgument},
25        Baked,
26    },
27    common,
28    proc::{
29        self, concrete_int_scalars,
30        index::{self, BoundsCheck},
31        ExternalTextureNameKey, NameKey, TypeResolution,
32    },
33    valid, FastHashMap, FastHashSet,
34};
35
36#[cfg(test)]
37use core::ptr;
38
39// This is a hack: we need to pass a pointer to an atomic,
40// but generally the backend isn't putting "&" in front of every pointer.
41// Some more general handling of pointers is needed to be implemented here.
42const ATOMIC_REFERENCE: &str = "&";
43
44pub(crate) const ATOMIC_COMP_EXCH_FUNCTION: &str = "naga_atomic_compare_exchange_weak_explicit";
45pub(crate) const MODF_FUNCTION: &str = "naga_modf";
46pub(crate) const FREXP_FUNCTION: &str = "naga_frexp";
47pub(crate) const ABS_FUNCTION: &str = "naga_abs";
48pub(crate) const DIV_FUNCTION: &str = "naga_div";
49pub(crate) const DOT_FUNCTION_PREFIX: &str = "naga_dot";
50pub(crate) const MOD_FUNCTION: &str = "naga_mod";
51pub(crate) const NEG_FUNCTION: &str = "naga_neg";
52pub(crate) const F2I32_FUNCTION: &str = "naga_f2i32";
53pub(crate) const F2U32_FUNCTION: &str = "naga_f2u32";
54pub(crate) const F2I64_FUNCTION: &str = "naga_f2i64";
55pub(crate) const F2U64_FUNCTION: &str = "naga_f2u64";
56pub(crate) const IMAGE_LOAD_EXTERNAL_FUNCTION: &str = "nagaTextureLoadExternal";
57pub(crate) const IMAGE_SIZE_EXTERNAL_FUNCTION: &str = "nagaTextureDimensionsExternal";
58pub(crate) const IMAGE_SAMPLE_BASE_CLAMP_TO_EDGE_FUNCTION: &str =
59    "nagaTextureSampleBaseClampToEdge";
60/// For some reason, Metal does not let you have `metal::texture<..>*` as a buffer argument.
61/// However, if you put that texture inside a struct, everything is totally fine. This
62/// baffles me to no end.
63///
64/// As such, we wrap all argument buffers in a struct that has a single generic `<T>` field.
65/// This allows `NagaArgumentBufferWrapper<metal::texture<..>>*` to work. The astute among
66/// you have noticed that this should be exactly the same to the compiler, and you're correct.
67pub(crate) const ARGUMENT_BUFFER_WRAPPER_STRUCT: &str = "NagaArgumentBufferWrapper";
68/// Name of the struct that is declared to wrap the 3 textures and parameters
69/// buffer that [`crate::ImageClass::External`] variables are lowered to,
70/// allowing them to be conveniently passed to user-defined or wrapper
71/// functions. The struct is declared in [`Writer::write_type_defs`].
72pub(crate) const EXTERNAL_TEXTURE_WRAPPER_STRUCT: &str = "NagaExternalTextureWrapper";
73pub(crate) const COOPERATIVE_LOAD_FUNCTION: &str = "NagaCooperativeLoad";
74pub(crate) const COOPERATIVE_MULTIPLY_ADD_FUNCTION: &str = "NagaCooperativeMultiplyAdd";
75
76/// Write the Metal name for a Naga numeric type: scalar, vector, or matrix.
77///
78/// The `sizes` slice determines whether this function writes a
79/// scalar, vector, or matrix type:
80///
81/// - An empty slice produces a scalar type.
82/// - A one-element slice produces a vector type.
83/// - A two element slice `[ROWS COLUMNS]` produces a matrix of the given size.
84fn put_numeric_type(
85    out: &mut impl Write,
86    scalar: crate::Scalar,
87    sizes: &[crate::VectorSize],
88) -> Result<(), FmtError> {
89    match (scalar, sizes) {
90        (scalar, &[]) => {
91            write!(out, "{}", scalar.to_msl_name())
92        }
93        (scalar, &[rows]) => {
94            write!(
95                out,
96                "{}::{}{}",
97                NAMESPACE,
98                scalar.to_msl_name(),
99                common::vector_size_str(rows)
100            )
101        }
102        (scalar, &[rows, columns]) => {
103            write!(
104                out,
105                "{}::{}{}x{}",
106                NAMESPACE,
107                scalar.to_msl_name(),
108                common::vector_size_str(columns),
109                common::vector_size_str(rows)
110            )
111        }
112        (_, _) => Ok(()), // not meaningful
113    }
114}
115
116const fn scalar_is_int(scalar: crate::Scalar) -> bool {
117    use crate::ScalarKind::*;
118    match scalar.kind {
119        Sint | Uint | AbstractInt | Bool => true,
120        Float | AbstractFloat => false,
121    }
122}
123
124/// Prefix for cached clamped level-of-detail values for `ImageLoad` expressions.
125const CLAMPED_LOD_LOAD_PREFIX: &str = "clamped_lod_e";
126
127/// Prefix for reinterpreted expressions using `as_type<T>(...)`.
128const REINTERPRET_PREFIX: &str = "reinterpreted_";
129
130/// Wrapper for identifier names for clamped level-of-detail values
131///
132/// Values of this type implement [`core::fmt::Display`], formatting as
133/// the name of the variable used to hold the cached clamped
134/// level-of-detail value for an `ImageLoad` expression.
135struct ClampedLod(Handle<crate::Expression>);
136
137impl Display for ClampedLod {
138    fn fmt(&self, f: &mut Formatter<'_>) -> core::fmt::Result {
139        self.0.write_prefixed(f, CLAMPED_LOD_LOAD_PREFIX)
140    }
141}
142
143/// Wrapper for generating `struct _mslBufferSizes` member names for
144/// runtime-sized array lengths.
145///
146/// On Metal, `wgpu_hal` passes the element counts for all runtime-sized arrays
147/// as an argument to the entry point. This argument's type in the MSL is
148/// `struct _mslBufferSizes`, a Naga-synthesized struct with a `uint` member for
149/// each global variable containing a runtime-sized array.
150///
151/// If `global` is a [`Handle`] for a [`GlobalVariable`] that contains a
152/// runtime-sized array, then the value `ArraySize(global)` implements
153/// [`core::fmt::Display`], formatting as the name of the struct member carrying
154/// the number of elements in that runtime-sized array.
155///
156/// [`GlobalVariable`]: crate::GlobalVariable
157struct ArraySizeMember(Handle<crate::GlobalVariable>);
158
159impl Display for ArraySizeMember {
160    fn fmt(&self, f: &mut Formatter<'_>) -> core::fmt::Result {
161        self.0.write_prefixed(f, "size")
162    }
163}
164
165/// Wrapper for reinterpreted variables using `as_type<target_type>(orig)`.
166///
167/// Implements [`core::fmt::Display`], formatting as a name derived from
168/// `target_type` and the variable name of `orig`.
169#[derive(Clone, Copy)]
170struct Reinterpreted<'a> {
171    target_type: &'a str,
172    orig: Handle<crate::Expression>,
173}
174
175impl<'a> Reinterpreted<'a> {
176    const fn new(target_type: &'a str, orig: Handle<crate::Expression>) -> Self {
177        Self { target_type, orig }
178    }
179}
180
181impl Display for Reinterpreted<'_> {
182    fn fmt(&self, f: &mut Formatter<'_>) -> core::fmt::Result {
183        f.write_str(REINTERPRET_PREFIX)?;
184        f.write_str(self.target_type)?;
185        self.orig.write_prefixed(f, "_e")
186    }
187}
188
189pub(super) struct TypeContext<'a> {
190    pub handle: Handle<crate::Type>,
191    pub gctx: proc::GlobalCtx<'a>,
192    pub names: &'a FastHashMap<NameKey, String>,
193    pub access: crate::StorageAccess,
194    pub first_time: bool,
195}
196
197impl TypeContext<'_> {
198    fn scalar(&self) -> Option<crate::Scalar> {
199        let ty = &self.gctx.types[self.handle];
200        ty.inner.scalar()
201    }
202
203    fn vector_size(&self) -> Option<crate::VectorSize> {
204        let ty = &self.gctx.types[self.handle];
205        match ty.inner {
206            crate::TypeInner::Vector { size, .. } => Some(size),
207            _ => None,
208        }
209    }
210
211    fn unwrap_array(self) -> Self {
212        match self.gctx.types[self.handle].inner {
213            crate::TypeInner::Array { base, .. } => Self {
214                handle: base,
215                ..self
216            },
217            _ => self,
218        }
219    }
220}
221
222impl Display for TypeContext<'_> {
223    fn fmt(&self, out: &mut Formatter<'_>) -> Result<(), FmtError> {
224        let ty = &self.gctx.types[self.handle];
225        if ty.needs_alias() && !self.first_time {
226            let name = &self.names[&NameKey::Type(self.handle)];
227            return write!(out, "{name}");
228        }
229
230        match ty.inner {
231            crate::TypeInner::Scalar(scalar) => put_numeric_type(out, scalar, &[]),
232            crate::TypeInner::Atomic(scalar) => {
233                write!(out, "{}::atomic_{}", NAMESPACE, scalar.to_msl_name())
234            }
235            crate::TypeInner::Vector { size, scalar } => put_numeric_type(out, scalar, &[size]),
236            crate::TypeInner::Matrix {
237                columns,
238                rows,
239                scalar,
240            } => put_numeric_type(out, scalar, &[rows, columns]),
241            // Requires Metal-2.3
242            crate::TypeInner::CooperativeMatrix {
243                columns,
244                rows,
245                scalar,
246                role: _,
247            } => {
248                write!(
249                    out,
250                    "{NAMESPACE}::simdgroup_{}{}x{}",
251                    scalar.to_msl_name(),
252                    columns as u32,
253                    rows as u32,
254                )
255            }
256            crate::TypeInner::Pointer { base, space } => {
257                let sub = Self {
258                    handle: base,
259                    first_time: false,
260                    ..*self
261                };
262                let space_name = match space.to_msl_name() {
263                    Some(name) => name,
264                    None => return Ok(()),
265                };
266                write!(out, "{space_name} {sub}&")
267            }
268            crate::TypeInner::ValuePointer {
269                size,
270                scalar,
271                space,
272            } => {
273                match space.to_msl_name() {
274                    Some(name) => write!(out, "{name} ")?,
275                    None => return Ok(()),
276                };
277                match size {
278                    Some(rows) => put_numeric_type(out, scalar, &[rows])?,
279                    None => put_numeric_type(out, scalar, &[])?,
280                };
281
282                write!(out, "&")
283            }
284            crate::TypeInner::Array { base, .. } => {
285                let sub = Self {
286                    handle: base,
287                    first_time: false,
288                    ..*self
289                };
290                // Array lengths go at the end of the type definition,
291                // so just print the element type here.
292                write!(out, "{sub}")
293            }
294            crate::TypeInner::Struct { .. } => unreachable!(),
295            crate::TypeInner::Image {
296                dim,
297                arrayed,
298                class,
299            } => {
300                let dim_str = match dim {
301                    crate::ImageDimension::D1 => "1d",
302                    crate::ImageDimension::D2 => "2d",
303                    crate::ImageDimension::D3 => "3d",
304                    crate::ImageDimension::Cube => "cube",
305                };
306                let (texture_str, msaa_str, scalar, access) = match class {
307                    crate::ImageClass::Sampled { kind, multi } => {
308                        let (msaa_str, access) = if multi {
309                            ("_ms", "read")
310                        } else {
311                            ("", "sample")
312                        };
313                        let scalar = crate::Scalar { kind, width: 4 };
314                        ("texture", msaa_str, scalar, access)
315                    }
316                    crate::ImageClass::Depth { multi } => {
317                        let (msaa_str, access) = if multi {
318                            ("_ms", "read")
319                        } else {
320                            ("", "sample")
321                        };
322                        let scalar = crate::Scalar {
323                            kind: crate::ScalarKind::Float,
324                            width: 4,
325                        };
326                        ("depth", msaa_str, scalar, access)
327                    }
328                    crate::ImageClass::Storage { format, .. } => {
329                        let access = if self
330                            .access
331                            .contains(crate::StorageAccess::LOAD | crate::StorageAccess::STORE)
332                        {
333                            "read_write"
334                        } else if self.access.contains(crate::StorageAccess::STORE) {
335                            "write"
336                        } else if self.access.contains(crate::StorageAccess::LOAD) {
337                            "read"
338                        } else {
339                            log::warn!(
340                                "Storage access for {:?} (name '{}'): {:?}",
341                                self.handle,
342                                ty.name.as_deref().unwrap_or_default(),
343                                self.access
344                            );
345                            unreachable!("module is not valid");
346                        };
347                        ("texture", "", format.into(), access)
348                    }
349                    crate::ImageClass::External => {
350                        return write!(out, "{EXTERNAL_TEXTURE_WRAPPER_STRUCT}");
351                    }
352                };
353                let base_name = scalar.to_msl_name();
354                let array_str = if arrayed { "_array" } else { "" };
355                write!(
356                    out,
357                    "{NAMESPACE}::{texture_str}{dim_str}{msaa_str}{array_str}<{base_name}, {NAMESPACE}::access::{access}>",
358                )
359            }
360            crate::TypeInner::Sampler { comparison: _ } => {
361                write!(out, "{NAMESPACE}::sampler")
362            }
363            crate::TypeInner::AccelerationStructure { vertex_return } => {
364                if vertex_return {
365                    unimplemented!("metal does not support vertex ray hit return")
366                }
367                write!(out, "{RT_NAMESPACE}::instance_acceleration_structure")
368            }
369            crate::TypeInner::RayQuery { vertex_return } => {
370                if vertex_return {
371                    unimplemented!("metal does not support vertex ray hit return")
372                }
373                write!(out, "{}", super::ray::metal_intersector_ty())
374            }
375            crate::TypeInner::BindingArray { base, .. } => {
376                let base_tyname = Self {
377                    handle: base,
378                    first_time: false,
379                    ..*self
380                };
381
382                write!(
383                    out,
384                    "constant {ARGUMENT_BUFFER_WRAPPER_STRUCT}<{base_tyname}>*"
385                )
386            }
387        }
388    }
389}
390
391pub(super) struct TypedGlobalVariable<'a> {
392    pub module: &'a crate::Module,
393    pub names: &'a FastHashMap<NameKey, String>,
394    pub handle: Handle<crate::GlobalVariable>,
395    pub usage: valid::GlobalUse,
396    pub reference: bool,
397}
398
399struct TypedGlobalVariableParts {
400    ty_name: String,
401    var_name: String,
402}
403
404impl TypedGlobalVariable<'_> {
405    fn to_parts(&self) -> Result<TypedGlobalVariableParts, Error> {
406        let var = &self.module.global_variables[self.handle];
407        let name = &self.names[&NameKey::GlobalVariable(self.handle)];
408
409        let storage_access = match var.space {
410            crate::AddressSpace::Storage { access } => access,
411            _ => match self.module.types[var.ty].inner {
412                crate::TypeInner::Image {
413                    class: crate::ImageClass::Storage { access, .. },
414                    ..
415                } => access,
416                crate::TypeInner::BindingArray { base, .. } => {
417                    match self.module.types[base].inner {
418                        crate::TypeInner::Image {
419                            class: crate::ImageClass::Storage { access, .. },
420                            ..
421                        } => access,
422                        _ => crate::StorageAccess::default(),
423                    }
424                }
425                _ => crate::StorageAccess::default(),
426            },
427        };
428        let ty_name = TypeContext {
429            handle: var.ty,
430            gctx: self.module.to_ctx(),
431            names: self.names,
432            access: storage_access,
433            first_time: false,
434        };
435
436        let access = if var.space.needs_access_qualifier()
437            && !self.usage.intersects(valid::GlobalUse::WRITE)
438        {
439            "const"
440        } else {
441            ""
442        };
443        let (coherent, space, access, reference) = match (var.space.to_msl_name(), var.space) {
444            (Some(space), crate::AddressSpace::WorkGroup) => {
445                ("", space, access, if self.reference { "&" } else { "" })
446            }
447            (Some(space), _) if self.reference => {
448                let coherent = if var
449                    .memory_decorations
450                    .contains(crate::MemoryDecorations::COHERENT)
451                {
452                    "coherent "
453                } else {
454                    ""
455                };
456                (coherent, space, access, "&")
457            }
458            _ => ("", "", "", ""),
459        };
460
461        let ty = format!(
462            "{coherent}{space}{}{ty_name}{}{access}{reference}",
463            if space.is_empty() { "" } else { " " },
464            if access.is_empty() { "" } else { " " },
465        );
466
467        Ok(TypedGlobalVariableParts {
468            ty_name: ty,
469            var_name: name.clone(),
470        })
471    }
472    pub(super) fn try_fmt<W: Write>(&self, out: &mut W) -> BackendResult {
473        let parts = self.to_parts()?;
474
475        Ok(write!(out, "{} {}", parts.ty_name, parts.var_name)?)
476    }
477}
478
479#[derive(Eq, PartialEq, Hash)]
480pub(super) enum WrappedFunction {
481    UnaryOp {
482        op: crate::UnaryOperator,
483        ty: (Option<crate::VectorSize>, crate::Scalar),
484    },
485    BinaryOp {
486        op: crate::BinaryOperator,
487        left_ty: (Option<crate::VectorSize>, crate::Scalar),
488        right_ty: (Option<crate::VectorSize>, crate::Scalar),
489    },
490    Math {
491        fun: crate::MathFunction,
492        arg_ty: (Option<crate::VectorSize>, crate::Scalar),
493    },
494    Cast {
495        src_scalar: crate::Scalar,
496        vector_size: Option<crate::VectorSize>,
497        dst_scalar: crate::Scalar,
498    },
499    ImageLoad {
500        class: crate::ImageClass,
501    },
502    ImageSample {
503        class: crate::ImageClass,
504        clamp_to_edge: bool,
505    },
506    ImageQuerySize {
507        class: crate::ImageClass,
508    },
509    CooperativeLoad {
510        space_name: &'static str,
511        columns: crate::CooperativeSize,
512        rows: crate::CooperativeSize,
513        scalar: crate::Scalar,
514    },
515    CooperativeMultiplyAdd {
516        space_name: &'static str,
517        columns: crate::CooperativeSize,
518        rows: crate::CooperativeSize,
519        intermediate: crate::CooperativeSize,
520        scalar: crate::Scalar,
521    },
522    RayQueryGetIntersection {
523        committed: bool,
524    },
525}
526
527pub struct Writer<W> {
528    pub(super) out: W,
529    pub(super) names: FastHashMap<NameKey, String>,
530    pub(super) named_expressions: crate::NamedExpressions,
531    /// Set of expressions that need to be baked to avoid unnecessary repetition in output
532    pub(super) need_bake_expressions: back::NeedBakeExpressions,
533    pub(super) namer: proc::Namer,
534    pub(super) wrapped_functions: FastHashSet<WrappedFunction>,
535    #[cfg(test)]
536    pub(super) put_expression_stack_pointers: FastHashSet<*const ()>,
537    #[cfg(test)]
538    pub(super) put_block_stack_pointers: FastHashSet<*const ()>,
539    /// Set of (struct type, struct field index) denoting which fields require
540    /// padding inserted **before** them (i.e. between fields at index - 1 and index)
541    pub(super) struct_member_pads: FastHashSet<(Handle<crate::Type>, u32)>,
542    pub(super) needs_object_memory_barriers: bool,
543}
544
545impl crate::Scalar {
546    pub(super) fn to_msl_name(self) -> &'static str {
547        use crate::ScalarKind as Sk;
548        match self {
549            Self {
550                kind: Sk::Float,
551                width: 4,
552            } => "float",
553            Self {
554                kind: Sk::Float,
555                width: 2,
556            } => "half",
557            Self {
558                kind: Sk::Sint,
559                width: 2,
560            } => "short",
561            Self {
562                kind: Sk::Uint,
563                width: 2,
564            } => "ushort",
565            Self {
566                kind: Sk::Sint,
567                width: 4,
568            } => "int",
569            Self {
570                kind: Sk::Uint,
571                width: 4,
572            } => "uint",
573            Self {
574                kind: Sk::Sint,
575                width: 8,
576            } => "long",
577            Self {
578                kind: Sk::Uint,
579                width: 8,
580            } => "ulong",
581            Self {
582                kind: Sk::Bool,
583                width: _,
584            } => "bool",
585            Self {
586                kind: Sk::AbstractInt | Sk::AbstractFloat,
587                width: _,
588            } => unreachable!("Found Abstract scalar kind"),
589            _ => unreachable!("Unsupported scalar kind: {:?}", self),
590        }
591    }
592}
593
594const fn separate(need_separator: bool) -> &'static str {
595    if need_separator {
596        ","
597    } else {
598        ""
599    }
600}
601
602fn should_pack_struct_member(
603    members: &[crate::StructMember],
604    span: u32,
605    index: usize,
606    module: &crate::Module,
607) -> Option<crate::Scalar> {
608    let member = &members[index];
609
610    let ty_inner = &module.types[member.ty].inner;
611    let last_offset = member.offset + ty_inner.size(module.to_ctx());
612    let next_offset = match members.get(index + 1) {
613        Some(next) => next.offset,
614        None => span,
615    };
616    let is_tight = next_offset == last_offset;
617
618    match *ty_inner {
619        crate::TypeInner::Vector {
620            size: crate::VectorSize::Tri,
621            scalar: scalar @ crate::Scalar { width: 4 | 2, .. },
622        } if is_tight => Some(scalar),
623        _ => None,
624    }
625}
626
627fn needs_array_length(ty: Handle<crate::Type>, arena: &crate::UniqueArena<crate::Type>) -> bool {
628    match arena[ty].inner {
629        crate::TypeInner::Struct { ref members, .. } => {
630            if let Some(member) = members.last() {
631                if let crate::TypeInner::Array {
632                    size: crate::ArraySize::Dynamic,
633                    ..
634                } = arena[member.ty].inner
635                {
636                    return true;
637                }
638            }
639            false
640        }
641        crate::TypeInner::Array {
642            size: crate::ArraySize::Dynamic,
643            ..
644        } => true,
645        _ => false,
646    }
647}
648
649impl crate::AddressSpace {
650    /// Returns true if global variables in this address space are
651    /// passed in function arguments. These arguments need to be
652    /// passed through any functions called from the entry point.
653    const fn needs_pass_through(&self) -> bool {
654        match *self {
655            Self::Uniform
656            | Self::Storage { .. }
657            | Self::Private
658            | Self::WorkGroup
659            | Self::Immediate
660            | Self::Handle
661            | Self::TaskPayload => true,
662            Self::Function => false,
663            Self::RayPayload | Self::IncomingRayPayload => unreachable!(),
664        }
665    }
666
667    /// Returns true if the address space may need a "const" qualifier.
668    const fn needs_access_qualifier(&self) -> bool {
669        match *self {
670            //Note: we are ignoring the storage access here, and instead
671            // rely on the actual use of a global by functions. This means we
672            // may end up with "const" even if the binding is read-write,
673            // and that should be OK.
674            Self::Storage { .. } => true,
675            Self::TaskPayload => true,
676            Self::RayPayload | Self::IncomingRayPayload => unimplemented!(),
677            // These should always be read-write.
678            Self::Private | Self::WorkGroup => false,
679            // These translate to `constant` address space, no need for qualifiers.
680            Self::Uniform | Self::Immediate => false,
681            // Not applicable.
682            Self::Handle | Self::Function => false,
683        }
684    }
685
686    const fn to_msl_name(self) -> Option<&'static str> {
687        match self {
688            Self::Handle => None,
689            Self::Uniform | Self::Immediate => Some("constant"),
690            Self::Storage { .. } => Some("device"),
691            // note for `RayPayload`, this probably needs to be emulated as a
692            // private variable, as metal has essentially an inout input
693            // for where it is passed.
694            Self::Private | Self::Function | Self::RayPayload => Some("thread"),
695            Self::WorkGroup => Some("threadgroup"),
696            Self::TaskPayload => Some("object_data"),
697            Self::IncomingRayPayload => Some("ray_data"),
698        }
699    }
700}
701
702impl crate::Type {
703    // Returns `true` if we need to emit an alias for this type.
704    const fn needs_alias(&self) -> bool {
705        use crate::TypeInner as Ti;
706
707        match self.inner {
708            // value types are concise enough, we only alias them if they are named
709            Ti::Scalar(_)
710            | Ti::Vector { .. }
711            | Ti::Matrix { .. }
712            | Ti::CooperativeMatrix { .. }
713            | Ti::Atomic(_)
714            | Ti::Pointer { .. }
715            | Ti::ValuePointer { .. } => self.name.is_some(),
716            // composite types are better to be aliased, regardless of the name
717            Ti::Struct { .. } | Ti::Array { .. } => true,
718            // handle types may be different, depending on the global var access, so we always inline them
719            Ti::Image { .. }
720            | Ti::Sampler { .. }
721            | Ti::AccelerationStructure { .. }
722            | Ti::RayQuery { .. }
723            | Ti::BindingArray { .. } => false,
724        }
725    }
726}
727
728#[derive(Clone, Copy)]
729enum FunctionOrigin {
730    Handle(Handle<crate::Function>),
731    EntryPoint(proc::EntryPointIndex),
732}
733
734trait NameKeyExt {
735    fn local(origin: FunctionOrigin, local_handle: Handle<crate::LocalVariable>) -> NameKey {
736        match origin {
737            FunctionOrigin::Handle(handle) => NameKey::FunctionLocal(handle, local_handle),
738            FunctionOrigin::EntryPoint(idx) => NameKey::EntryPointLocal(idx, local_handle),
739        }
740    }
741
742    /// Return the name key for a local variable used by ReadZeroSkipWrite bounds-check
743    /// policy when it needs to produce a pointer-typed result for an OOB access. These
744    /// are unique per accessed type, so the second argument is a type handle. See docs
745    /// for [`crate::back::msl`].
746    fn oob_local_for_type(origin: FunctionOrigin, ty: Handle<crate::Type>) -> NameKey {
747        match origin {
748            FunctionOrigin::Handle(handle) => NameKey::FunctionOobLocal(handle, ty),
749            FunctionOrigin::EntryPoint(idx) => NameKey::EntryPointOobLocal(idx, ty),
750        }
751    }
752}
753
754impl NameKeyExt for NameKey {}
755
756/// A level of detail argument.
757///
758/// When [`BoundsCheckPolicy::Restrict`] applies to an [`ImageLoad`] access, we
759/// save the clamped level of detail in a temporary variable whose name is based
760/// on the handle of the `ImageLoad` expression. But for other policies, we just
761/// use the expression directly.
762///
763/// [`BoundsCheckPolicy::Restrict`]: index::BoundsCheckPolicy::Restrict
764/// [`ImageLoad`]: crate::Expression::ImageLoad
765#[derive(Clone, Copy)]
766enum LevelOfDetail {
767    Direct(Handle<crate::Expression>),
768    Restricted(Handle<crate::Expression>),
769}
770
771/// Values needed to select a particular texel for [`ImageLoad`] and [`ImageStore`].
772///
773/// When this is used in code paths unconcerned with the `Restrict` bounds check
774/// policy, the `LevelOfDetail` enum introduces an unneeded match, since `level`
775/// will always be either `None` or `Some(Direct(_))`. But this turns out not to
776/// be too awkward. If that changes, we can revisit.
777///
778/// [`ImageLoad`]: crate::Expression::ImageLoad
779/// [`ImageStore`]: crate::Statement::ImageStore
780struct TexelAddress {
781    coordinate: Handle<crate::Expression>,
782    array_index: Option<Handle<crate::Expression>>,
783    sample: Option<Handle<crate::Expression>>,
784    level: Option<LevelOfDetail>,
785}
786
787pub(super) struct ExpressionContext<'a> {
788    pub(super) function: &'a crate::Function,
789    origin: FunctionOrigin,
790    pub(super) info: &'a valid::FunctionInfo,
791    pub(super) module: &'a crate::Module,
792    pub(super) mod_info: &'a valid::ModuleInfo,
793    pub(super) pipeline_options: &'a PipelineOptions,
794    pub(super) lang_version: (u8, u8),
795    pub(super) policies: index::BoundsCheckPolicies,
796
797    /// The set of expressions used as indices in `ReadZeroSkipWrite`-policy
798    /// accesses. These may need to be cached in temporary variables. See
799    /// `index::find_checked_indexes` for details.
800    pub(super) guarded_indices: HandleSet<crate::Expression>,
801    /// See [`Writer::gen_force_bounded_loop_statements`] for details.
802    pub(super) force_loop_bounding: bool,
803}
804
805impl<'a> ExpressionContext<'a> {
806    fn resolve_type(&self, handle: Handle<crate::Expression>) -> &'a crate::TypeInner {
807        self.info[handle].ty.inner_with(&self.module.types)
808    }
809
810    /// Return true if calls to `image`'s `read` and `write` methods should supply a level of detail.
811    ///
812    /// Only mipmapped images need to specify a level of detail. Since 1D
813    /// textures cannot have mipmaps, MSL requires that the level argument to
814    /// texture1d queries and accesses must be a constexpr 0. It's easiest
815    /// just to omit the level entirely for 1D textures.
816    fn image_needs_lod(&self, image: Handle<crate::Expression>) -> bool {
817        let image_ty = self.resolve_type(image);
818        if let crate::TypeInner::Image { dim, class, .. } = *image_ty {
819            class.is_mipmapped() && dim != crate::ImageDimension::D1
820        } else {
821            false
822        }
823    }
824
825    fn choose_bounds_check_policy(
826        &self,
827        pointer: Handle<crate::Expression>,
828    ) -> index::BoundsCheckPolicy {
829        self.policies
830            .choose_policy(pointer, &self.module.types, self.info)
831    }
832
833    /// See docs for [`proc::index::access_needs_check`].
834    fn access_needs_check(
835        &self,
836        base: Handle<crate::Expression>,
837        index: index::GuardedIndex,
838    ) -> Option<index::IndexableLength> {
839        index::access_needs_check(
840            base,
841            index,
842            self.module,
843            &self.function.expressions,
844            self.info,
845        )
846    }
847
848    /// See docs for [`proc::index::bounds_check_iter`].
849    fn bounds_check_iter(
850        &self,
851        chain: Handle<crate::Expression>,
852    ) -> impl Iterator<Item = BoundsCheck> + '_ {
853        index::bounds_check_iter(chain, self.module, self.function, self.info)
854    }
855
856    /// See docs for [`proc::index::oob_local_types`].
857    fn oob_local_types(&self) -> FastHashSet<Handle<crate::Type>> {
858        index::oob_local_types(self.module, self.function, self.info, self.policies)
859    }
860
861    fn get_packed_vec_kind(&self, expr_handle: Handle<crate::Expression>) -> Option<crate::Scalar> {
862        match self.function.expressions[expr_handle] {
863            crate::Expression::AccessIndex { base, index } => {
864                let ty = match *self.resolve_type(base) {
865                    crate::TypeInner::Pointer { base, .. } => &self.module.types[base].inner,
866                    ref ty => ty,
867                };
868                match *ty {
869                    crate::TypeInner::Struct {
870                        ref members, span, ..
871                    } => should_pack_struct_member(members, span, index as usize, self.module),
872                    _ => None,
873                }
874            }
875            _ => None,
876        }
877    }
878}
879
880pub(super) struct StatementContext<'a> {
881    pub(super) expression: ExpressionContext<'a>,
882    pub(super) result_struct: Option<&'a str>,
883}
884
885impl<W: Write> Writer<W> {
886    /// Creates a new `Writer` instance.
887    pub fn new(out: W) -> Self {
888        Writer {
889            out,
890            names: FastHashMap::default(),
891            named_expressions: Default::default(),
892            need_bake_expressions: Default::default(),
893            namer: proc::Namer::default(),
894            wrapped_functions: FastHashSet::default(),
895            #[cfg(test)]
896            put_expression_stack_pointers: Default::default(),
897            #[cfg(test)]
898            put_block_stack_pointers: Default::default(),
899            struct_member_pads: FastHashSet::default(),
900            needs_object_memory_barriers: false,
901        }
902    }
903
904    /// Finishes writing and returns the output.
905    // See https://github.com/rust-lang/rust-clippy/issues/4979.
906    pub fn finish(self) -> W {
907        self.out
908    }
909
910    /// Generates statements to be inserted immediately before and at the very
911    /// start of the body of each loop, to defeat MSL infinite loop reasoning.
912    /// The 0th item of the returned tuple should be inserted immediately prior
913    /// to the loop and the 1st item should be inserted at the very start of
914    /// the loop body.
915    ///
916    /// # What is this trying to solve?
917    ///
918    /// In Metal Shading Language, an infinite loop has undefined behavior.
919    /// (This rule is inherited from C++14.) This means that, if the MSL
920    /// compiler determines that a given loop will never exit, it may assume
921    /// that it is never reached. It may thus assume that any conditions
922    /// sufficient to cause the loop to be reached must be false. Like many
923    /// optimizing compilers, MSL uses this kind of analysis to establish limits
924    /// on the range of values variables involved in those conditions might
925    /// hold.
926    ///
927    /// For example, suppose the MSL compiler sees the code:
928    ///
929    /// ```ignore
930    /// if (i >= 10) {
931    ///     while (true) { }
932    /// }
933    /// ```
934    ///
935    /// It will recognize that the `while` loop will never terminate, conclude
936    /// that it must be unreachable, and thus infer that, if this code is
937    /// reached, then `i < 10` at that point.
938    ///
939    /// Now suppose that, at some point where `i` has the same value as above,
940    /// the compiler sees the code:
941    ///
942    /// ```ignore
943    /// if (i < 10) {
944    ///     a[i] = 1;
945    /// }
946    /// ```
947    ///
948    /// Because the compiler is confident that `i < 10`, it will make the
949    /// assignment to `a[i]` unconditional, rewriting this code as, simply:
950    ///
951    /// ```ignore
952    /// a[i] = 1;
953    /// ```
954    ///
955    /// If that `if` condition was injected by Naga to implement a bounds check,
956    /// the MSL compiler's optimizations could allow out-of-bounds array
957    /// accesses to occur.
958    ///
959    /// Naga cannot feasibly anticipate whether the MSL compiler will determine
960    /// that a loop is infinite, so an attacker could craft a Naga module
961    /// containing an infinite loop protected by conditions that cause the Metal
962    /// compiler to remove bounds checks that Naga injected elsewhere in the
963    /// function.
964    ///
965    /// This rewrite could occur even if the conditional assignment appears
966    /// *before* the `while` loop, as long as `i < 10` by the time the loop is
967    /// reached. This would allow the attacker to save the results of
968    /// unauthorized reads somewhere accessible before entering the infinite
969    /// loop. But even worse, the MSL compiler has been observed to simply
970    /// delete the infinite loop entirely, so that even code dominated by the
971    /// loop becomes reachable. This would make the attack even more flexible,
972    /// since shaders that would appear to never terminate would actually exit
973    /// nicely, after having stolen data from elsewhere in the GPU address
974    /// space.
975    ///
976    /// To avoid UB, Naga must persuade the MSL compiler that no loop Naga
977    /// generates is infinite. One approach would be to add inline assembly to
978    /// each loop that is annotated as potentially branching out of the loop,
979    /// but which in fact generates no instructions. Unfortunately, inline
980    /// assembly is not handled correctly by some Metal device drivers.
981    ///
982    /// A previously used approach was to add the following code to the bottom
983    /// of every loop:
984    ///
985    /// ```ignore
986    /// if (volatile bool unpredictable = false; unpredictable)
987    ///     break;
988    /// ```
989    ///
990    /// Although the `if` condition will always be false in any real execution,
991    /// the `volatile` qualifier prevents the compiler from assuming this. Thus,
992    /// it must assume that the `break` might be reached, and hence that the
993    /// loop is not unbounded. This prevents the range analysis impact described
994    /// above. Unfortunately this prevented the compiler from making important,
995    /// and safe, optimizations such as loop unrolling and was observed to
996    /// significantly hurt performance.
997    ///
998    /// Our current approach declares a counter before every loop and
999    /// increments it every iteration, breaking after 2^64 iterations:
1000    ///
1001    /// ```ignore
1002    /// uint2 loop_bound = uint2(0);
1003    /// while (true) {
1004    ///   if (metal::all(loop_bound == uint2(4294967295))) { break; }
1005    ///   loop_bound += uint2(loop_bound.y == 4294967295, 1);
1006    /// }
1007    /// ```
1008    ///
1009    /// This convinces the compiler that the loop is finite and therefore may
1010    /// execute, whilst at the same time allowing optimizations such as loop
1011    /// unrolling. Furthermore the 64-bit counter is large enough it seems
1012    /// implausible that it would affect the execution of any shader.
1013    ///
1014    /// This approach is also used by Chromium WebGPU's Dawn shader compiler:
1015    /// <https://dawn.googlesource.com/dawn/+/d9e2d1f718678ebee0728b999830576c410cce0a/src/tint/lang/core/ir/transform/prevent_infinite_loops.cc>
1016    fn gen_force_bounded_loop_statements(
1017        &mut self,
1018        level: back::Level,
1019        context: &StatementContext,
1020    ) -> Option<(String, String)> {
1021        if !context.expression.force_loop_bounding {
1022            return None;
1023        }
1024
1025        let loop_bound_name = self.namer.call("loop_bound");
1026        // Count down from u32::MAX rather than up from 0 to avoid hang on
1027        // certain Intel drivers. See <https://github.com/gfx-rs/wgpu/issues/7319>.
1028        let decl = format!("{level}uint2 {loop_bound_name} = uint2({}u);", u32::MAX);
1029        let level = level.next();
1030        let break_and_inc = format!(
1031            "{level}if ({NAMESPACE}::all({loop_bound_name} == uint2(0u))) {{ break; }}
1032{level}{loop_bound_name} -= uint2({loop_bound_name}.y == 0u, 1u);"
1033        );
1034
1035        Some((decl, break_and_inc))
1036    }
1037
1038    fn put_call_parameters(
1039        &mut self,
1040        parameters: impl Iterator<Item = Handle<crate::Expression>>,
1041        context: &ExpressionContext,
1042    ) -> BackendResult {
1043        self.put_call_parameters_impl(parameters, context, |writer, context, expr| {
1044            writer.put_expression(expr, context, true)
1045        })
1046    }
1047
1048    fn put_call_parameters_impl<C, E>(
1049        &mut self,
1050        parameters: impl Iterator<Item = Handle<crate::Expression>>,
1051        ctx: &C,
1052        put_expression: E,
1053    ) -> BackendResult
1054    where
1055        E: Fn(&mut Self, &C, Handle<crate::Expression>) -> BackendResult,
1056    {
1057        write!(self.out, "(")?;
1058        for (i, handle) in parameters.enumerate() {
1059            if i != 0 {
1060                write!(self.out, ", ")?;
1061            }
1062            put_expression(self, ctx, handle)?;
1063        }
1064        write!(self.out, ")")?;
1065        Ok(())
1066    }
1067
1068    /// Writes the local variables of the given function, as well as any extra
1069    /// out-of-bounds locals that are needed.
1070    ///
1071    /// The names of the OOB locals are also added to `self.names` at the same
1072    /// time.
1073    fn put_locals(&mut self, context: &ExpressionContext) -> BackendResult {
1074        let oob_local_types = context.oob_local_types();
1075        for &ty in oob_local_types.iter() {
1076            let name_key = NameKey::oob_local_for_type(context.origin, ty);
1077            self.names.insert(name_key, self.namer.call("oob"));
1078        }
1079
1080        for (name_key, ty, init) in context
1081            .function
1082            .local_variables
1083            .iter()
1084            .map(|(local_handle, local)| {
1085                let name_key = NameKey::local(context.origin, local_handle);
1086                (name_key, local.ty, local.init)
1087            })
1088            .chain(oob_local_types.iter().map(|&ty| {
1089                let name_key = NameKey::oob_local_for_type(context.origin, ty);
1090                (name_key, ty, None)
1091            }))
1092        {
1093            let ty_name = TypeContext {
1094                handle: ty,
1095                gctx: context.module.to_ctx(),
1096                names: &self.names,
1097                access: crate::StorageAccess::empty(),
1098                first_time: false,
1099            };
1100            write!(
1101                self.out,
1102                "{}{} {}",
1103                back::INDENT,
1104                ty_name,
1105                self.names[&name_key]
1106            )?;
1107            match init {
1108                Some(value) => {
1109                    write!(self.out, " = ")?;
1110                    self.put_expression(value, context, true)?;
1111                }
1112                None => {
1113                    write!(self.out, " = {{}}")?;
1114                }
1115            };
1116            writeln!(self.out, ";")?;
1117        }
1118        Ok(())
1119    }
1120
1121    fn put_level_of_detail(
1122        &mut self,
1123        level: LevelOfDetail,
1124        context: &ExpressionContext,
1125    ) -> BackendResult {
1126        match level {
1127            LevelOfDetail::Direct(expr) => self.put_expression(expr, context, true)?,
1128            LevelOfDetail::Restricted(load) => write!(self.out, "{}", ClampedLod(load))?,
1129        }
1130        Ok(())
1131    }
1132
1133    fn put_image_query(
1134        &mut self,
1135        image: Handle<crate::Expression>,
1136        query: &str,
1137        level: Option<LevelOfDetail>,
1138        context: &ExpressionContext,
1139    ) -> BackendResult {
1140        self.put_expression(image, context, false)?;
1141        write!(self.out, ".get_{query}(")?;
1142        if let Some(level) = level {
1143            self.put_level_of_detail(level, context)?;
1144        }
1145        write!(self.out, ")")?;
1146        Ok(())
1147    }
1148
1149    fn put_image_size_query(
1150        &mut self,
1151        image: Handle<crate::Expression>,
1152        level: Option<LevelOfDetail>,
1153        kind: crate::ScalarKind,
1154        context: &ExpressionContext,
1155    ) -> BackendResult {
1156        if let crate::TypeInner::Image {
1157            class: crate::ImageClass::External,
1158            ..
1159        } = *context.resolve_type(image)
1160        {
1161            write!(self.out, "{IMAGE_SIZE_EXTERNAL_FUNCTION}(")?;
1162            self.put_expression(image, context, true)?;
1163            write!(self.out, ")")?;
1164            return Ok(());
1165        }
1166
1167        //Note: MSL only has separate width/height/depth queries,
1168        // so compose the result of them.
1169        let dim = match *context.resolve_type(image) {
1170            crate::TypeInner::Image { dim, .. } => dim,
1171            ref other => unreachable!("Unexpected type {:?}", other),
1172        };
1173        let scalar = crate::Scalar { kind, width: 4 };
1174        let coordinate_type = scalar.to_msl_name();
1175        match dim {
1176            crate::ImageDimension::D1 => {
1177                // Since 1D textures never have mipmaps, MSL requires that the
1178                // `level` argument be a constexpr 0. It's simplest for us just
1179                // to pass `None` and omit the level entirely.
1180                if kind == crate::ScalarKind::Uint {
1181                    // No need to construct a vector. No cast needed.
1182                    self.put_image_query(image, "width", None, context)?;
1183                } else {
1184                    // There's no definition for `int` in the `metal` namespace.
1185                    write!(self.out, "int(")?;
1186                    self.put_image_query(image, "width", None, context)?;
1187                    write!(self.out, ")")?;
1188                }
1189            }
1190            crate::ImageDimension::D2 => {
1191                write!(self.out, "{NAMESPACE}::{coordinate_type}2(")?;
1192                self.put_image_query(image, "width", level, context)?;
1193                write!(self.out, ", ")?;
1194                self.put_image_query(image, "height", level, context)?;
1195                write!(self.out, ")")?;
1196            }
1197            crate::ImageDimension::D3 => {
1198                write!(self.out, "{NAMESPACE}::{coordinate_type}3(")?;
1199                self.put_image_query(image, "width", level, context)?;
1200                write!(self.out, ", ")?;
1201                self.put_image_query(image, "height", level, context)?;
1202                write!(self.out, ", ")?;
1203                self.put_image_query(image, "depth", level, context)?;
1204                write!(self.out, ")")?;
1205            }
1206            crate::ImageDimension::Cube => {
1207                write!(self.out, "{NAMESPACE}::{coordinate_type}2(")?;
1208                self.put_image_query(image, "width", level, context)?;
1209                write!(self.out, ")")?;
1210            }
1211        }
1212        Ok(())
1213    }
1214
1215    fn put_cast_to_uint_scalar_or_vector(
1216        &mut self,
1217        expr: Handle<crate::Expression>,
1218        context: &ExpressionContext,
1219    ) -> BackendResult {
1220        // coordinates in IR are int, but Metal expects uint
1221        match *context.resolve_type(expr) {
1222            crate::TypeInner::Scalar(_) => {
1223                put_numeric_type(&mut self.out, crate::Scalar::U32, &[])?
1224            }
1225            crate::TypeInner::Vector { size, .. } => {
1226                put_numeric_type(&mut self.out, crate::Scalar::U32, &[size])?
1227            }
1228            _ => {
1229                return Err(Error::GenericValidation(
1230                    "Invalid type for image coordinate".into(),
1231                ))
1232            }
1233        };
1234
1235        write!(self.out, "(")?;
1236        self.put_expression(expr, context, true)?;
1237        write!(self.out, ")")?;
1238        Ok(())
1239    }
1240
1241    fn put_image_sample_level(
1242        &mut self,
1243        image: Handle<crate::Expression>,
1244        level: crate::SampleLevel,
1245        context: &ExpressionContext,
1246    ) -> BackendResult {
1247        let has_levels = context.image_needs_lod(image);
1248        match level {
1249            crate::SampleLevel::Auto => {}
1250            crate::SampleLevel::Zero => {
1251                //TODO: do we support Zero on `Sampled` image classes?
1252            }
1253            _ if !has_levels => {
1254                log::warn!("1D image can't be sampled with level {level:?}");
1255            }
1256            crate::SampleLevel::Exact(h) => {
1257                write!(self.out, ", {NAMESPACE}::level(")?;
1258                self.put_expression(h, context, true)?;
1259                write!(self.out, ")")?;
1260            }
1261            crate::SampleLevel::Bias(h) => {
1262                write!(self.out, ", {NAMESPACE}::bias(")?;
1263                self.put_expression(h, context, true)?;
1264                write!(self.out, ")")?;
1265            }
1266            crate::SampleLevel::Gradient { x, y } => {
1267                write!(self.out, ", {NAMESPACE}::gradient2d(")?;
1268                self.put_expression(x, context, true)?;
1269                write!(self.out, ", ")?;
1270                self.put_expression(y, context, true)?;
1271                write!(self.out, ")")?;
1272            }
1273        }
1274        Ok(())
1275    }
1276
1277    fn put_image_coordinate_limits(
1278        &mut self,
1279        image: Handle<crate::Expression>,
1280        level: Option<LevelOfDetail>,
1281        context: &ExpressionContext,
1282    ) -> BackendResult {
1283        self.put_image_size_query(image, level, crate::ScalarKind::Uint, context)?;
1284        write!(self.out, " - 1")?;
1285        Ok(())
1286    }
1287
1288    /// General function for writing restricted image indexes.
1289    ///
1290    /// This is used to produce restricted mip levels, array indices, and sample
1291    /// indices for [`ImageLoad`] and [`ImageStore`] accesses under the
1292    /// [`Restrict`] bounds check policy.
1293    ///
1294    /// This function writes an expression of the form:
1295    ///
1296    /// ```ignore
1297    ///
1298    ///     metal::min(uint(INDEX), IMAGE.LIMIT_METHOD() - 1)
1299    ///
1300    /// ```
1301    ///
1302    /// [`ImageLoad`]: crate::Expression::ImageLoad
1303    /// [`ImageStore`]: crate::Statement::ImageStore
1304    /// [`Restrict`]: index::BoundsCheckPolicy::Restrict
1305    fn put_restricted_scalar_image_index(
1306        &mut self,
1307        image: Handle<crate::Expression>,
1308        index: Handle<crate::Expression>,
1309        limit_method: &str,
1310        context: &ExpressionContext,
1311    ) -> BackendResult {
1312        write!(self.out, "{NAMESPACE}::min(uint(")?;
1313        self.put_expression(index, context, true)?;
1314        write!(self.out, "), ")?;
1315        self.put_expression(image, context, false)?;
1316        write!(self.out, ".{limit_method}() - 1)")?;
1317        Ok(())
1318    }
1319
1320    fn put_restricted_texel_address(
1321        &mut self,
1322        image: Handle<crate::Expression>,
1323        address: &TexelAddress,
1324        context: &ExpressionContext,
1325    ) -> BackendResult {
1326        // Write the coordinate.
1327        write!(self.out, "{NAMESPACE}::min(")?;
1328        self.put_cast_to_uint_scalar_or_vector(address.coordinate, context)?;
1329        write!(self.out, ", ")?;
1330        self.put_image_coordinate_limits(image, address.level, context)?;
1331        write!(self.out, ")")?;
1332
1333        // Write the array index, if present.
1334        if let Some(array_index) = address.array_index {
1335            write!(self.out, ", ")?;
1336            self.put_restricted_scalar_image_index(image, array_index, "get_array_size", context)?;
1337        }
1338
1339        // Write the sample index, if present.
1340        if let Some(sample) = address.sample {
1341            write!(self.out, ", ")?;
1342            self.put_restricted_scalar_image_index(image, sample, "get_num_samples", context)?;
1343        }
1344
1345        // The level of detail should be clamped and cached by
1346        // `put_cache_restricted_level`, so we don't need to clamp it here.
1347        if let Some(level) = address.level {
1348            write!(self.out, ", ")?;
1349            self.put_level_of_detail(level, context)?;
1350        }
1351
1352        Ok(())
1353    }
1354
1355    /// Write an expression that is true if the given image access is in bounds.
1356    fn put_image_access_bounds_check(
1357        &mut self,
1358        image: Handle<crate::Expression>,
1359        address: &TexelAddress,
1360        context: &ExpressionContext,
1361    ) -> BackendResult {
1362        let mut conjunction = "";
1363
1364        // First, check the level of detail. Only if that is in bounds can we
1365        // use it to find the appropriate bounds for the coordinates.
1366        let level = if let Some(level) = address.level {
1367            write!(self.out, "uint(")?;
1368            self.put_level_of_detail(level, context)?;
1369            write!(self.out, ") < ")?;
1370            self.put_expression(image, context, true)?;
1371            write!(self.out, ".get_num_mip_levels()")?;
1372            conjunction = " && ";
1373            Some(level)
1374        } else {
1375            None
1376        };
1377
1378        // Check sample index, if present.
1379        if let Some(sample) = address.sample {
1380            write!(self.out, "uint(")?;
1381            self.put_expression(sample, context, true)?;
1382            write!(self.out, ") < ")?;
1383            self.put_expression(image, context, true)?;
1384            write!(self.out, ".get_num_samples()")?;
1385            conjunction = " && ";
1386        }
1387
1388        // Check array index, if present.
1389        if let Some(array_index) = address.array_index {
1390            write!(self.out, "{conjunction}uint(")?;
1391            self.put_expression(array_index, context, true)?;
1392            write!(self.out, ") < ")?;
1393            self.put_expression(image, context, true)?;
1394            write!(self.out, ".get_array_size()")?;
1395            conjunction = " && ";
1396        }
1397
1398        // Finally, check if the coordinates are within bounds.
1399        let coord_is_vector = match *context.resolve_type(address.coordinate) {
1400            crate::TypeInner::Vector { .. } => true,
1401            _ => false,
1402        };
1403        write!(self.out, "{conjunction}")?;
1404        if coord_is_vector {
1405            write!(self.out, "{NAMESPACE}::all(")?;
1406        }
1407        self.put_cast_to_uint_scalar_or_vector(address.coordinate, context)?;
1408        write!(self.out, " < ")?;
1409        self.put_image_size_query(image, level, crate::ScalarKind::Uint, context)?;
1410        if coord_is_vector {
1411            write!(self.out, ")")?;
1412        }
1413
1414        Ok(())
1415    }
1416
1417    fn put_image_load(
1418        &mut self,
1419        load: Handle<crate::Expression>,
1420        image: Handle<crate::Expression>,
1421        mut address: TexelAddress,
1422        context: &ExpressionContext,
1423    ) -> BackendResult {
1424        if let crate::TypeInner::Image {
1425            class: crate::ImageClass::External,
1426            ..
1427        } = *context.resolve_type(image)
1428        {
1429            write!(self.out, "{IMAGE_LOAD_EXTERNAL_FUNCTION}(")?;
1430            self.put_expression(image, context, true)?;
1431            write!(self.out, ", ")?;
1432            self.put_cast_to_uint_scalar_or_vector(address.coordinate, context)?;
1433            write!(self.out, ")")?;
1434            return Ok(());
1435        }
1436
1437        match context.policies.image_load {
1438            proc::BoundsCheckPolicy::Restrict => {
1439                // Use the cached restricted level of detail, if any. Omit the
1440                // level altogether for 1D textures.
1441                if address.level.is_some() {
1442                    address.level = if context.image_needs_lod(image) {
1443                        Some(LevelOfDetail::Restricted(load))
1444                    } else {
1445                        None
1446                    }
1447                }
1448
1449                self.put_expression(image, context, false)?;
1450                write!(self.out, ".read(")?;
1451                self.put_restricted_texel_address(image, &address, context)?;
1452                write!(self.out, ")")?;
1453            }
1454            proc::BoundsCheckPolicy::ReadZeroSkipWrite => {
1455                write!(self.out, "(")?;
1456                self.put_image_access_bounds_check(image, &address, context)?;
1457                write!(self.out, " ? ")?;
1458                self.put_unchecked_image_load(image, &address, context)?;
1459                write!(self.out, ": DefaultConstructible())")?;
1460            }
1461            proc::BoundsCheckPolicy::Unchecked => {
1462                self.put_unchecked_image_load(image, &address, context)?;
1463            }
1464        }
1465
1466        Ok(())
1467    }
1468
1469    fn put_unchecked_image_load(
1470        &mut self,
1471        image: Handle<crate::Expression>,
1472        address: &TexelAddress,
1473        context: &ExpressionContext,
1474    ) -> BackendResult {
1475        self.put_expression(image, context, false)?;
1476        write!(self.out, ".read(")?;
1477        // coordinates in IR are int, but Metal expects uint
1478        self.put_cast_to_uint_scalar_or_vector(address.coordinate, context)?;
1479        if let Some(expr) = address.array_index {
1480            write!(self.out, ", ")?;
1481            self.put_expression(expr, context, true)?;
1482        }
1483        if let Some(sample) = address.sample {
1484            write!(self.out, ", ")?;
1485            self.put_expression(sample, context, true)?;
1486        }
1487        if let Some(level) = address.level {
1488            if context.image_needs_lod(image) {
1489                write!(self.out, ", ")?;
1490                self.put_level_of_detail(level, context)?;
1491            }
1492        }
1493        write!(self.out, ")")?;
1494
1495        Ok(())
1496    }
1497
1498    fn put_image_atomic(
1499        &mut self,
1500        level: back::Level,
1501        image: Handle<crate::Expression>,
1502        address: &TexelAddress,
1503        fun: crate::AtomicFunction,
1504        value: Handle<crate::Expression>,
1505        context: &StatementContext,
1506    ) -> BackendResult {
1507        write!(self.out, "{level}")?;
1508        self.put_expression(image, &context.expression, false)?;
1509        let op = if context.expression.resolve_type(value).scalar_width() == Some(8) {
1510            fun.to_msl_64_bit()?
1511        } else {
1512            fun.to_msl()
1513        };
1514        write!(self.out, ".atomic_{op}(")?;
1515        // coordinates in IR are int, but Metal expects uint
1516        self.put_cast_to_uint_scalar_or_vector(address.coordinate, &context.expression)?;
1517        write!(self.out, ", ")?;
1518        self.put_expression(value, &context.expression, true)?;
1519        writeln!(self.out, ");")?;
1520
1521        // Workaround for Apple Metal TBDR driver bug: fragment shader atomic
1522        // texture writes randomly drop unless followed by a standard texture
1523        // write. Insert a dead-code write behind an unprovable condition so
1524        // the compiler emits proper memory safety barriers.
1525        // See: https://projects.blender.org/blender/blender/commit/aa95220576706122d79c91c7f5c522e6c7416425
1526        let value_ty = context.expression.resolve_type(value);
1527        let zero_value = match (value_ty.scalar_kind(), value_ty.scalar_width()) {
1528            (Some(crate::ScalarKind::Sint), _) => "int4(0)",
1529            (_, Some(8)) => "ulong4(0uL)",
1530            _ => "uint4(0u)",
1531        };
1532        let coord_ty = context.expression.resolve_type(address.coordinate);
1533        let x = if matches!(coord_ty, crate::TypeInner::Scalar(_)) {
1534            ""
1535        } else {
1536            ".x"
1537        };
1538        write!(self.out, "{level}if (")?;
1539        self.put_expression(address.coordinate, &context.expression, true)?;
1540        write!(self.out, "{x} == -99999) {{ ")?;
1541        self.put_expression(image, &context.expression, false)?;
1542        write!(self.out, ".write({zero_value}, ")?;
1543        self.put_cast_to_uint_scalar_or_vector(address.coordinate, &context.expression)?;
1544        if let Some(array_index) = address.array_index {
1545            write!(self.out, ", ")?;
1546            self.put_expression(array_index, &context.expression, true)?;
1547        }
1548        writeln!(self.out, "); }}")?;
1549
1550        Ok(())
1551    }
1552
1553    fn put_image_store(
1554        &mut self,
1555        level: back::Level,
1556        image: Handle<crate::Expression>,
1557        address: &TexelAddress,
1558        value: Handle<crate::Expression>,
1559        context: &StatementContext,
1560    ) -> BackendResult {
1561        write!(self.out, "{level}")?;
1562        self.put_expression(image, &context.expression, false)?;
1563        write!(self.out, ".write(")?;
1564        self.put_expression(value, &context.expression, true)?;
1565        write!(self.out, ", ")?;
1566        // coordinates in IR are int, but Metal expects uint
1567        self.put_cast_to_uint_scalar_or_vector(address.coordinate, &context.expression)?;
1568        if let Some(expr) = address.array_index {
1569            write!(self.out, ", ")?;
1570            self.put_expression(expr, &context.expression, true)?;
1571        }
1572        writeln!(self.out, ");")?;
1573
1574        Ok(())
1575    }
1576
1577    /// Write the maximum valid index of the dynamically sized array at the end of `handle`.
1578    ///
1579    /// The 'maximum valid index' is simply one less than the array's length.
1580    ///
1581    /// This emits an expression of the form `a / b`, so the caller must
1582    /// parenthesize its output if it will be applying operators of higher
1583    /// precedence.
1584    ///
1585    /// `handle` must be the handle of a global variable whose final member is a
1586    /// dynamically sized array.
1587    fn put_dynamic_array_max_index(
1588        &mut self,
1589        handle: Handle<crate::GlobalVariable>,
1590        context: &ExpressionContext,
1591    ) -> BackendResult {
1592        let global = &context.module.global_variables[handle];
1593        let (offset, array_ty) = match context.module.types[global.ty].inner {
1594            crate::TypeInner::Struct { ref members, .. } => match members.last() {
1595                Some(&crate::StructMember { offset, ty, .. }) => (offset, ty),
1596                None => return Err(Error::GenericValidation("Struct has no members".into())),
1597            },
1598            crate::TypeInner::Array {
1599                size: crate::ArraySize::Dynamic,
1600                ..
1601            } => (0, global.ty),
1602            ref ty => {
1603                return Err(Error::GenericValidation(format!(
1604                    "Expected type with dynamic array, got {ty:?}"
1605                )))
1606            }
1607        };
1608
1609        let (size, stride) = match context.module.types[array_ty].inner {
1610            crate::TypeInner::Array { base, stride, .. } => (
1611                context.module.types[base]
1612                    .inner
1613                    .size(context.module.to_ctx()),
1614                stride,
1615            ),
1616            ref ty => {
1617                return Err(Error::GenericValidation(format!(
1618                    "Expected array type, got {ty:?}"
1619                )))
1620            }
1621        };
1622
1623        // When the stride length is larger than the size, the final element's stride of
1624        // bytes would have padding following the value. But the buffer size in
1625        // `buffer_sizes.sizeN` may not include this padding - it only needs to be large
1626        // enough to hold the actual values' bytes.
1627        //
1628        // So subtract off the size to get a byte size that falls at the start or within
1629        // the final element. Then divide by the stride size, to get one less than the
1630        // length, and then add one. This works even if the buffer size does include the
1631        // stride padding, since division rounds towards zero (MSL 2.4 §6.1). It will fail
1632        // if there are zero elements in the array, but the WebGPU `validating shader binding`
1633        // rules, together with draw-time validation when `minBindingSize` is zero,
1634        // prevent that.
1635        write!(
1636            self.out,
1637            "(_buffer_sizes.{member} - {offset} - {size}) / {stride}",
1638            member = ArraySizeMember(handle),
1639            offset = offset,
1640            size = size,
1641            stride = stride,
1642        )?;
1643        Ok(())
1644    }
1645
1646    /// Emit code for the arithmetic expression of the dot product.
1647    ///
1648    /// The argument `extractor` is a function that accepts a `Writer`, a vector, and
1649    /// an index. It writes out the expression for the vector component at that index.
1650    fn put_dot_product<T: Copy>(
1651        &mut self,
1652        arg: T,
1653        arg1: T,
1654        size: usize,
1655        extractor: impl Fn(&mut Self, T, usize) -> BackendResult,
1656    ) -> BackendResult {
1657        // Write parentheses around the dot product expression to prevent operators
1658        // with different precedences from applying earlier.
1659        write!(self.out, "(")?;
1660
1661        // Cycle through all the components of the vector
1662        for index in 0..size {
1663            // Write the addition to the previous product
1664            // This will print an extra '+' at the beginning but that is fine in msl
1665            write!(self.out, " + ")?;
1666            extractor(self, arg, index)?;
1667            write!(self.out, " * ")?;
1668            extractor(self, arg1, index)?;
1669        }
1670
1671        write!(self.out, ")")?;
1672        Ok(())
1673    }
1674
1675    /// Emit code for the WGSL functions `pack4x{I, U}8[Clamp]`.
1676    fn put_pack4x8(
1677        &mut self,
1678        arg: Handle<crate::Expression>,
1679        context: &ExpressionContext<'_>,
1680        was_signed: bool,
1681        clamp_bounds: Option<(&str, &str)>,
1682    ) -> Result<(), Error> {
1683        let write_arg = |this: &mut Self| -> BackendResult {
1684            if let Some((min, max)) = clamp_bounds {
1685                // Clamping with scalar bounds works (component-wise) even for packed_[u]char4.
1686                write!(this.out, "{NAMESPACE}::clamp(")?;
1687                this.put_expression(arg, context, true)?;
1688                write!(this.out, ", {min}, {max})")?;
1689            } else {
1690                this.put_expression(arg, context, true)?;
1691            }
1692            Ok(())
1693        };
1694
1695        if context.lang_version >= (2, 1) {
1696            let packed_type = if was_signed {
1697                "packed_char4"
1698            } else {
1699                "packed_uchar4"
1700            };
1701            // Metal uses little endian byte order, which matches what WGSL expects here.
1702            write!(self.out, "as_type<uint>({packed_type}(")?;
1703            write_arg(self)?;
1704            write!(self.out, "))")?;
1705        } else {
1706            // MSL < 2.1 doesn't support `as_type` casting between packed chars and scalars.
1707            if was_signed {
1708                write!(self.out, "uint(")?;
1709            }
1710            write!(self.out, "(")?;
1711            write_arg(self)?;
1712            write!(self.out, "[0] & 0xFF) | ((")?;
1713            write_arg(self)?;
1714            write!(self.out, "[1] & 0xFF) << 8) | ((")?;
1715            write_arg(self)?;
1716            write!(self.out, "[2] & 0xFF) << 16) | ((")?;
1717            write_arg(self)?;
1718            write!(self.out, "[3] & 0xFF) << 24)")?;
1719            if was_signed {
1720                write!(self.out, ")")?;
1721            }
1722        }
1723
1724        Ok(())
1725    }
1726
1727    /// Emit code for the isign expression.
1728    ///
1729    fn put_isign(
1730        &mut self,
1731        arg: Handle<crate::Expression>,
1732        context: &ExpressionContext,
1733    ) -> BackendResult {
1734        write!(self.out, "{NAMESPACE}::select({NAMESPACE}::select(")?;
1735        let scalar = context
1736            .resolve_type(arg)
1737            .scalar()
1738            .expect("put_isign should only be called for args which have an integer scalar type")
1739            .to_msl_name();
1740        match context.resolve_type(arg) {
1741            &crate::TypeInner::Vector { size, .. } => {
1742                let size = common::vector_size_str(size);
1743                write!(self.out, "{scalar}{size}(-1), {scalar}{size}(1)")?;
1744            }
1745            _ => {
1746                write!(self.out, "{scalar}(-1), {scalar}(1)")?;
1747            }
1748        }
1749        write!(self.out, ", (")?;
1750        self.put_expression(arg, context, true)?;
1751        write!(self.out, " > 0)), {scalar}(0), (")?;
1752        self.put_expression(arg, context, true)?;
1753        write!(self.out, " == 0))")?;
1754        Ok(())
1755    }
1756
1757    pub(super) fn put_const_expression(
1758        &mut self,
1759        expr_handle: Handle<crate::Expression>,
1760        module: &crate::Module,
1761        mod_info: &valid::ModuleInfo,
1762        arena: &crate::Arena<crate::Expression>,
1763    ) -> BackendResult {
1764        self.put_possibly_const_expression(
1765            expr_handle,
1766            arena,
1767            module,
1768            mod_info,
1769            &(module, mod_info),
1770            |&(_, mod_info), expr| &mod_info[expr],
1771            |writer, &(module, _), expr| writer.put_const_expression(expr, module, mod_info, arena),
1772        )
1773    }
1774
1775    fn put_literal(&mut self, literal: crate::Literal) -> BackendResult {
1776        match literal {
1777            crate::Literal::F64(_) => {
1778                return Err(Error::CapabilityNotSupported(valid::Capabilities::FLOAT64))
1779            }
1780            crate::Literal::F16(value) => {
1781                if value.is_infinite() {
1782                    let sign = if value.is_sign_negative() { "-" } else { "" };
1783                    write!(self.out, "{sign}INFINITY")?;
1784                } else if value.is_nan() {
1785                    write!(self.out, "NAN")?;
1786                } else {
1787                    let suffix = if value.fract() == f16::from_f32(0.0) {
1788                        ".0h"
1789                    } else {
1790                        "h"
1791                    };
1792                    write!(self.out, "{value}{suffix}")?;
1793                }
1794            }
1795            crate::Literal::F32(value) => {
1796                if value.is_infinite() {
1797                    let sign = if value.is_sign_negative() { "-" } else { "" };
1798                    write!(self.out, "{sign}INFINITY")?;
1799                } else if value.is_nan() {
1800                    write!(self.out, "NAN")?;
1801                } else {
1802                    let suffix = if value.fract() == 0.0 { ".0" } else { "" };
1803                    write!(self.out, "{value}{suffix}")?;
1804                }
1805            }
1806            crate::Literal::U16(value) => {
1807                write!(self.out, "static_cast<ushort>({value})")?;
1808            }
1809            crate::Literal::I16(value) => {
1810                write!(self.out, "static_cast<short>({value})")?;
1811            }
1812            crate::Literal::U32(value) => {
1813                write!(self.out, "{value}u")?;
1814            }
1815            crate::Literal::I32(value) => {
1816                // `-2147483648` is parsed as unary negation of positive 2147483648.
1817                // 2147483648 is too large for int32_t meaning the expression gets
1818                // promoted to a int64_t which is not our intention. Avoid this by instead
1819                // using `-2147483647 - 1`.
1820                if value == i32::MIN {
1821                    write!(self.out, "({} - 1)", value + 1)?;
1822                } else {
1823                    write!(self.out, "{value}")?;
1824                }
1825            }
1826            crate::Literal::U64(value) => {
1827                write!(self.out, "{value}uL")?;
1828            }
1829            crate::Literal::I64(value) => {
1830                // `-9223372036854775808` is parsed as unary negation of positive
1831                // 9223372036854775808. 9223372036854775808 is too large for int64_t
1832                // causing Metal to emit a `-Wconstant-conversion` warning, and change the
1833                // value to `-9223372036854775808`. Which would then be negated, possibly
1834                // causing undefined behaviour. Avoid this by instead using
1835                // `-9223372036854775808L - 1L`.
1836                if value == i64::MIN {
1837                    write!(self.out, "({}L - 1L)", value + 1)?;
1838                } else {
1839                    write!(self.out, "{value}L")?;
1840                }
1841            }
1842            crate::Literal::Bool(value) => {
1843                write!(self.out, "{value}")?;
1844            }
1845            crate::Literal::AbstractInt(_) | crate::Literal::AbstractFloat(_) => {
1846                return Err(Error::GenericValidation(
1847                    "Unsupported abstract literal".into(),
1848                ));
1849            }
1850        }
1851        Ok(())
1852    }
1853
1854    #[allow(clippy::too_many_arguments)]
1855    fn put_possibly_const_expression<C, I, E>(
1856        &mut self,
1857        expr_handle: Handle<crate::Expression>,
1858        expressions: &crate::Arena<crate::Expression>,
1859        module: &crate::Module,
1860        mod_info: &valid::ModuleInfo,
1861        ctx: &C,
1862        get_expr_ty: I,
1863        put_expression: E,
1864    ) -> BackendResult
1865    where
1866        I: Fn(&C, Handle<crate::Expression>) -> &TypeResolution,
1867        E: Fn(&mut Self, &C, Handle<crate::Expression>) -> BackendResult,
1868    {
1869        match expressions[expr_handle] {
1870            crate::Expression::Literal(literal) => {
1871                self.put_literal(literal)?;
1872            }
1873            crate::Expression::Constant(handle) => {
1874                let constant = &module.constants[handle];
1875                if constant.name.is_some() {
1876                    write!(self.out, "{}", self.names[&NameKey::Constant(handle)])?;
1877                } else {
1878                    self.put_const_expression(
1879                        constant.init,
1880                        module,
1881                        mod_info,
1882                        &module.global_expressions,
1883                    )?;
1884                }
1885            }
1886            crate::Expression::ZeroValue(ty) => {
1887                let ty_name = TypeContext {
1888                    handle: ty,
1889                    gctx: module.to_ctx(),
1890                    names: &self.names,
1891                    access: crate::StorageAccess::empty(),
1892                    first_time: false,
1893                };
1894                write!(self.out, "{ty_name} {{}}")?;
1895            }
1896            crate::Expression::Compose { ty, ref components } => {
1897                let ty_name = TypeContext {
1898                    handle: ty,
1899                    gctx: module.to_ctx(),
1900                    names: &self.names,
1901                    access: crate::StorageAccess::empty(),
1902                    first_time: false,
1903                };
1904                write!(self.out, "{ty_name}")?;
1905                match module.types[ty].inner {
1906                    crate::TypeInner::Scalar(_)
1907                    | crate::TypeInner::Vector { .. }
1908                    | crate::TypeInner::Matrix { .. } => {
1909                        self.put_call_parameters_impl(
1910                            components.iter().copied(),
1911                            ctx,
1912                            put_expression,
1913                        )?;
1914                    }
1915                    crate::TypeInner::Array { .. } => {
1916                        // Naga Arrays are Metal arrays wrapped in structs, so
1917                        // we need two levels of braces.
1918                        write!(self.out, " {{{{")?;
1919                        for (index, &component) in components.iter().enumerate() {
1920                            if index != 0 {
1921                                write!(self.out, ", ")?;
1922                            }
1923                            put_expression(self, ctx, component)?;
1924                        }
1925                        write!(self.out, "}}}}")?;
1926                    }
1927                    crate::TypeInner::Struct { .. } => {
1928                        write!(self.out, " {{")?;
1929                        for (index, &component) in components.iter().enumerate() {
1930                            if index != 0 {
1931                                write!(self.out, ", ")?;
1932                            }
1933                            // insert padding initialization, if needed
1934                            if self.struct_member_pads.contains(&(ty, index as u32)) {
1935                                write!(self.out, "{{}}, ")?;
1936                            }
1937                            put_expression(self, ctx, component)?;
1938                        }
1939                        write!(self.out, "}}")?;
1940                    }
1941                    _ => return Err(Error::UnsupportedCompose(ty)),
1942                }
1943            }
1944            crate::Expression::Splat { size, value } => {
1945                let scalar = match *get_expr_ty(ctx, value).inner_with(&module.types) {
1946                    crate::TypeInner::Scalar(scalar) => scalar,
1947                    ref ty => {
1948                        return Err(Error::GenericValidation(format!(
1949                            "Expected splat value type must be a scalar, got {ty:?}",
1950                        )))
1951                    }
1952                };
1953                put_numeric_type(&mut self.out, scalar, &[size])?;
1954                write!(self.out, "(")?;
1955                put_expression(self, ctx, value)?;
1956                write!(self.out, ")")?;
1957            }
1958            _ => {
1959                return Err(Error::Override);
1960            }
1961        }
1962
1963        Ok(())
1964    }
1965
1966    /// Emit code for the expression `expr_handle`.
1967    ///
1968    /// The `is_scoped` argument is true if the surrounding operators have the
1969    /// precedence of the comma operator, or lower. So, for example:
1970    ///
1971    /// - Pass `true` for `is_scoped` when writing function arguments, an
1972    ///   expression statement, an initializer expression, or anything already
1973    ///   wrapped in parenthesis.
1974    ///
1975    /// - Pass `false` if it is an operand of a `?:` operator, a `[]`, or really
1976    ///   almost anything else.
1977    pub(super) fn put_expression(
1978        &mut self,
1979        expr_handle: Handle<crate::Expression>,
1980        context: &ExpressionContext,
1981        is_scoped: bool,
1982    ) -> BackendResult {
1983        // Add to the set in order to track the stack size.
1984        #[cfg(test)]
1985        self.put_expression_stack_pointers
1986            .insert(ptr::from_ref(&expr_handle).cast());
1987
1988        if let Some(name) = self.named_expressions.get(&expr_handle) {
1989            write!(self.out, "{name}")?;
1990            return Ok(());
1991        }
1992
1993        let expression = &context.function.expressions[expr_handle];
1994        match *expression {
1995            crate::Expression::Literal(_)
1996            | crate::Expression::Constant(_)
1997            | crate::Expression::ZeroValue(_)
1998            | crate::Expression::Compose { .. }
1999            | crate::Expression::Splat { .. } => {
2000                self.put_possibly_const_expression(
2001                    expr_handle,
2002                    &context.function.expressions,
2003                    context.module,
2004                    context.mod_info,
2005                    context,
2006                    |context, expr: Handle<crate::Expression>| &context.info[expr].ty,
2007                    |writer, context, expr| writer.put_expression(expr, context, true),
2008                )?;
2009            }
2010            crate::Expression::Override(_) => return Err(Error::Override),
2011            crate::Expression::Access { base, .. }
2012            | crate::Expression::AccessIndex { base, .. } => {
2013                // This is an acceptable place to generate a `ReadZeroSkipWrite` check.
2014                // Since `put_bounds_checks` and `put_access_chain` handle an entire
2015                // access chain at a time, recursing back through `put_expression` only
2016                // for index expressions and the base object, we will never see intermediate
2017                // `Access` or `AccessIndex` expressions here.
2018                let policy = context.choose_bounds_check_policy(base);
2019                if policy == index::BoundsCheckPolicy::ReadZeroSkipWrite
2020                    && self.put_bounds_checks(
2021                        expr_handle,
2022                        context,
2023                        back::Level(0),
2024                        if is_scoped { "" } else { "(" },
2025                    )?
2026                {
2027                    write!(self.out, " ? ")?;
2028                    self.put_access_chain(expr_handle, policy, context)?;
2029                    write!(self.out, " : ")?;
2030
2031                    if context.resolve_type(base).pointer_space().is_some() {
2032                        // We can't just use `DefaultConstructible` if this is a pointer.
2033                        // Instead, we create a dummy local variable to serve as pointer
2034                        // target if the access is out of bounds.
2035                        let result_ty = context.info[expr_handle]
2036                            .ty
2037                            .inner_with(&context.module.types)
2038                            .pointer_base_type();
2039                        let result_ty_handle = match result_ty {
2040                            Some(TypeResolution::Handle(handle)) => handle,
2041                            Some(TypeResolution::Value(_)) => {
2042                                // As long as the result of a pointer access expression is
2043                                // passed to a function or stored in a let binding, the
2044                                // type will be in the arena. If additional uses of
2045                                // pointers become valid, this assumption might no longer
2046                                // hold. Note that the LHS of a load or store doesn't
2047                                // take this path -- there is dedicated code in `put_load`
2048                                // and `put_store`.
2049                                unreachable!(
2050                                    "Expected type {result_ty:?} of access through pointer type {base:?} to be in the arena",
2051                                );
2052                            }
2053                            None => {
2054                                unreachable!(
2055                                    "Expected access through pointer type {base:?} to return a pointer, but got {result_ty:?}",
2056                                )
2057                            }
2058                        };
2059                        let name_key =
2060                            NameKey::oob_local_for_type(context.origin, result_ty_handle);
2061                        self.out.write_str(&self.names[&name_key])?;
2062                    } else {
2063                        write!(self.out, "DefaultConstructible()")?;
2064                    }
2065
2066                    if !is_scoped {
2067                        write!(self.out, ")")?;
2068                    }
2069                } else {
2070                    self.put_access_chain(expr_handle, policy, context)?;
2071                }
2072            }
2073            crate::Expression::Swizzle {
2074                size,
2075                vector,
2076                pattern,
2077            } => {
2078                self.put_wrapped_expression_for_packed_vec3_access(
2079                    vector,
2080                    context,
2081                    false,
2082                    &Self::put_expression,
2083                )?;
2084                write!(self.out, ".")?;
2085                for &sc in pattern[..size as usize].iter() {
2086                    write!(self.out, "{}", back::COMPONENTS[sc as usize])?;
2087                }
2088            }
2089            crate::Expression::FunctionArgument(index) => {
2090                let name_key = match context.origin {
2091                    FunctionOrigin::Handle(handle) => NameKey::FunctionArgument(handle, index),
2092                    FunctionOrigin::EntryPoint(ep_index) => {
2093                        NameKey::EntryPointArgument(ep_index, index)
2094                    }
2095                };
2096                let name = &self.names[&name_key];
2097                write!(self.out, "{name}")?;
2098            }
2099            crate::Expression::GlobalVariable(handle) => {
2100                let name = &self.names[&NameKey::GlobalVariable(handle)];
2101                write!(self.out, "{name}")?;
2102            }
2103            crate::Expression::LocalVariable(handle) => {
2104                let name_key = NameKey::local(context.origin, handle);
2105                let name = &self.names[&name_key];
2106                write!(self.out, "{name}")?;
2107            }
2108            crate::Expression::Load { pointer } => self.put_load(pointer, context, is_scoped)?,
2109            crate::Expression::ImageSample {
2110                coordinate,
2111                image,
2112                sampler,
2113                clamp_to_edge: true,
2114                gather: None,
2115                array_index: None,
2116                offset: None,
2117                level: crate::SampleLevel::Zero,
2118                depth_ref: None,
2119            } => {
2120                write!(self.out, "{IMAGE_SAMPLE_BASE_CLAMP_TO_EDGE_FUNCTION}(")?;
2121                self.put_expression(image, context, true)?;
2122                write!(self.out, ", ")?;
2123                self.put_expression(sampler, context, true)?;
2124                write!(self.out, ", ")?;
2125                self.put_expression(coordinate, context, true)?;
2126                write!(self.out, ")")?;
2127            }
2128            crate::Expression::ImageSample {
2129                image,
2130                sampler,
2131                gather,
2132                coordinate,
2133                array_index,
2134                offset,
2135                level,
2136                depth_ref,
2137                clamp_to_edge,
2138            } => {
2139                if clamp_to_edge {
2140                    return Err(Error::GenericValidation(
2141                        "ImageSample::clamp_to_edge should have been validated out".to_string(),
2142                    ));
2143                }
2144
2145                let main_op = match gather {
2146                    Some(_) => "gather",
2147                    None => "sample",
2148                };
2149                let comparison_op = match depth_ref {
2150                    Some(_) => "_compare",
2151                    None => "",
2152                };
2153                self.put_expression(image, context, false)?;
2154                write!(self.out, ".{main_op}{comparison_op}(")?;
2155                self.put_expression(sampler, context, true)?;
2156                write!(self.out, ", ")?;
2157                self.put_expression(coordinate, context, true)?;
2158                if let Some(expr) = array_index {
2159                    write!(self.out, ", ")?;
2160                    self.put_expression(expr, context, true)?;
2161                }
2162                if let Some(dref) = depth_ref {
2163                    write!(self.out, ", ")?;
2164                    self.put_expression(dref, context, true)?;
2165                }
2166
2167                self.put_image_sample_level(image, level, context)?;
2168
2169                if let Some(offset) = offset {
2170                    write!(self.out, ", ")?;
2171                    self.put_expression(offset, context, true)?;
2172                }
2173
2174                match gather {
2175                    None | Some(crate::SwizzleComponent::X) => {}
2176                    Some(component) => {
2177                        let is_cube_map = match *context.resolve_type(image) {
2178                            crate::TypeInner::Image {
2179                                dim: crate::ImageDimension::Cube,
2180                                ..
2181                            } => true,
2182                            _ => false,
2183                        };
2184                        // Offset always comes before the gather, except
2185                        // in cube maps where it's not applicable
2186                        if offset.is_none() && !is_cube_map {
2187                            write!(self.out, ", {NAMESPACE}::int2(0)")?;
2188                        }
2189                        let letter = back::COMPONENTS[component as usize];
2190                        write!(self.out, ", {NAMESPACE}::component::{letter}")?;
2191                    }
2192                }
2193                write!(self.out, ")")?;
2194            }
2195            crate::Expression::ImageLoad {
2196                image,
2197                coordinate,
2198                array_index,
2199                sample,
2200                level,
2201            } => {
2202                let address = TexelAddress {
2203                    coordinate,
2204                    array_index,
2205                    sample,
2206                    level: level.map(LevelOfDetail::Direct),
2207                };
2208                self.put_image_load(expr_handle, image, address, context)?;
2209            }
2210            //Note: for all the queries, the signed integers are expected,
2211            // so a conversion is needed.
2212            crate::Expression::ImageQuery { image, query } => match query {
2213                crate::ImageQuery::Size { level } => {
2214                    self.put_image_size_query(
2215                        image,
2216                        level.map(LevelOfDetail::Direct),
2217                        crate::ScalarKind::Uint,
2218                        context,
2219                    )?;
2220                }
2221                crate::ImageQuery::NumLevels => {
2222                    self.put_expression(image, context, false)?;
2223                    write!(self.out, ".get_num_mip_levels()")?;
2224                }
2225                crate::ImageQuery::NumLayers => {
2226                    self.put_expression(image, context, false)?;
2227                    write!(self.out, ".get_array_size()")?;
2228                }
2229                crate::ImageQuery::NumSamples => {
2230                    self.put_expression(image, context, false)?;
2231                    write!(self.out, ".get_num_samples()")?;
2232                }
2233            },
2234            crate::Expression::Unary { op, expr } => {
2235                let op_str = match op {
2236                    crate::UnaryOperator::Negate => {
2237                        match context.resolve_type(expr).scalar_kind() {
2238                            Some(crate::ScalarKind::Sint) => NEG_FUNCTION,
2239                            _ => "-",
2240                        }
2241                    }
2242                    crate::UnaryOperator::LogicalNot => "!",
2243                    crate::UnaryOperator::BitwiseNot => "~",
2244                };
2245                write!(self.out, "{op_str}(")?;
2246                self.put_expression(expr, context, false)?;
2247                write!(self.out, ")")?;
2248            }
2249            crate::Expression::Binary { op, left, right } => {
2250                let kind = context
2251                    .resolve_type(left)
2252                    .scalar_kind()
2253                    .ok_or(Error::UnsupportedBinaryOp(op))?;
2254
2255                if op == crate::BinaryOperator::Divide
2256                    && (kind == crate::ScalarKind::Sint || kind == crate::ScalarKind::Uint)
2257                {
2258                    write!(self.out, "{DIV_FUNCTION}(")?;
2259                    self.put_expression(left, context, true)?;
2260                    write!(self.out, ", ")?;
2261                    self.put_expression(right, context, true)?;
2262                    write!(self.out, ")")?;
2263                } else if op == crate::BinaryOperator::Modulo
2264                    && (kind == crate::ScalarKind::Sint || kind == crate::ScalarKind::Uint)
2265                {
2266                    write!(self.out, "{MOD_FUNCTION}(")?;
2267                    self.put_expression(left, context, true)?;
2268                    write!(self.out, ", ")?;
2269                    self.put_expression(right, context, true)?;
2270                    write!(self.out, ")")?;
2271                } else if op == crate::BinaryOperator::Modulo && kind == crate::ScalarKind::Float {
2272                    // TODO: handle undefined behavior of BinaryOperator::Modulo
2273                    //
2274                    // float:
2275                    // if right == 0 return ? see https://github.com/gpuweb/gpuweb/issues/2798
2276                    write!(self.out, "{NAMESPACE}::fmod(")?;
2277                    self.put_expression(left, context, true)?;
2278                    write!(self.out, ", ")?;
2279                    self.put_expression(right, context, true)?;
2280                    write!(self.out, ")")?;
2281                } else if (op == crate::BinaryOperator::Add
2282                    || op == crate::BinaryOperator::Subtract
2283                    || op == crate::BinaryOperator::Multiply)
2284                    && kind == crate::ScalarKind::Sint
2285                {
2286                    let to_unsigned = |ty: &crate::TypeInner| match *ty {
2287                        crate::TypeInner::Scalar(scalar) => {
2288                            Ok(crate::TypeInner::Scalar(crate::Scalar {
2289                                kind: crate::ScalarKind::Uint,
2290                                ..scalar
2291                            }))
2292                        }
2293                        crate::TypeInner::Vector { size, scalar } => Ok(crate::TypeInner::Vector {
2294                            size,
2295                            scalar: crate::Scalar {
2296                                kind: crate::ScalarKind::Uint,
2297                                ..scalar
2298                            },
2299                        }),
2300                        _ => Err(Error::UnsupportedBitCast(ty.clone())),
2301                    };
2302
2303                    // Avoid undefined behaviour due to overflowing signed
2304                    // integer arithmetic. Cast the operands to unsigned prior
2305                    // to performing the operation, then cast the result back
2306                    // to signed.
2307                    self.put_bitcasted_expression(
2308                        context.resolve_type(expr_handle),
2309                        expr_handle,
2310                        context,
2311                        &|writer, context, is_scoped| {
2312                            writer.put_binop(
2313                                op,
2314                                left,
2315                                right,
2316                                context,
2317                                is_scoped,
2318                                &|writer, expr, context, _is_scoped| {
2319                                    writer.put_bitcasted_expression(
2320                                        &to_unsigned(context.resolve_type(expr))?,
2321                                        expr,
2322                                        context,
2323                                        &|writer, context, is_scoped| {
2324                                            writer.put_expression(expr, context, is_scoped)
2325                                        },
2326                                    )
2327                                },
2328                            )
2329                        },
2330                    )?;
2331                } else {
2332                    self.put_binop(op, left, right, context, is_scoped, &Self::put_expression)?;
2333                }
2334            }
2335            crate::Expression::Select {
2336                condition,
2337                accept,
2338                reject,
2339            } => match *context.resolve_type(condition) {
2340                crate::TypeInner::Scalar(crate::Scalar {
2341                    kind: crate::ScalarKind::Bool,
2342                    ..
2343                }) => {
2344                    if !is_scoped {
2345                        write!(self.out, "(")?;
2346                    }
2347                    self.put_expression(condition, context, false)?;
2348                    write!(self.out, " ? ")?;
2349                    self.put_expression(accept, context, false)?;
2350                    write!(self.out, " : ")?;
2351                    self.put_expression(reject, context, false)?;
2352                    if !is_scoped {
2353                        write!(self.out, ")")?;
2354                    }
2355                }
2356                crate::TypeInner::Vector {
2357                    scalar:
2358                        crate::Scalar {
2359                            kind: crate::ScalarKind::Bool,
2360                            ..
2361                        },
2362                    ..
2363                } => {
2364                    write!(self.out, "{NAMESPACE}::select(")?;
2365                    self.put_expression(reject, context, true)?;
2366                    write!(self.out, ", ")?;
2367                    self.put_expression(accept, context, true)?;
2368                    write!(self.out, ", ")?;
2369                    self.put_expression(condition, context, true)?;
2370                    write!(self.out, ")")?;
2371                }
2372                ref ty => {
2373                    return Err(Error::GenericValidation(format!(
2374                        "Expected select condition to be a non-bool type, got {ty:?}",
2375                    )))
2376                }
2377            },
2378            crate::Expression::Derivative { axis, expr, .. } => {
2379                use crate::DerivativeAxis as Axis;
2380                let op = match axis {
2381                    Axis::X => "dfdx",
2382                    Axis::Y => "dfdy",
2383                    Axis::Width => "fwidth",
2384                };
2385                write!(self.out, "{NAMESPACE}::{op}")?;
2386                self.put_call_parameters(iter::once(expr), context)?;
2387            }
2388            crate::Expression::Relational { fun, argument } => {
2389                let op = match fun {
2390                    crate::RelationalFunction::Any => "any",
2391                    crate::RelationalFunction::All => "all",
2392                    crate::RelationalFunction::IsNan => "isnan",
2393                    crate::RelationalFunction::IsInf => "isinf",
2394                };
2395                write!(self.out, "{NAMESPACE}::{op}")?;
2396                self.put_call_parameters(iter::once(argument), context)?;
2397            }
2398            crate::Expression::Math {
2399                fun,
2400                arg,
2401                arg1,
2402                arg2,
2403                arg3,
2404            } => {
2405                use crate::MathFunction as Mf;
2406
2407                let arg_type = context.resolve_type(arg);
2408                let scalar_argument = match arg_type {
2409                    &crate::TypeInner::Scalar(_) => true,
2410                    _ => false,
2411                };
2412
2413                let fun_name = match fun {
2414                    // comparison
2415                    Mf::Abs => "abs",
2416                    Mf::Min => "min",
2417                    Mf::Max => "max",
2418                    Mf::Clamp => "clamp",
2419                    Mf::Saturate => "saturate",
2420                    // trigonometry
2421                    Mf::Cos => "cos",
2422                    Mf::Cosh => "cosh",
2423                    Mf::Sin => "sin",
2424                    Mf::Sinh => "sinh",
2425                    Mf::Tan => "tan",
2426                    Mf::Tanh => "tanh",
2427                    Mf::Acos => "acos",
2428                    Mf::Asin => "asin",
2429                    Mf::Atan => "atan",
2430                    Mf::Atan2 => "atan2",
2431                    Mf::Asinh => "asinh",
2432                    Mf::Acosh => "acosh",
2433                    Mf::Atanh => "atanh",
2434                    Mf::Radians => "",
2435                    Mf::Degrees => "",
2436                    // decomposition
2437                    Mf::Ceil => "ceil",
2438                    Mf::Floor => "floor",
2439                    Mf::Round => "rint",
2440                    Mf::Fract => "fract",
2441                    Mf::Trunc => "trunc",
2442                    Mf::Modf => MODF_FUNCTION,
2443                    Mf::Frexp => FREXP_FUNCTION,
2444                    Mf::Ldexp => "ldexp",
2445                    // exponent
2446                    Mf::Exp => "exp",
2447                    Mf::Exp2 => "exp2",
2448                    Mf::Log => "log",
2449                    Mf::Log2 => "log2",
2450                    Mf::Pow => "pow",
2451                    // geometry
2452                    Mf::Dot => match *context.resolve_type(arg) {
2453                        crate::TypeInner::Vector {
2454                            scalar:
2455                                crate::Scalar {
2456                                    // Resolve float values to MSL's builtin dot function.
2457                                    kind: crate::ScalarKind::Float,
2458                                    ..
2459                                },
2460                            ..
2461                        } => "dot",
2462                        crate::TypeInner::Vector {
2463                            size,
2464                            scalar:
2465                                scalar @ crate::Scalar {
2466                                    kind: crate::ScalarKind::Sint | crate::ScalarKind::Uint,
2467                                    ..
2468                                },
2469                        } => {
2470                            // Integer vector dot: call our mangled helper `dot_{type}{N}(a, b)`.
2471                            let fun_name = self.get_dot_wrapper_function_helper_name(scalar, size);
2472                            write!(self.out, "{fun_name}(")?;
2473                            self.put_expression(arg, context, true)?;
2474                            write!(self.out, ", ")?;
2475                            self.put_expression(arg1.unwrap(), context, true)?;
2476                            write!(self.out, ")")?;
2477                            return Ok(());
2478                        }
2479                        _ => unreachable!(
2480                            "Correct TypeInner for dot product should be already validated"
2481                        ),
2482                    },
2483                    fun @ (Mf::Dot4I8Packed | Mf::Dot4U8Packed) => {
2484                        if context.lang_version >= (2, 1) {
2485                            // Write potentially optimizable code using `packed_(u?)char4`.
2486                            // The two function arguments were already reinterpreted as packed (signed
2487                            // or unsigned) chars in `Self::put_block`.
2488                            let packed_type = match fun {
2489                                Mf::Dot4I8Packed => "packed_char4",
2490                                Mf::Dot4U8Packed => "packed_uchar4",
2491                                _ => unreachable!(),
2492                            };
2493
2494                            return self.put_dot_product(
2495                                Reinterpreted::new(packed_type, arg),
2496                                Reinterpreted::new(packed_type, arg1.unwrap()),
2497                                4,
2498                                |writer, arg, index| {
2499                                    // MSL implicitly promotes these (signed or unsigned) chars to
2500                                    // `int` or `uint` in the multiplication, so no overflow can occur.
2501                                    write!(writer.out, "{arg}[{index}]")?;
2502                                    Ok(())
2503                                },
2504                            );
2505                        } else {
2506                            // Fall back to a polyfill since MSL < 2.1 doesn't seem to support
2507                            // bitcasting from uint to `packed_char4` or `packed_uchar4`.
2508                            // See <https://github.com/gfx-rs/wgpu/pull/7574#issuecomment-2835464472>.
2509                            let conversion = match fun {
2510                                Mf::Dot4I8Packed => "int",
2511                                Mf::Dot4U8Packed => "",
2512                                _ => unreachable!(),
2513                            };
2514
2515                            return self.put_dot_product(
2516                                arg,
2517                                arg1.unwrap(),
2518                                4,
2519                                |writer, arg, index| {
2520                                    write!(writer.out, "({conversion}(")?;
2521                                    writer.put_expression(arg, context, true)?;
2522                                    if index == 3 {
2523                                        write!(writer.out, ") >> 24)")?;
2524                                    } else {
2525                                        write!(writer.out, ") << {} >> 24)", (3 - index) * 8)?;
2526                                    }
2527                                    Ok(())
2528                                },
2529                            );
2530                        }
2531                    }
2532                    Mf::Outer => return Err(Error::UnsupportedCall(format!("{fun:?}"))),
2533                    Mf::Cross => "cross",
2534                    Mf::Distance => "distance",
2535                    Mf::Length if scalar_argument => "abs",
2536                    Mf::Length => "length",
2537                    Mf::Normalize => "normalize",
2538                    Mf::FaceForward => "faceforward",
2539                    Mf::Reflect => "reflect",
2540                    Mf::Refract => "refract",
2541                    // computational
2542                    Mf::Sign => match arg_type.scalar_kind() {
2543                        Some(crate::ScalarKind::Sint) => {
2544                            return self.put_isign(arg, context);
2545                        }
2546                        _ => "sign",
2547                    },
2548                    Mf::Fma => "fma",
2549                    Mf::Mix => "mix",
2550                    Mf::Step => "step",
2551                    Mf::SmoothStep => "smoothstep",
2552                    Mf::Sqrt => "sqrt",
2553                    Mf::InverseSqrt => "rsqrt",
2554                    Mf::Inverse => return Err(Error::UnsupportedCall(format!("{fun:?}"))),
2555                    Mf::Transpose => "transpose",
2556                    Mf::Determinant => "determinant",
2557                    Mf::QuantizeToF16 => "",
2558                    // bits
2559                    Mf::CountTrailingZeros => "ctz",
2560                    Mf::CountLeadingZeros => "clz",
2561                    Mf::CountOneBits => "popcount",
2562                    Mf::ReverseBits => "reverse_bits",
2563                    Mf::ExtractBits => "",
2564                    Mf::InsertBits => "",
2565                    Mf::FirstTrailingBit => "",
2566                    Mf::FirstLeadingBit => "",
2567                    // data packing
2568                    Mf::Pack4x8snorm => "pack_float_to_snorm4x8",
2569                    Mf::Pack4x8unorm => "pack_float_to_unorm4x8",
2570                    Mf::Pack2x16snorm => "pack_float_to_snorm2x16",
2571                    Mf::Pack2x16unorm => "pack_float_to_unorm2x16",
2572                    Mf::Pack2x16float => "",
2573                    Mf::Pack4xI8 => "",
2574                    Mf::Pack4xU8 => "",
2575                    Mf::Pack4xI8Clamp => "",
2576                    Mf::Pack4xU8Clamp => "",
2577                    // data unpacking
2578                    Mf::Unpack4x8snorm => "unpack_snorm4x8_to_float",
2579                    Mf::Unpack4x8unorm => "unpack_unorm4x8_to_float",
2580                    Mf::Unpack2x16snorm => "unpack_snorm2x16_to_float",
2581                    Mf::Unpack2x16unorm => "unpack_unorm2x16_to_float",
2582                    Mf::Unpack2x16float => "",
2583                    Mf::Unpack4xI8 => "",
2584                    Mf::Unpack4xU8 => "",
2585                };
2586
2587                match fun {
2588                    Mf::ReverseBits | Mf::ExtractBits | Mf::InsertBits => {
2589                        // reverse_bits is listed as requiring MSL 2.1 but that
2590                        // is a copy/paste error. Looking at previous snapshots
2591                        // on web.archive.org it's present in MSL 1.2.
2592                        //
2593                        // https://developer.apple.com/library/archive/documentation/Miscellaneous/Conceptual/MetalProgrammingGuide/WhatsNewiniOS10tvOS10andOSX1012/WhatsNewiniOS10tvOS10andOSX1012.html
2594                        // also talks about MSL 1.2 adding "New integer
2595                        // functions to extract, insert, and reverse bits, as
2596                        // described in Integer Functions."
2597                        if context.lang_version < (1, 2) {
2598                            return Err(Error::UnsupportedFunction(fun_name.to_string()));
2599                        }
2600                    }
2601                    _ => {}
2602                }
2603
2604                match fun {
2605                    Mf::Abs if arg_type.scalar_kind() == Some(crate::ScalarKind::Sint) => {
2606                        write!(self.out, "{ABS_FUNCTION}(")?;
2607                        self.put_expression(arg, context, true)?;
2608                        write!(self.out, ")")?;
2609                    }
2610                    Mf::Distance if scalar_argument => {
2611                        write!(self.out, "{NAMESPACE}::abs(")?;
2612                        self.put_expression(arg, context, false)?;
2613                        write!(self.out, " - ")?;
2614                        self.put_expression(arg1.unwrap(), context, false)?;
2615                        write!(self.out, ")")?;
2616                    }
2617                    Mf::FirstTrailingBit => {
2618                        let scalar = context.resolve_type(arg).scalar().unwrap();
2619                        let constant = scalar.width * 8 + 1;
2620
2621                        write!(self.out, "((({NAMESPACE}::ctz(")?;
2622                        self.put_expression(arg, context, true)?;
2623                        write!(self.out, ") + 1) % {constant}) - 1)")?;
2624                    }
2625                    Mf::FirstLeadingBit => {
2626                        let inner = context.resolve_type(arg);
2627                        let scalar = inner.scalar().unwrap();
2628                        let constant = scalar.width * 8 - 1;
2629
2630                        write!(
2631                            self.out,
2632                            "{NAMESPACE}::select({constant} - {NAMESPACE}::clz("
2633                        )?;
2634
2635                        if scalar.kind == crate::ScalarKind::Sint {
2636                            write!(self.out, "{NAMESPACE}::select(")?;
2637                            self.put_expression(arg, context, true)?;
2638                            write!(self.out, ", ~")?;
2639                            self.put_expression(arg, context, true)?;
2640                            write!(self.out, ", ")?;
2641                            self.put_expression(arg, context, true)?;
2642                            write!(self.out, " < 0)")?;
2643                        } else {
2644                            self.put_expression(arg, context, true)?;
2645                        }
2646
2647                        write!(self.out, "), ")?;
2648
2649                        // or metal will complain that select is ambiguous
2650                        match *inner {
2651                            crate::TypeInner::Vector { size, scalar } => {
2652                                let size = common::vector_size_str(size);
2653                                let name = scalar.to_msl_name();
2654                                write!(self.out, "{name}{size}")?;
2655                            }
2656                            crate::TypeInner::Scalar(scalar) => {
2657                                let name = scalar.to_msl_name();
2658                                write!(self.out, "{name}")?;
2659                            }
2660                            _ => (),
2661                        }
2662
2663                        write!(self.out, "(-1), ")?;
2664                        self.put_expression(arg, context, true)?;
2665                        write!(self.out, " == 0")?;
2666                        if scalar.kind == crate::ScalarKind::Sint {
2667                            write!(self.out, " || ")?;
2668                            self.put_expression(arg, context, true)?;
2669                            write!(self.out, " == -1")?;
2670                        }
2671                        write!(self.out, ")")?;
2672                    }
2673                    Mf::Unpack2x16float => {
2674                        write!(self.out, "float2(as_type<half2>(")?;
2675                        self.put_expression(arg, context, false)?;
2676                        write!(self.out, "))")?;
2677                    }
2678                    Mf::Pack2x16float => {
2679                        write!(self.out, "as_type<uint>(half2(")?;
2680                        self.put_expression(arg, context, false)?;
2681                        write!(self.out, "))")?;
2682                    }
2683                    Mf::ExtractBits => {
2684                        // The behavior of ExtractBits is undefined when offset + count > bit_width. We need
2685                        // to first sanitize the offset and count first. If we don't do this, Apple chips
2686                        // will return out-of-spec values if the extracted range is not within the bit width.
2687                        //
2688                        // This encodes the exact formula specified by the wgsl spec, without temporary values:
2689                        // https://gpuweb.github.io/gpuweb/wgsl/#extractBits-unsigned-builtin
2690                        //
2691                        // w = sizeof(x) * 8
2692                        // o = min(offset, w)
2693                        // tmp = w - o
2694                        // c = min(count, tmp)
2695                        //
2696                        // bitfieldExtract(x, o, c)
2697                        //
2698                        // extract_bits(e, min(offset, w), min(count, w - min(offset, w))))
2699
2700                        let scalar_bits = context.resolve_type(arg).scalar_width().unwrap() * 8;
2701
2702                        write!(self.out, "{NAMESPACE}::extract_bits(")?;
2703                        self.put_expression(arg, context, true)?;
2704                        write!(self.out, ", {NAMESPACE}::min(")?;
2705                        self.put_expression(arg1.unwrap(), context, true)?;
2706                        write!(self.out, ", {scalar_bits}u), {NAMESPACE}::min(")?;
2707                        self.put_expression(arg2.unwrap(), context, true)?;
2708                        write!(self.out, ", {scalar_bits}u - {NAMESPACE}::min(")?;
2709                        self.put_expression(arg1.unwrap(), context, true)?;
2710                        write!(self.out, ", {scalar_bits}u)))")?;
2711                    }
2712                    Mf::InsertBits => {
2713                        // The behavior of InsertBits has the same issue as ExtractBits.
2714                        //
2715                        // insertBits(e, newBits, min(offset, w), min(count, w - min(offset, w))))
2716
2717                        let scalar_bits = context.resolve_type(arg).scalar_width().unwrap() * 8;
2718
2719                        write!(self.out, "{NAMESPACE}::insert_bits(")?;
2720                        self.put_expression(arg, context, true)?;
2721                        write!(self.out, ", ")?;
2722                        self.put_expression(arg1.unwrap(), context, true)?;
2723                        write!(self.out, ", {NAMESPACE}::min(")?;
2724                        self.put_expression(arg2.unwrap(), context, true)?;
2725                        write!(self.out, ", {scalar_bits}u), {NAMESPACE}::min(")?;
2726                        self.put_expression(arg3.unwrap(), context, true)?;
2727                        write!(self.out, ", {scalar_bits}u - {NAMESPACE}::min(")?;
2728                        self.put_expression(arg2.unwrap(), context, true)?;
2729                        write!(self.out, ", {scalar_bits}u)))")?;
2730                    }
2731                    Mf::Radians => {
2732                        write!(self.out, "((")?;
2733                        self.put_expression(arg, context, false)?;
2734                        write!(self.out, ") * 0.017453292519943295474)")?;
2735                    }
2736                    Mf::Degrees => {
2737                        write!(self.out, "((")?;
2738                        self.put_expression(arg, context, false)?;
2739                        write!(self.out, ") * 57.295779513082322865)")?;
2740                    }
2741                    Mf::Modf | Mf::Frexp => {
2742                        write!(self.out, "{fun_name}")?;
2743                        self.put_call_parameters(iter::once(arg), context)?;
2744                    }
2745                    Mf::Pack4xI8 => self.put_pack4x8(arg, context, true, None)?,
2746                    Mf::Pack4xU8 => self.put_pack4x8(arg, context, false, None)?,
2747                    Mf::Pack4xI8Clamp => {
2748                        self.put_pack4x8(arg, context, true, Some(("-128", "127")))?
2749                    }
2750                    Mf::Pack4xU8Clamp => {
2751                        self.put_pack4x8(arg, context, false, Some(("0", "255")))?
2752                    }
2753                    fun @ (Mf::Unpack4xI8 | Mf::Unpack4xU8) => {
2754                        let sign_prefix = if matches!(fun, Mf::Unpack4xU8) {
2755                            "u"
2756                        } else {
2757                            ""
2758                        };
2759
2760                        if context.lang_version >= (2, 1) {
2761                            // Metal uses little endian byte order, which matches what WGSL expects here.
2762                            write!(
2763                                self.out,
2764                                "{sign_prefix}int4(as_type<packed_{sign_prefix}char4>("
2765                            )?;
2766                            self.put_expression(arg, context, true)?;
2767                            write!(self.out, "))")?;
2768                        } else {
2769                            // MSL < 2.1 doesn't support `as_type` casting between packed chars and scalars.
2770                            write!(self.out, "({sign_prefix}int4(")?;
2771                            self.put_expression(arg, context, true)?;
2772                            write!(self.out, ", ")?;
2773                            self.put_expression(arg, context, true)?;
2774                            write!(self.out, " >> 8, ")?;
2775                            self.put_expression(arg, context, true)?;
2776                            write!(self.out, " >> 16, ")?;
2777                            self.put_expression(arg, context, true)?;
2778                            write!(self.out, " >> 24) << 24 >> 24)")?;
2779                        }
2780                    }
2781                    Mf::QuantizeToF16 => {
2782                        match *context.resolve_type(arg) {
2783                            crate::TypeInner::Scalar { .. } => write!(self.out, "float(half(")?,
2784                            crate::TypeInner::Vector { size, .. } => write!(
2785                                self.out,
2786                                "{NAMESPACE}::float{size}({NAMESPACE}::half{size}(",
2787                                size = common::vector_size_str(size),
2788                            )?,
2789                            _ => unreachable!(
2790                                "Correct TypeInner for QuantizeToF16 should be already validated"
2791                            ),
2792                        };
2793
2794                        self.put_expression(arg, context, true)?;
2795                        write!(self.out, "))")?;
2796                    }
2797                    _ => {
2798                        write!(self.out, "{NAMESPACE}::{fun_name}")?;
2799                        self.put_call_parameters(
2800                            iter::once(arg).chain(arg1).chain(arg2).chain(arg3),
2801                            context,
2802                        )?;
2803                    }
2804                }
2805            }
2806            crate::Expression::As {
2807                expr,
2808                kind,
2809                convert,
2810            } => match *context.resolve_type(expr) {
2811                crate::TypeInner::Scalar(src) | crate::TypeInner::Vector { scalar: src, .. } => {
2812                    if src.kind == crate::ScalarKind::Float
2813                        && (kind == crate::ScalarKind::Sint || kind == crate::ScalarKind::Uint)
2814                        && convert.is_some()
2815                    {
2816                        // Use helper functions for float to int casts in order to avoid
2817                        // undefined behaviour when value is out of range for the target
2818                        // type.
2819                        let fun_name = match (kind, convert) {
2820                            (crate::ScalarKind::Sint, Some(4)) => F2I32_FUNCTION,
2821                            (crate::ScalarKind::Uint, Some(4)) => F2U32_FUNCTION,
2822                            (crate::ScalarKind::Sint, Some(8)) => F2I64_FUNCTION,
2823                            (crate::ScalarKind::Uint, Some(8)) => F2U64_FUNCTION,
2824                            _ => unreachable!(),
2825                        };
2826                        write!(self.out, "{fun_name}(")?;
2827                        self.put_expression(expr, context, true)?;
2828                        write!(self.out, ")")?;
2829                    } else {
2830                        let target_scalar = crate::Scalar {
2831                            kind,
2832                            width: convert.unwrap_or(src.width),
2833                        };
2834                        let op = match convert {
2835                            Some(_) => "static_cast",
2836                            None => "as_type",
2837                        };
2838                        write!(self.out, "{op}<")?;
2839                        match *context.resolve_type(expr) {
2840                            crate::TypeInner::Vector { size, .. } => {
2841                                put_numeric_type(&mut self.out, target_scalar, &[size])?
2842                            }
2843                            _ => put_numeric_type(&mut self.out, target_scalar, &[])?,
2844                        };
2845                        write!(self.out, ">(")?;
2846                        self.put_expression(expr, context, true)?;
2847                        write!(self.out, ")")?;
2848                    }
2849                }
2850                crate::TypeInner::Matrix {
2851                    columns,
2852                    rows,
2853                    scalar,
2854                } => {
2855                    let target_scalar = crate::Scalar {
2856                        kind,
2857                        width: convert.unwrap_or(scalar.width),
2858                    };
2859                    put_numeric_type(&mut self.out, target_scalar, &[rows, columns])?;
2860                    write!(self.out, "(")?;
2861                    self.put_expression(expr, context, true)?;
2862                    write!(self.out, ")")?;
2863                }
2864                ref ty => {
2865                    return Err(Error::GenericValidation(format!(
2866                        "Unsupported type for As: {ty:?}"
2867                    )))
2868                }
2869            },
2870            // has to be a named expression
2871            crate::Expression::CallResult(_)
2872            | crate::Expression::AtomicResult { .. }
2873            | crate::Expression::WorkGroupUniformLoadResult { .. }
2874            | crate::Expression::SubgroupBallotResult
2875            | crate::Expression::SubgroupOperationResult { .. }
2876            | crate::Expression::RayQueryProceedResult => {
2877                unreachable!()
2878            }
2879            crate::Expression::ArrayLength(expr) => {
2880                // Find the global to which the array belongs.
2881                let global = match context.function.expressions[expr] {
2882                    crate::Expression::AccessIndex { base, .. } => {
2883                        match context.function.expressions[base] {
2884                            crate::Expression::GlobalVariable(handle) => handle,
2885                            ref ex => {
2886                                return Err(Error::GenericValidation(format!(
2887                                    "Expected global variable in AccessIndex, got {ex:?}"
2888                                )))
2889                            }
2890                        }
2891                    }
2892                    crate::Expression::GlobalVariable(handle) => handle,
2893                    ref ex => {
2894                        return Err(Error::GenericValidation(format!(
2895                            "Unexpected expression in ArrayLength, got {ex:?}"
2896                        )))
2897                    }
2898                };
2899
2900                if !is_scoped {
2901                    write!(self.out, "(")?;
2902                }
2903                write!(self.out, "1 + ")?;
2904                self.put_dynamic_array_max_index(global, context)?;
2905                if !is_scoped {
2906                    write!(self.out, ")")?;
2907                }
2908            }
2909            crate::Expression::RayQueryVertexPositions { .. } => {
2910                unimplemented!()
2911            }
2912            crate::Expression::RayQueryGetIntersection { query, committed } => {
2913                if context.lang_version < (2, 4) {
2914                    return Err(Error::UnsupportedRayTracing);
2915                }
2916
2917                write!(
2918                    self.out,
2919                    "{}_{committed}(",
2920                    super::ray::INTERSECTION_FUNCTION_NAME
2921                )?;
2922                self.put_expression(query, context, true)?;
2923                write!(self.out, ")")?;
2924            }
2925            crate::Expression::CooperativeLoad { ref data, .. } => {
2926                if context.lang_version < (2, 3) {
2927                    return Err(Error::UnsupportedCooperativeMatrix);
2928                }
2929                write!(self.out, "{COOPERATIVE_LOAD_FUNCTION}(")?;
2930                write!(self.out, "&")?;
2931                self.put_access_chain(data.pointer, context.policies.index, context)?;
2932                write!(self.out, ", ")?;
2933                self.put_expression(data.stride, context, true)?;
2934                write!(self.out, ", {})", data.row_major)?;
2935            }
2936            crate::Expression::CooperativeMultiplyAdd { a, b, c } => {
2937                if context.lang_version < (2, 3) {
2938                    return Err(Error::UnsupportedCooperativeMatrix);
2939                }
2940                write!(self.out, "{COOPERATIVE_MULTIPLY_ADD_FUNCTION}(")?;
2941                self.put_expression(a, context, true)?;
2942                write!(self.out, ", ")?;
2943                self.put_expression(b, context, true)?;
2944                write!(self.out, ", ")?;
2945                self.put_expression(c, context, true)?;
2946                write!(self.out, ")")?;
2947            }
2948        }
2949        Ok(())
2950    }
2951
2952    /// Emits code for a binary operation, using the provided callback to emit
2953    /// the left and right operands.
2954    fn put_binop<F>(
2955        &mut self,
2956        op: crate::BinaryOperator,
2957        left: Handle<crate::Expression>,
2958        right: Handle<crate::Expression>,
2959        context: &ExpressionContext,
2960        is_scoped: bool,
2961        put_expression: &F,
2962    ) -> BackendResult
2963    where
2964        F: Fn(&mut Self, Handle<crate::Expression>, &ExpressionContext, bool) -> BackendResult,
2965    {
2966        let op_str = back::binary_operation_str(op);
2967
2968        if !is_scoped {
2969            write!(self.out, "(")?;
2970        }
2971
2972        // Cast packed vector if necessary
2973        // Packed vector - matrix multiplications are not supported in MSL
2974        if op == crate::BinaryOperator::Multiply
2975            && matches!(
2976                context.resolve_type(right),
2977                &crate::TypeInner::Matrix { .. }
2978            )
2979        {
2980            self.put_wrapped_expression_for_packed_vec3_access(
2981                left,
2982                context,
2983                false,
2984                put_expression,
2985            )?;
2986        } else {
2987            put_expression(self, left, context, false)?;
2988        }
2989
2990        write!(self.out, " {op_str} ")?;
2991
2992        // See comment above
2993        if op == crate::BinaryOperator::Multiply
2994            && matches!(context.resolve_type(left), &crate::TypeInner::Matrix { .. })
2995        {
2996            self.put_wrapped_expression_for_packed_vec3_access(
2997                right,
2998                context,
2999                false,
3000                put_expression,
3001            )?;
3002        } else {
3003            put_expression(self, right, context, false)?;
3004        }
3005
3006        if !is_scoped {
3007            write!(self.out, ")")?;
3008        }
3009
3010        Ok(())
3011    }
3012
3013    /// Used by expressions like Swizzle and Binary since they need packed_vec3's to be casted to a vec3
3014    fn put_wrapped_expression_for_packed_vec3_access<F>(
3015        &mut self,
3016        expr_handle: Handle<crate::Expression>,
3017        context: &ExpressionContext,
3018        is_scoped: bool,
3019        put_expression: &F,
3020    ) -> BackendResult
3021    where
3022        F: Fn(&mut Self, Handle<crate::Expression>, &ExpressionContext, bool) -> BackendResult,
3023    {
3024        if let Some(scalar) = context.get_packed_vec_kind(expr_handle) {
3025            write!(self.out, "{}::{}3(", NAMESPACE, scalar.to_msl_name())?;
3026            put_expression(self, expr_handle, context, is_scoped)?;
3027            write!(self.out, ")")?;
3028        } else {
3029            put_expression(self, expr_handle, context, is_scoped)?;
3030        }
3031        Ok(())
3032    }
3033
3034    /// Emits code for an expression using the provided callback, wrapping the
3035    /// result in a bitcast to the type `cast_to`.
3036    fn put_bitcasted_expression<F>(
3037        &mut self,
3038        cast_to: &crate::TypeInner,
3039        inner_expr: Handle<crate::Expression>,
3040        context: &ExpressionContext,
3041        put_expression: &F,
3042    ) -> BackendResult
3043    where
3044        F: Fn(&mut Self, &ExpressionContext, bool) -> BackendResult,
3045    {
3046        // For sub-32-bit types, C++ integer promotion can widen the inner
3047        // expression (e.g. `ushort + ushort` promotes to `int`), making a
3048        // direct `as_type<short>(int_expr)` invalid due to size mismatch.
3049        // We wrap with `static_cast` to truncate back before the bitcast.
3050        let needs_truncation = match *cast_to {
3051            crate::TypeInner::Scalar(scalar) => scalar.width < 4,
3052            crate::TypeInner::Vector { scalar, .. } => scalar.width < 4,
3053            _ => false,
3054        };
3055
3056        write!(self.out, "as_type<")?;
3057        match *cast_to {
3058            crate::TypeInner::Scalar(scalar) => put_numeric_type(&mut self.out, scalar, &[])?,
3059            crate::TypeInner::Vector { size, scalar } => {
3060                put_numeric_type(&mut self.out, scalar, &[size])?
3061            }
3062            _ => return Err(Error::UnsupportedBitCast(cast_to.clone())),
3063        };
3064        write!(self.out, ">(")?;
3065
3066        if needs_truncation {
3067            write!(self.out, "static_cast<")?;
3068            // Cast to the unsigned version of the target type to truncate
3069            let unsigned_scalar = match *cast_to {
3070                crate::TypeInner::Scalar(scalar) => crate::Scalar {
3071                    kind: crate::ScalarKind::Uint,
3072                    ..scalar
3073                },
3074                crate::TypeInner::Vector { scalar, .. } => crate::Scalar {
3075                    kind: crate::ScalarKind::Uint,
3076                    ..scalar
3077                },
3078                _ => unreachable!(),
3079            };
3080            match *cast_to {
3081                crate::TypeInner::Scalar(_) => {
3082                    put_numeric_type(&mut self.out, unsigned_scalar, &[])?
3083                }
3084                crate::TypeInner::Vector { size, .. } => {
3085                    put_numeric_type(&mut self.out, unsigned_scalar, &[size])?
3086                }
3087                _ => unreachable!(),
3088            };
3089            write!(self.out, ">(")?;
3090        }
3091
3092        // if it's packed, we must unpack it (e.g., float3(val)) before the bitcast.
3093        if let Some(scalar) = context.get_packed_vec_kind(inner_expr) {
3094            put_numeric_type(&mut self.out, scalar, &[crate::VectorSize::Tri])?;
3095            write!(self.out, "(")?;
3096            put_expression(self, context, true)?;
3097            write!(self.out, ")")?;
3098        } else {
3099            put_expression(self, context, true)?;
3100        }
3101
3102        if needs_truncation {
3103            write!(self.out, ")")?;
3104        }
3105
3106        write!(self.out, ")")?;
3107        Ok(())
3108    }
3109
3110    /// Write a `GuardedIndex` as a Metal expression.
3111    fn put_index(
3112        &mut self,
3113        index: index::GuardedIndex,
3114        context: &ExpressionContext,
3115        is_scoped: bool,
3116    ) -> BackendResult {
3117        match index {
3118            index::GuardedIndex::Expression(expr) => {
3119                self.put_expression(expr, context, is_scoped)?
3120            }
3121            index::GuardedIndex::Known(value) => write!(self.out, "{value}")?,
3122        }
3123        Ok(())
3124    }
3125
3126    /// Emit an index bounds check condition for `chain`, if required.
3127    ///
3128    /// `chain` is a subtree of `Access` and `AccessIndex` expressions,
3129    /// operating either on a pointer to a value, or on a value directly. If we cannot
3130    /// statically determine that all indexing operations in `chain` are within
3131    /// bounds, then write a conditional expression to check them dynamically,
3132    /// and return true. All accesses in the chain are checked by the generated
3133    /// expression.
3134    ///
3135    /// This assumes that the [`BoundsCheckPolicy`] for `chain` is [`ReadZeroSkipWrite`].
3136    ///
3137    /// The text written is of the form:
3138    ///
3139    /// ```ignore
3140    /// {level}{prefix}uint(i) < 4 && uint(j) < 10
3141    /// ```
3142    ///
3143    /// where `{level}` and `{prefix}` are the arguments to this function. For [`Store`]
3144    /// statements, presumably these arguments start an indented `if` statement; for
3145    /// [`Load`] expressions, the caller is probably building up a ternary `?:`
3146    /// expression. In either case, what is written is not a complete syntactic structure
3147    /// in its own right, and the caller will have to finish it off if we return `true`.
3148    ///
3149    /// If no expression is written, return false.
3150    ///
3151    /// [`BoundsCheckPolicy`]: index::BoundsCheckPolicy
3152    /// [`ReadZeroSkipWrite`]: index::BoundsCheckPolicy::ReadZeroSkipWrite
3153    /// [`Store`]: crate::Statement::Store
3154    /// [`Load`]: crate::Expression::Load
3155    fn put_bounds_checks(
3156        &mut self,
3157        chain: Handle<crate::Expression>,
3158        context: &ExpressionContext,
3159        level: back::Level,
3160        prefix: &'static str,
3161    ) -> Result<bool, Error> {
3162        let mut check_written = false;
3163
3164        // Iterate over the access chain, handling each required bounds check.
3165        for item in context.bounds_check_iter(chain) {
3166            let BoundsCheck {
3167                base,
3168                index,
3169                length,
3170            } = item;
3171
3172            if check_written {
3173                write!(self.out, " && ")?;
3174            } else {
3175                write!(self.out, "{level}{prefix}")?;
3176                check_written = true;
3177            }
3178
3179            // Check that the index falls within bounds. Do this with a single
3180            // comparison, by casting the index to `uint` first, so that negative
3181            // indices become large positive values.
3182            write!(self.out, "uint(")?;
3183            self.put_index(index, context, true)?;
3184            self.out.write_str(") < ")?;
3185            match length {
3186                index::IndexableLength::Known(value) => write!(self.out, "{value}")?,
3187                index::IndexableLength::Dynamic => {
3188                    let global = context.function.originating_global(base).ok_or_else(|| {
3189                        Error::GenericValidation("Could not find originating global".into())
3190                    })?;
3191                    write!(self.out, "1 + ")?;
3192                    self.put_dynamic_array_max_index(global, context)?
3193                }
3194            }
3195        }
3196
3197        Ok(check_written)
3198    }
3199
3200    /// Write the access chain `chain`.
3201    ///
3202    /// `chain` is a subtree of [`Access`] and [`AccessIndex`] expressions,
3203    /// operating either on a pointer to a value, or on a value directly.
3204    ///
3205    /// Generate bounds checks code only if `policy` is [`Restrict`]. The
3206    /// [`ReadZeroSkipWrite`] policy requires checks before any accesses take place, so
3207    /// that must be handled in the caller.
3208    ///
3209    /// Handle the entire chain, recursing back into `put_expression` only for index
3210    /// expressions and the base expression that originates the pointer or composite value
3211    /// being accessed. This allows `put_expression` to assume that any `Access` or
3212    /// `AccessIndex` expressions it sees are the top of a chain, so it can emit
3213    /// `ReadZeroSkipWrite` checks.
3214    ///
3215    /// [`Access`]: crate::Expression::Access
3216    /// [`AccessIndex`]: crate::Expression::AccessIndex
3217    /// [`Restrict`]: crate::proc::index::BoundsCheckPolicy::Restrict
3218    /// [`ReadZeroSkipWrite`]: crate::proc::index::BoundsCheckPolicy::ReadZeroSkipWrite
3219    fn put_access_chain(
3220        &mut self,
3221        chain: Handle<crate::Expression>,
3222        policy: index::BoundsCheckPolicy,
3223        context: &ExpressionContext,
3224    ) -> BackendResult {
3225        match context.function.expressions[chain] {
3226            crate::Expression::Access { base, index } => {
3227                let mut base_ty = context.resolve_type(base);
3228
3229                // Look through any pointers to see what we're really indexing.
3230                if let crate::TypeInner::Pointer { base, space: _ } = *base_ty {
3231                    base_ty = &context.module.types[base].inner;
3232                }
3233
3234                self.put_subscripted_access_chain(
3235                    base,
3236                    base_ty,
3237                    index::GuardedIndex::Expression(index),
3238                    policy,
3239                    context,
3240                )?;
3241            }
3242            crate::Expression::AccessIndex { base, index } => {
3243                let base_resolution = &context.info[base].ty;
3244                let mut base_ty = base_resolution.inner_with(&context.module.types);
3245                let mut base_ty_handle = base_resolution.handle();
3246
3247                // Look through any pointers to see what we're really indexing.
3248                if let crate::TypeInner::Pointer { base, space: _ } = *base_ty {
3249                    base_ty = &context.module.types[base].inner;
3250                    base_ty_handle = Some(base);
3251                }
3252
3253                // Handle structs and anything else that can use `.x` syntax here, so
3254                // `put_subscripted_access_chain` won't have to handle the absurd case of
3255                // indexing a struct with an expression.
3256                match *base_ty {
3257                    crate::TypeInner::Struct { .. } => {
3258                        let base_ty = base_ty_handle.unwrap();
3259                        self.put_access_chain(base, policy, context)?;
3260                        let name = &self.names[&NameKey::StructMember(base_ty, index)];
3261                        write!(self.out, ".{name}")?;
3262                    }
3263                    crate::TypeInner::ValuePointer { .. } | crate::TypeInner::Vector { .. } => {
3264                        self.put_access_chain(base, policy, context)?;
3265                        // Prior to Metal v2.1 component access for packed vectors wasn't available
3266                        // however array indexing is
3267                        if context.get_packed_vec_kind(base).is_some() {
3268                            write!(self.out, "[{index}]")?;
3269                        } else {
3270                            write!(self.out, ".{}", back::COMPONENTS[index as usize])?;
3271                        }
3272                    }
3273                    _ => {
3274                        self.put_subscripted_access_chain(
3275                            base,
3276                            base_ty,
3277                            index::GuardedIndex::Known(index),
3278                            policy,
3279                            context,
3280                        )?;
3281                    }
3282                }
3283            }
3284            _ => self.put_expression(chain, context, false)?,
3285        }
3286
3287        Ok(())
3288    }
3289
3290    /// Write a `[]`-style access of `base` by `index`.
3291    ///
3292    /// If `policy` is [`Restrict`], then generate code as needed to force all index
3293    /// values within bounds.
3294    ///
3295    /// The `base_ty` argument must be the type we are actually indexing, like [`Array`] or
3296    /// [`Vector`]. In other words, it's `base`'s type with any surrounding [`Pointer`]
3297    /// removed. Our callers often already have this handy.
3298    ///
3299    /// This only emits `[]` expressions; it doesn't handle struct member accesses or
3300    /// referencing vector components by name.
3301    ///
3302    /// [`Restrict`]: crate::proc::index::BoundsCheckPolicy::Restrict
3303    /// [`Array`]: crate::TypeInner::Array
3304    /// [`Vector`]: crate::TypeInner::Vector
3305    /// [`Pointer`]: crate::TypeInner::Pointer
3306    fn put_subscripted_access_chain(
3307        &mut self,
3308        base: Handle<crate::Expression>,
3309        base_ty: &crate::TypeInner,
3310        index: index::GuardedIndex,
3311        policy: index::BoundsCheckPolicy,
3312        context: &ExpressionContext,
3313    ) -> BackendResult {
3314        let accessing_wrapped_array = match *base_ty {
3315            crate::TypeInner::Array {
3316                size: crate::ArraySize::Constant(_) | crate::ArraySize::Pending(_),
3317                ..
3318            } => true,
3319            _ => false,
3320        };
3321        let accessing_wrapped_binding_array =
3322            matches!(*base_ty, crate::TypeInner::BindingArray { .. });
3323
3324        self.put_access_chain(base, policy, context)?;
3325        if accessing_wrapped_array {
3326            write!(self.out, ".{WRAPPED_ARRAY_FIELD}")?;
3327        }
3328        write!(self.out, "[")?;
3329
3330        // Decide whether this index needs to be clamped to fall within range.
3331        let restriction_needed = if policy == index::BoundsCheckPolicy::Restrict {
3332            context.access_needs_check(base, index)
3333        } else {
3334            None
3335        };
3336        if let Some(limit) = restriction_needed {
3337            write!(self.out, "{NAMESPACE}::min(unsigned(")?;
3338            self.put_index(index, context, true)?;
3339            write!(self.out, "), ")?;
3340            match limit {
3341                index::IndexableLength::Known(limit) => {
3342                    write!(self.out, "{}u", limit - 1)?;
3343                }
3344                index::IndexableLength::Dynamic => {
3345                    let global = context.function.originating_global(base).ok_or_else(|| {
3346                        Error::GenericValidation("Could not find originating global".into())
3347                    })?;
3348                    self.put_dynamic_array_max_index(global, context)?;
3349                }
3350            }
3351            write!(self.out, ")")?;
3352        } else {
3353            self.put_index(index, context, true)?;
3354        }
3355
3356        write!(self.out, "]")?;
3357
3358        if accessing_wrapped_binding_array {
3359            write!(self.out, ".{WRAPPED_ARRAY_FIELD}")?;
3360        }
3361
3362        Ok(())
3363    }
3364
3365    fn put_load(
3366        &mut self,
3367        pointer: Handle<crate::Expression>,
3368        context: &ExpressionContext,
3369        is_scoped: bool,
3370    ) -> BackendResult {
3371        // Since access chains never cross between address spaces, we can just
3372        // check the index bounds check policy once at the top.
3373        let policy = context.choose_bounds_check_policy(pointer);
3374        if policy == index::BoundsCheckPolicy::ReadZeroSkipWrite
3375            && self.put_bounds_checks(
3376                pointer,
3377                context,
3378                back::Level(0),
3379                if is_scoped { "" } else { "(" },
3380            )?
3381        {
3382            write!(self.out, " ? ")?;
3383            self.put_unchecked_load(pointer, policy, context)?;
3384            write!(self.out, " : DefaultConstructible()")?;
3385
3386            if !is_scoped {
3387                write!(self.out, ")")?;
3388            }
3389        } else {
3390            self.put_unchecked_load(pointer, policy, context)?;
3391        }
3392
3393        Ok(())
3394    }
3395
3396    fn put_unchecked_load(
3397        &mut self,
3398        pointer: Handle<crate::Expression>,
3399        policy: index::BoundsCheckPolicy,
3400        context: &ExpressionContext,
3401    ) -> BackendResult {
3402        let is_atomic_pointer = context
3403            .resolve_type(pointer)
3404            .is_atomic_pointer(&context.module.types);
3405
3406        if is_atomic_pointer {
3407            write!(
3408                self.out,
3409                "{NAMESPACE}::atomic_load_explicit({ATOMIC_REFERENCE}"
3410            )?;
3411            self.put_access_chain(pointer, policy, context)?;
3412            write!(self.out, ", {NAMESPACE}::memory_order_relaxed)")?;
3413        } else {
3414            // We don't do any dereferencing with `*` here as pointer arguments to functions
3415            // are done by `&` references and not `*` pointers. These do not need to be
3416            // dereferenced.
3417            self.put_access_chain(pointer, policy, context)?;
3418        }
3419
3420        Ok(())
3421    }
3422
3423    fn put_return_value(
3424        &mut self,
3425        level: back::Level,
3426        expr_handle: Handle<crate::Expression>,
3427        result_struct: Option<&str>,
3428        context: &ExpressionContext,
3429    ) -> BackendResult {
3430        match result_struct {
3431            Some(struct_name) => {
3432                let mut has_point_size = false;
3433                let result_ty = context.function.result.as_ref().unwrap().ty;
3434                match context.module.types[result_ty].inner {
3435                    crate::TypeInner::Struct { ref members, .. } => {
3436                        let tmp = "_tmp";
3437                        write!(self.out, "{level}const auto {tmp} = ")?;
3438                        self.put_expression(expr_handle, context, true)?;
3439                        writeln!(self.out, ";")?;
3440                        write!(self.out, "{level}return {struct_name} {{")?;
3441
3442                        let mut is_first = true;
3443
3444                        for (index, member) in members.iter().enumerate() {
3445                            if let Some(crate::Binding::BuiltIn(crate::BuiltIn::PointSize)) =
3446                                member.binding
3447                            {
3448                                has_point_size = true;
3449                                if !context.pipeline_options.allow_and_force_point_size {
3450                                    continue;
3451                                }
3452                            }
3453
3454                            let comma = if is_first { "" } else { "," };
3455                            is_first = false;
3456                            let name = &self.names[&NameKey::StructMember(result_ty, index as u32)];
3457                            // HACK: we are forcefully deduplicating the expression here
3458                            // to convert from a wrapped struct to a raw array, e.g.
3459                            // `float gl_ClipDistance1 [[clip_distance]] [1];`.
3460                            if let crate::TypeInner::Array {
3461                                size: crate::ArraySize::Constant(size),
3462                                ..
3463                            } = context.module.types[member.ty].inner
3464                            {
3465                                write!(self.out, "{comma} {{")?;
3466                                for j in 0..size.get() {
3467                                    if j != 0 {
3468                                        write!(self.out, ",")?;
3469                                    }
3470                                    write!(self.out, "{tmp}.{name}.{WRAPPED_ARRAY_FIELD}[{j}]")?;
3471                                }
3472                                write!(self.out, "}}")?;
3473                            } else {
3474                                write!(self.out, "{comma} {tmp}.{name}")?;
3475                            }
3476                        }
3477                    }
3478                    _ => {
3479                        write!(self.out, "{level}return {struct_name} {{ ")?;
3480                        self.put_expression(expr_handle, context, true)?;
3481                    }
3482                }
3483
3484                if let FunctionOrigin::EntryPoint(ep_index) = context.origin {
3485                    let stage = context.module.entry_points[ep_index as usize].stage;
3486                    if context.pipeline_options.allow_and_force_point_size
3487                        && stage == crate::ShaderStage::Vertex
3488                        && !has_point_size
3489                    {
3490                        // point size was injected and comes last
3491                        write!(self.out, ", 1.0")?;
3492                    }
3493                }
3494                write!(self.out, " }}")?;
3495            }
3496            None => {
3497                write!(self.out, "{level}return ")?;
3498                self.put_expression(expr_handle, context, true)?;
3499            }
3500        }
3501        writeln!(self.out, ";")?;
3502        Ok(())
3503    }
3504
3505    /// Helper method used to find which expressions of a given function require baking
3506    ///
3507    /// # Notes
3508    /// This function overwrites the contents of `self.need_bake_expressions`
3509    fn update_expressions_to_bake(
3510        &mut self,
3511        func: &crate::Function,
3512        info: &valid::FunctionInfo,
3513        context: &ExpressionContext,
3514    ) {
3515        use crate::Expression;
3516        self.need_bake_expressions.clear();
3517
3518        for (expr_handle, expr) in func.expressions.iter() {
3519            // Expressions whose reference count is above the
3520            // threshold should always be stored in temporaries.
3521            let expr_info = &info[expr_handle];
3522            let min_ref_count = func.expressions[expr_handle].bake_ref_count();
3523            if min_ref_count <= expr_info.ref_count {
3524                self.need_bake_expressions.insert(expr_handle);
3525            } else {
3526                match expr_info.ty {
3527                    // force ray desc to be baked: it's used multiple times internally
3528                    TypeResolution::Handle(h)
3529                        if Some(h) == context.module.special_types.ray_desc =>
3530                    {
3531                        self.need_bake_expressions.insert(expr_handle);
3532                    }
3533                    _ => {}
3534                }
3535            }
3536
3537            if let Expression::Math {
3538                fun,
3539                arg,
3540                arg1,
3541                arg2,
3542                ..
3543            } = *expr
3544            {
3545                match fun {
3546                    // WGSL's `dot` function works on any `vecN` type, but Metal's only
3547                    // works on floating-point vectors, so we emit inline code for
3548                    // integer vector `dot` calls. But that code uses each argument `N`
3549                    // times, once for each component (see `put_dot_product`), so to
3550                    // avoid duplicated evaluation, we must bake integer operands.
3551                    // This applies both when using the polyfill (because of the duplicate
3552                    // evaluation issue) and when we don't use the polyfill (because we
3553                    // need them to be emitted before casting to packed chars -- see the
3554                    // comment at the call to `put_casting_to_packed_chars`).
3555                    crate::MathFunction::Dot4U8Packed | crate::MathFunction::Dot4I8Packed => {
3556                        self.need_bake_expressions.insert(arg);
3557                        self.need_bake_expressions.insert(arg1.unwrap());
3558                    }
3559                    crate::MathFunction::FirstLeadingBit => {
3560                        self.need_bake_expressions.insert(arg);
3561                    }
3562                    crate::MathFunction::Pack4xI8
3563                    | crate::MathFunction::Pack4xU8
3564                    | crate::MathFunction::Pack4xI8Clamp
3565                    | crate::MathFunction::Pack4xU8Clamp
3566                    | crate::MathFunction::Unpack4xI8
3567                    | crate::MathFunction::Unpack4xU8 => {
3568                        // On MSL < 2.1, we emit a polyfill for these functions that uses the
3569                        // argument multiple times. This is no longer necessary on MSL >= 2.1.
3570                        if context.lang_version < (2, 1) {
3571                            self.need_bake_expressions.insert(arg);
3572                        }
3573                    }
3574                    crate::MathFunction::ExtractBits => {
3575                        // Only argument 1 is re-used.
3576                        self.need_bake_expressions.insert(arg1.unwrap());
3577                    }
3578                    crate::MathFunction::InsertBits => {
3579                        // Only argument 2 is re-used.
3580                        self.need_bake_expressions.insert(arg2.unwrap());
3581                    }
3582                    crate::MathFunction::Sign => {
3583                        // WGSL's `sign` function works also on signed ints, but Metal's only
3584                        // works on floating points, so we emit inline code for integer `sign`
3585                        // calls. But that code uses each argument 2 times (see `put_isign`),
3586                        // so to avoid duplicated evaluation, we must bake the argument.
3587                        let inner = context.resolve_type(expr_handle);
3588                        if inner.scalar_kind() == Some(crate::ScalarKind::Sint) {
3589                            self.need_bake_expressions.insert(arg);
3590                        }
3591                    }
3592                    _ => {}
3593                }
3594            }
3595        }
3596    }
3597
3598    pub(super) fn start_baking_expression(
3599        &mut self,
3600        handle: Handle<crate::Expression>,
3601        context: &ExpressionContext,
3602        name: &str,
3603    ) -> BackendResult {
3604        match context.info[handle].ty {
3605            TypeResolution::Handle(ty_handle) => {
3606                let ty_name = TypeContext {
3607                    handle: ty_handle,
3608                    gctx: context.module.to_ctx(),
3609                    names: &self.names,
3610                    access: crate::StorageAccess::empty(),
3611                    first_time: false,
3612                };
3613                write!(self.out, "{ty_name}")?;
3614            }
3615            TypeResolution::Value(crate::TypeInner::Scalar(scalar)) => {
3616                put_numeric_type(&mut self.out, scalar, &[])?;
3617            }
3618            TypeResolution::Value(crate::TypeInner::Vector { size, scalar }) => {
3619                put_numeric_type(&mut self.out, scalar, &[size])?;
3620            }
3621            TypeResolution::Value(crate::TypeInner::Matrix {
3622                columns,
3623                rows,
3624                scalar,
3625            }) => {
3626                put_numeric_type(&mut self.out, scalar, &[rows, columns])?;
3627            }
3628            TypeResolution::Value(crate::TypeInner::CooperativeMatrix {
3629                columns,
3630                rows,
3631                scalar,
3632                role: _,
3633            }) => {
3634                write!(
3635                    self.out,
3636                    "{}::simdgroup_{}{}x{}",
3637                    NAMESPACE,
3638                    scalar.to_msl_name(),
3639                    columns as u32,
3640                    rows as u32,
3641                )?;
3642            }
3643            TypeResolution::Value(ref other) => {
3644                log::warn!("Type {other:?} isn't a known local");
3645                return Err(Error::FeatureNotImplemented("weird local type".to_string()));
3646            }
3647        }
3648
3649        //TODO: figure out the naming scheme that wouldn't collide with user names.
3650        write!(self.out, " {name} = ")?;
3651
3652        Ok(())
3653    }
3654
3655    /// Cache a clamped level of detail value, if necessary.
3656    ///
3657    /// [`ImageLoad`] accesses covered by [`BoundsCheckPolicy::Restrict`] use a
3658    /// properly clamped level of detail value both in the access itself, and
3659    /// for fetching the size of the requested MIP level, needed to clamp the
3660    /// coordinates. To avoid recomputing this clamped level of detail, we cache
3661    /// it in a temporary variable, as part of the [`Emit`] statement covering
3662    /// the [`ImageLoad`] expression.
3663    ///
3664    /// [`ImageLoad`]: crate::Expression::ImageLoad
3665    /// [`BoundsCheckPolicy::Restrict`]: index::BoundsCheckPolicy::Restrict
3666    /// [`Emit`]: crate::Statement::Emit
3667    fn put_cache_restricted_level(
3668        &mut self,
3669        load: Handle<crate::Expression>,
3670        image: Handle<crate::Expression>,
3671        mip_level: Option<Handle<crate::Expression>>,
3672        indent: back::Level,
3673        context: &StatementContext,
3674    ) -> BackendResult {
3675        // Does this image access actually require (or even permit) a
3676        // level-of-detail, and does the policy require us to restrict it?
3677        let level_of_detail = match mip_level {
3678            Some(level) => level,
3679            None => return Ok(()),
3680        };
3681
3682        if context.expression.policies.image_load != index::BoundsCheckPolicy::Restrict
3683            || !context.expression.image_needs_lod(image)
3684        {
3685            return Ok(());
3686        }
3687
3688        write!(self.out, "{}uint {} = ", indent, ClampedLod(load),)?;
3689        self.put_restricted_scalar_image_index(
3690            image,
3691            level_of_detail,
3692            "get_num_mip_levels",
3693            &context.expression,
3694        )?;
3695        writeln!(self.out, ";")?;
3696
3697        Ok(())
3698    }
3699
3700    /// Convert the arguments of `Dot4{I, U}Packed` to `packed_(u?)char4`.
3701    ///
3702    /// Caches the results in temporary variables (whose names are derived from
3703    /// the original variable names). This caching avoids the need to redo the
3704    /// casting for each vector component when emitting the dot product.
3705    fn put_casting_to_packed_chars(
3706        &mut self,
3707        fun: crate::MathFunction,
3708        arg0: Handle<crate::Expression>,
3709        arg1: Handle<crate::Expression>,
3710        indent: back::Level,
3711        context: &StatementContext<'_>,
3712    ) -> Result<(), Error> {
3713        let packed_type = match fun {
3714            crate::MathFunction::Dot4I8Packed => "packed_char4",
3715            crate::MathFunction::Dot4U8Packed => "packed_uchar4",
3716            _ => unreachable!(),
3717        };
3718
3719        for arg in [arg0, arg1] {
3720            write!(
3721                self.out,
3722                "{indent}{packed_type} {0} = as_type<{packed_type}>(",
3723                Reinterpreted::new(packed_type, arg)
3724            )?;
3725            self.put_expression(arg, &context.expression, true)?;
3726            writeln!(self.out, ");")?;
3727        }
3728
3729        Ok(())
3730    }
3731
3732    fn put_block(
3733        &mut self,
3734        level: back::Level,
3735        statements: &[crate::Statement],
3736        context: &StatementContext,
3737    ) -> BackendResult {
3738        // Add to the set in order to track the stack size.
3739        #[cfg(test)]
3740        self.put_block_stack_pointers
3741            .insert(ptr::from_ref(&level).cast());
3742
3743        for statement in statements {
3744            log::trace!("statement[{}] {:?}", level.0, statement);
3745            match *statement {
3746                crate::Statement::Emit(ref range) => {
3747                    for handle in range.clone() {
3748                        use crate::MathFunction as Mf;
3749
3750                        match context.expression.function.expressions[handle] {
3751                            // `ImageLoad` expressions covered by the `Restrict` bounds check policy
3752                            // may need to cache a clamped version of their level-of-detail argument.
3753                            crate::Expression::ImageLoad {
3754                                image,
3755                                level: mip_level,
3756                                ..
3757                            } => {
3758                                self.put_cache_restricted_level(
3759                                    handle, image, mip_level, level, context,
3760                                )?;
3761                            }
3762
3763                            // If we are going to write a `Dot4I8Packed` or `Dot4U8Packed` on Metal
3764                            // 2.1+ then we introduce two intermediate variables that recast the two
3765                            // arguments as packed (signed or unsigned) chars. The actual dot product
3766                            // is implemented in `Self::put_expression`, and it uses both of these
3767                            // intermediate variables multiple times. There's no danger that the
3768                            // original arguments get modified between the definition of these
3769                            // intermediate variables and the implementation of the actual dot
3770                            // product since we require the inputs of `Dot4{I, U}Packed` to be baked.
3771                            crate::Expression::Math {
3772                                fun: fun @ (Mf::Dot4I8Packed | Mf::Dot4U8Packed),
3773                                arg,
3774                                arg1,
3775                                ..
3776                            } if context.expression.lang_version >= (2, 1) => {
3777                                self.put_casting_to_packed_chars(
3778                                    fun,
3779                                    arg,
3780                                    arg1.unwrap(),
3781                                    level,
3782                                    context,
3783                                )?;
3784                            }
3785
3786                            _ => (),
3787                        }
3788
3789                        let ptr_class = context.expression.resolve_type(handle).pointer_space();
3790                        let expr_name = if ptr_class.is_some() {
3791                            None // don't bake pointer expressions (just yet)
3792                        } else if let Some(name) =
3793                            context.expression.function.named_expressions.get(&handle)
3794                        {
3795                            // The `crate::Function::named_expressions` table holds
3796                            // expressions that should be saved in temporaries once they
3797                            // are `Emit`ted. We only add them to `self.named_expressions`
3798                            // when we reach the `Emit` that covers them, so that we don't
3799                            // try to use their names before we've actually initialized
3800                            // the temporary that holds them.
3801                            //
3802                            // Don't assume the names in `named_expressions` are unique,
3803                            // or even valid. Use the `Namer`.
3804                            Some(self.namer.call(name))
3805                        } else {
3806                            // If this expression is an index that we're going to first compare
3807                            // against a limit, and then actually use as an index, then we may
3808                            // want to cache it in a temporary, to avoid evaluating it twice.
3809                            let bake = if context.expression.guarded_indices.contains(handle) {
3810                                true
3811                            } else {
3812                                self.need_bake_expressions.contains(&handle)
3813                            };
3814
3815                            if bake {
3816                                Some(Baked(handle).to_string())
3817                            } else {
3818                                None
3819                            }
3820                        };
3821
3822                        if let Some(name) = expr_name {
3823                            write!(self.out, "{level}")?;
3824                            self.start_baking_expression(handle, &context.expression, &name)?;
3825                            self.put_expression(handle, &context.expression, true)?;
3826                            self.named_expressions.insert(handle, name);
3827                            writeln!(self.out, ";")?;
3828                        }
3829                    }
3830                }
3831                crate::Statement::Block(ref block) => {
3832                    if !block.is_empty() {
3833                        writeln!(self.out, "{level}{{")?;
3834                        self.put_block(level.next(), block, context)?;
3835                        writeln!(self.out, "{level}}}")?;
3836                    }
3837                }
3838                crate::Statement::If {
3839                    condition,
3840                    ref accept,
3841                    ref reject,
3842                } => {
3843                    write!(self.out, "{level}if (")?;
3844                    self.put_expression(condition, &context.expression, true)?;
3845                    writeln!(self.out, ") {{")?;
3846                    self.put_block(level.next(), accept, context)?;
3847                    if !reject.is_empty() {
3848                        writeln!(self.out, "{level}}} else {{")?;
3849                        self.put_block(level.next(), reject, context)?;
3850                    }
3851                    writeln!(self.out, "{level}}}")?;
3852                }
3853                crate::Statement::Switch {
3854                    selector,
3855                    ref cases,
3856                } => {
3857                    write!(self.out, "{level}switch(")?;
3858                    self.put_expression(selector, &context.expression, true)?;
3859                    writeln!(self.out, ") {{")?;
3860                    let lcase = level.next();
3861                    for case in cases.iter() {
3862                        match case.value {
3863                            crate::SwitchValue::I32(value) => {
3864                                write!(self.out, "{lcase}case {value}:")?;
3865                            }
3866                            crate::SwitchValue::U32(value) => {
3867                                write!(self.out, "{lcase}case {value}u:")?;
3868                            }
3869                            crate::SwitchValue::Default => {
3870                                write!(self.out, "{lcase}default:")?;
3871                            }
3872                        }
3873
3874                        let write_block_braces = !(case.fall_through && case.body.is_empty());
3875                        if write_block_braces {
3876                            writeln!(self.out, " {{")?;
3877                        } else {
3878                            writeln!(self.out)?;
3879                        }
3880
3881                        self.put_block(lcase.next(), &case.body, context)?;
3882                        if !case.fall_through && case.body.last().is_none_or(|s| !s.is_terminator())
3883                        {
3884                            writeln!(self.out, "{}break;", lcase.next())?;
3885                        }
3886
3887                        if write_block_braces {
3888                            writeln!(self.out, "{lcase}}}")?;
3889                        }
3890                    }
3891                    writeln!(self.out, "{level}}}")?;
3892                }
3893                crate::Statement::Loop {
3894                    ref body,
3895                    ref continuing,
3896                    break_if,
3897                } => {
3898                    let force_loop_bound_statements =
3899                        self.gen_force_bounded_loop_statements(level, context);
3900                    let gate_name = (!continuing.is_empty() || break_if.is_some())
3901                        .then(|| self.namer.call("loop_init"));
3902
3903                    if let Some((ref decl, _)) = force_loop_bound_statements {
3904                        writeln!(self.out, "{decl}")?;
3905                    }
3906                    if let Some(ref gate_name) = gate_name {
3907                        writeln!(self.out, "{level}bool {gate_name} = true;")?;
3908                    }
3909
3910                    writeln!(self.out, "{level}while(true) {{",)?;
3911                    if let Some((_, ref break_and_inc)) = force_loop_bound_statements {
3912                        writeln!(self.out, "{break_and_inc}")?;
3913                    }
3914                    if let Some(ref gate_name) = gate_name {
3915                        let lif = level.next();
3916                        let lcontinuing = lif.next();
3917                        writeln!(self.out, "{lif}if (!{gate_name}) {{")?;
3918                        self.put_block(lcontinuing, continuing, context)?;
3919                        if let Some(condition) = break_if {
3920                            write!(self.out, "{lcontinuing}if (")?;
3921                            self.put_expression(condition, &context.expression, true)?;
3922                            writeln!(self.out, ") {{")?;
3923                            writeln!(self.out, "{}break;", lcontinuing.next())?;
3924                            writeln!(self.out, "{lcontinuing}}}")?;
3925                        }
3926                        writeln!(self.out, "{lif}}}")?;
3927                        writeln!(self.out, "{lif}{gate_name} = false;")?;
3928                    }
3929                    self.put_block(level.next(), body, context)?;
3930
3931                    writeln!(self.out, "{level}}}")?;
3932                }
3933                crate::Statement::Break => {
3934                    writeln!(self.out, "{level}break;")?;
3935                }
3936                crate::Statement::Continue => {
3937                    writeln!(self.out, "{level}continue;")?;
3938                }
3939                crate::Statement::Return {
3940                    value: Some(expr_handle),
3941                } => {
3942                    self.put_return_value(
3943                        level,
3944                        expr_handle,
3945                        context.result_struct,
3946                        &context.expression,
3947                    )?;
3948                }
3949                crate::Statement::Return { value: None } => {
3950                    writeln!(self.out, "{level}return;")?;
3951                }
3952                crate::Statement::Kill => {
3953                    writeln!(self.out, "{level}{NAMESPACE}::discard_fragment();")?;
3954                }
3955                crate::Statement::ControlBarrier(flags)
3956                | crate::Statement::MemoryBarrier(flags) => {
3957                    self.write_barrier(flags, level)?;
3958                }
3959                crate::Statement::Store { pointer, value } => {
3960                    self.put_store(pointer, value, level, context)?
3961                }
3962                crate::Statement::ImageStore {
3963                    image,
3964                    coordinate,
3965                    array_index,
3966                    value,
3967                } => {
3968                    let address = TexelAddress {
3969                        coordinate,
3970                        array_index,
3971                        sample: None,
3972                        level: None,
3973                    };
3974                    self.put_image_store(level, image, &address, value, context)?
3975                }
3976                crate::Statement::Call {
3977                    function,
3978                    ref arguments,
3979                    result,
3980                } => {
3981                    write!(self.out, "{level}")?;
3982                    if let Some(expr) = result {
3983                        let name = Baked(expr).to_string();
3984                        self.start_baking_expression(expr, &context.expression, &name)?;
3985                        self.named_expressions.insert(expr, name);
3986                    }
3987                    let fun_name = &self.names[&NameKey::Function(function)];
3988                    write!(self.out, "{fun_name}(")?;
3989                    // first, write down the actual arguments
3990                    for (i, &handle) in arguments.iter().enumerate() {
3991                        if i != 0 {
3992                            write!(self.out, ", ")?;
3993                        }
3994                        self.put_expression(handle, &context.expression, true)?;
3995                    }
3996                    // follow-up with any global resources used
3997                    let mut separate = !arguments.is_empty();
3998                    let fun_info = &context.expression.mod_info[function];
3999                    let mut needs_buffer_sizes = false;
4000                    for (handle, var) in context.expression.module.global_variables.iter() {
4001                        if fun_info[handle].is_empty() {
4002                            continue;
4003                        }
4004                        if var.space.needs_pass_through() {
4005                            let name = &self.names[&NameKey::GlobalVariable(handle)];
4006                            if separate {
4007                                write!(self.out, ", ")?;
4008                            } else {
4009                                separate = true;
4010                            }
4011                            write!(self.out, "{name}")?;
4012                        }
4013                        needs_buffer_sizes |=
4014                            needs_array_length(var.ty, &context.expression.module.types);
4015                    }
4016                    if needs_buffer_sizes {
4017                        if separate {
4018                            write!(self.out, ", ")?;
4019                        }
4020                        write!(self.out, "_buffer_sizes")?;
4021                    }
4022
4023                    // done
4024                    writeln!(self.out, ");")?;
4025                }
4026                crate::Statement::Atomic {
4027                    pointer,
4028                    ref fun,
4029                    value,
4030                    result,
4031                } => {
4032                    let context = &context.expression;
4033
4034                    // This backend supports `SHADER_INT64_ATOMIC_MIN_MAX` but not
4035                    // `SHADER_INT64_ATOMIC_ALL_OPS`, so we can assume that if `result` is
4036                    // `Some`, we are not operating on a 64-bit value, and that if we are
4037                    // operating on a 64-bit value, `result` is `None`.
4038                    write!(self.out, "{level}")?;
4039                    let fun_key = if let Some(result) = result {
4040                        let res_name = Baked(result).to_string();
4041                        self.start_baking_expression(result, context, &res_name)?;
4042                        self.named_expressions.insert(result, res_name);
4043                        fun.to_msl()
4044                    } else if context.resolve_type(value).scalar_width() == Some(8) {
4045                        fun.to_msl_64_bit()?
4046                    } else {
4047                        fun.to_msl()
4048                    };
4049
4050                    // If the pointer we're passing to the atomic operation needs to be conditional
4051                    // for `ReadZeroSkipWrite`, the condition needs to *surround* the atomic op, and
4052                    // the pointer operand should be unchecked.
4053                    let policy = context.choose_bounds_check_policy(pointer);
4054                    let checked = policy == index::BoundsCheckPolicy::ReadZeroSkipWrite
4055                        && self.put_bounds_checks(pointer, context, back::Level(0), "")?;
4056
4057                    // If requested and successfully put bounds checks, continue the ternary expression.
4058                    if checked {
4059                        write!(self.out, " ? ")?;
4060                    }
4061
4062                    // Put the atomic function invocation.
4063                    match *fun {
4064                        crate::AtomicFunction::Exchange { compare: Some(cmp) } => {
4065                            write!(self.out, "{ATOMIC_COMP_EXCH_FUNCTION}({ATOMIC_REFERENCE}")?;
4066                            self.put_access_chain(pointer, policy, context)?;
4067                            write!(self.out, ", ")?;
4068                            self.put_expression(cmp, context, true)?;
4069                            write!(self.out, ", ")?;
4070                            self.put_expression(value, context, true)?;
4071                            write!(self.out, ")")?;
4072                        }
4073                        _ => {
4074                            write!(
4075                                self.out,
4076                                "{NAMESPACE}::atomic_{fun_key}_explicit({ATOMIC_REFERENCE}"
4077                            )?;
4078                            self.put_access_chain(pointer, policy, context)?;
4079                            write!(self.out, ", ")?;
4080                            self.put_expression(value, context, true)?;
4081                            write!(self.out, ", {NAMESPACE}::memory_order_relaxed)")?;
4082                        }
4083                    }
4084
4085                    // Finish the ternary expression.
4086                    if checked {
4087                        write!(self.out, " : DefaultConstructible()")?;
4088                    }
4089
4090                    // Done
4091                    writeln!(self.out, ";")?;
4092                }
4093                crate::Statement::ImageAtomic {
4094                    image,
4095                    coordinate,
4096                    array_index,
4097                    fun,
4098                    value,
4099                } => {
4100                    let address = TexelAddress {
4101                        coordinate,
4102                        array_index,
4103                        sample: None,
4104                        level: None,
4105                    };
4106                    self.put_image_atomic(level, image, &address, fun, value, context)?
4107                }
4108                crate::Statement::WorkGroupUniformLoad { pointer, result } => {
4109                    self.write_barrier(crate::Barrier::WORK_GROUP, level)?;
4110
4111                    write!(self.out, "{level}")?;
4112                    let name = self.namer.call("");
4113                    self.start_baking_expression(result, &context.expression, &name)?;
4114                    self.put_load(pointer, &context.expression, true)?;
4115                    self.named_expressions.insert(result, name);
4116
4117                    writeln!(self.out, ";")?;
4118                    self.write_barrier(crate::Barrier::WORK_GROUP, level)?;
4119                }
4120                crate::Statement::RayQuery { query, ref fun } => {
4121                    self.write_ray_query_stmt(level, context, query, fun)?;
4122                }
4123                crate::Statement::SubgroupBallot { result, predicate } => {
4124                    write!(self.out, "{level}")?;
4125                    let name = self.namer.call("");
4126                    self.start_baking_expression(result, &context.expression, &name)?;
4127                    self.named_expressions.insert(result, name);
4128                    write!(
4129                        self.out,
4130                        "{NAMESPACE}::uint4((uint64_t){NAMESPACE}::simd_ballot("
4131                    )?;
4132                    if let Some(predicate) = predicate {
4133                        self.put_expression(predicate, &context.expression, true)?;
4134                    } else {
4135                        write!(self.out, "true")?;
4136                    }
4137                    writeln!(self.out, "), 0, 0, 0);")?;
4138                }
4139                crate::Statement::SubgroupCollectiveOperation {
4140                    op,
4141                    collective_op,
4142                    argument,
4143                    result,
4144                } => {
4145                    write!(self.out, "{level}")?;
4146                    let name = self.namer.call("");
4147                    self.start_baking_expression(result, &context.expression, &name)?;
4148                    self.named_expressions.insert(result, name);
4149                    match (collective_op, op) {
4150                        (crate::CollectiveOperation::Reduce, crate::SubgroupOperation::All) => {
4151                            write!(self.out, "{NAMESPACE}::simd_all(")?
4152                        }
4153                        (crate::CollectiveOperation::Reduce, crate::SubgroupOperation::Any) => {
4154                            write!(self.out, "{NAMESPACE}::simd_any(")?
4155                        }
4156                        (crate::CollectiveOperation::Reduce, crate::SubgroupOperation::Add) => {
4157                            write!(self.out, "{NAMESPACE}::simd_sum(")?
4158                        }
4159                        (crate::CollectiveOperation::Reduce, crate::SubgroupOperation::Mul) => {
4160                            write!(self.out, "{NAMESPACE}::simd_product(")?
4161                        }
4162                        (crate::CollectiveOperation::Reduce, crate::SubgroupOperation::Max) => {
4163                            write!(self.out, "{NAMESPACE}::simd_max(")?
4164                        }
4165                        (crate::CollectiveOperation::Reduce, crate::SubgroupOperation::Min) => {
4166                            write!(self.out, "{NAMESPACE}::simd_min(")?
4167                        }
4168                        (crate::CollectiveOperation::Reduce, crate::SubgroupOperation::And) => {
4169                            write!(self.out, "{NAMESPACE}::simd_and(")?
4170                        }
4171                        (crate::CollectiveOperation::Reduce, crate::SubgroupOperation::Or) => {
4172                            write!(self.out, "{NAMESPACE}::simd_or(")?
4173                        }
4174                        (crate::CollectiveOperation::Reduce, crate::SubgroupOperation::Xor) => {
4175                            write!(self.out, "{NAMESPACE}::simd_xor(")?
4176                        }
4177                        (
4178                            crate::CollectiveOperation::ExclusiveScan,
4179                            crate::SubgroupOperation::Add,
4180                        ) => write!(self.out, "{NAMESPACE}::simd_prefix_exclusive_sum(")?,
4181                        (
4182                            crate::CollectiveOperation::ExclusiveScan,
4183                            crate::SubgroupOperation::Mul,
4184                        ) => write!(self.out, "{NAMESPACE}::simd_prefix_exclusive_product(")?,
4185                        (
4186                            crate::CollectiveOperation::InclusiveScan,
4187                            crate::SubgroupOperation::Add,
4188                        ) => write!(self.out, "{NAMESPACE}::simd_prefix_inclusive_sum(")?,
4189                        (
4190                            crate::CollectiveOperation::InclusiveScan,
4191                            crate::SubgroupOperation::Mul,
4192                        ) => write!(self.out, "{NAMESPACE}::simd_prefix_inclusive_product(")?,
4193                        _ => unimplemented!(),
4194                    }
4195                    self.put_expression(argument, &context.expression, true)?;
4196                    writeln!(self.out, ");")?;
4197                }
4198                crate::Statement::SubgroupGather {
4199                    mode,
4200                    argument,
4201                    result,
4202                } => {
4203                    write!(self.out, "{level}")?;
4204                    let name = self.namer.call("");
4205                    self.start_baking_expression(result, &context.expression, &name)?;
4206                    self.named_expressions.insert(result, name);
4207                    match mode {
4208                        crate::GatherMode::BroadcastFirst => {
4209                            write!(self.out, "{NAMESPACE}::simd_broadcast_first(")?;
4210                        }
4211                        crate::GatherMode::Broadcast(_) => {
4212                            write!(self.out, "{NAMESPACE}::simd_broadcast(")?;
4213                        }
4214                        crate::GatherMode::Shuffle(_) => {
4215                            write!(self.out, "{NAMESPACE}::simd_shuffle(")?;
4216                        }
4217                        crate::GatherMode::ShuffleDown(_) => {
4218                            write!(self.out, "{NAMESPACE}::simd_shuffle_down(")?;
4219                        }
4220                        crate::GatherMode::ShuffleUp(_) => {
4221                            write!(self.out, "{NAMESPACE}::simd_shuffle_up(")?;
4222                        }
4223                        crate::GatherMode::ShuffleXor(_) => {
4224                            write!(self.out, "{NAMESPACE}::simd_shuffle_xor(")?;
4225                        }
4226                        crate::GatherMode::QuadBroadcast(_) => {
4227                            write!(self.out, "{NAMESPACE}::quad_broadcast(")?;
4228                        }
4229                        crate::GatherMode::QuadSwap(_) => {
4230                            write!(self.out, "{NAMESPACE}::quad_shuffle_xor(")?;
4231                        }
4232                    }
4233                    self.put_expression(argument, &context.expression, true)?;
4234                    match mode {
4235                        crate::GatherMode::BroadcastFirst => {}
4236                        crate::GatherMode::Broadcast(index)
4237                        | crate::GatherMode::Shuffle(index)
4238                        | crate::GatherMode::ShuffleDown(index)
4239                        | crate::GatherMode::ShuffleUp(index)
4240                        | crate::GatherMode::ShuffleXor(index)
4241                        | crate::GatherMode::QuadBroadcast(index) => {
4242                            write!(self.out, ", ")?;
4243                            self.put_expression(index, &context.expression, true)?;
4244                        }
4245                        crate::GatherMode::QuadSwap(direction) => {
4246                            write!(self.out, ", ")?;
4247                            match direction {
4248                                crate::Direction::X => {
4249                                    write!(self.out, "1u")?;
4250                                }
4251                                crate::Direction::Y => {
4252                                    write!(self.out, "2u")?;
4253                                }
4254                                crate::Direction::Diagonal => {
4255                                    write!(self.out, "3u")?;
4256                                }
4257                            }
4258                        }
4259                    }
4260                    writeln!(self.out, ");")?;
4261                }
4262                crate::Statement::CooperativeStore { target, ref data } => {
4263                    write!(self.out, "{level}simdgroup_store(")?;
4264                    self.put_expression(target, &context.expression, true)?;
4265                    write!(self.out, ", &")?;
4266                    self.put_access_chain(
4267                        data.pointer,
4268                        context.expression.policies.index,
4269                        &context.expression,
4270                    )?;
4271                    write!(self.out, ", ")?;
4272                    self.put_expression(data.stride, &context.expression, true)?;
4273                    if data.row_major {
4274                        let matrix_origin = "0";
4275                        let transpose = true;
4276                        write!(self.out, ", {matrix_origin}, {transpose}")?;
4277                    }
4278                    writeln!(self.out, ");")?;
4279                }
4280                crate::Statement::RayPipelineFunction(_) => unreachable!(),
4281            }
4282        }
4283
4284        // un-emit expressions
4285        //TODO: take care of loop/continuing?
4286        for statement in statements {
4287            if let crate::Statement::Emit(ref range) = *statement {
4288                for handle in range.clone() {
4289                    self.named_expressions.shift_remove(&handle);
4290                }
4291            }
4292        }
4293        Ok(())
4294    }
4295
4296    fn put_store(
4297        &mut self,
4298        pointer: Handle<crate::Expression>,
4299        value: Handle<crate::Expression>,
4300        level: back::Level,
4301        context: &StatementContext,
4302    ) -> BackendResult {
4303        let policy = context.expression.choose_bounds_check_policy(pointer);
4304        if policy == index::BoundsCheckPolicy::ReadZeroSkipWrite
4305            && self.put_bounds_checks(pointer, &context.expression, level, "if (")?
4306        {
4307            writeln!(self.out, ") {{")?;
4308            self.put_unchecked_store(pointer, value, policy, level.next(), context)?;
4309            writeln!(self.out, "{level}}}")?;
4310        } else {
4311            self.put_unchecked_store(pointer, value, policy, level, context)?;
4312        }
4313
4314        Ok(())
4315    }
4316
4317    fn put_unchecked_store(
4318        &mut self,
4319        pointer: Handle<crate::Expression>,
4320        value: Handle<crate::Expression>,
4321        policy: index::BoundsCheckPolicy,
4322        level: back::Level,
4323        context: &StatementContext,
4324    ) -> BackendResult {
4325        let is_atomic_pointer = context
4326            .expression
4327            .resolve_type(pointer)
4328            .is_atomic_pointer(&context.expression.module.types);
4329
4330        if is_atomic_pointer {
4331            write!(
4332                self.out,
4333                "{level}{NAMESPACE}::atomic_store_explicit({ATOMIC_REFERENCE}"
4334            )?;
4335            self.put_access_chain(pointer, policy, &context.expression)?;
4336            write!(self.out, ", ")?;
4337            self.put_expression(value, &context.expression, true)?;
4338            writeln!(self.out, ", {NAMESPACE}::memory_order_relaxed);")?;
4339        } else {
4340            write!(self.out, "{level}")?;
4341            self.put_access_chain(pointer, policy, &context.expression)?;
4342            write!(self.out, " = ")?;
4343            self.put_expression(value, &context.expression, true)?;
4344            writeln!(self.out, ";")?;
4345        }
4346
4347        Ok(())
4348    }
4349
4350    pub fn write(
4351        &mut self,
4352        module: &crate::Module,
4353        info: &valid::ModuleInfo,
4354        options: &Options,
4355        pipeline_options: &PipelineOptions,
4356    ) -> Result<TranslationInfo, Error> {
4357        self.names.clear();
4358        self.namer.reset(
4359            module,
4360            &super::keywords::RESERVED_SET,
4361            proc::KeywordSet::empty(),
4362            proc::CaseInsensitiveKeywordSet::empty(),
4363            &[
4364                CLAMPED_LOD_LOAD_PREFIX,
4365                super::ray::INTERSECTION_FUNCTION_NAME,
4366            ],
4367            &mut self.names,
4368        );
4369        self.wrapped_functions.clear();
4370        self.struct_member_pads.clear();
4371
4372        writeln!(
4373            self.out,
4374            "// language: metal{}.{}",
4375            options.lang_version.0, options.lang_version.1
4376        )?;
4377        writeln!(self.out, "#include <metal_stdlib>")?;
4378        writeln!(self.out, "#include <simd/simd.h>")?;
4379        writeln!(self.out)?;
4380        // Work around Metal bug where `uint` is not available by default
4381        writeln!(self.out, "using {NAMESPACE}::uint;")?;
4382
4383        if module.uses_mesh_shaders() && options.lang_version < (3, 0) {
4384            return Err(Error::UnsupportedMeshShader);
4385        }
4386        self.needs_object_memory_barriers = module
4387            .entry_points
4388            .iter()
4389            .any(|e| e.stage == crate::ShaderStage::Task && e.task_payload.is_some());
4390
4391        if module.special_types.ray_desc.is_some()
4392            || module.special_types.ray_intersection.is_some()
4393        {
4394            if options.lang_version < (2, 4) {
4395                return Err(Error::UnsupportedRayTracing);
4396            }
4397        }
4398
4399        if options
4400            .bounds_check_policies
4401            .contains(index::BoundsCheckPolicy::ReadZeroSkipWrite)
4402        {
4403            self.put_default_constructible()?;
4404        }
4405        writeln!(self.out)?;
4406
4407        {
4408            // Make a `Vec` of all the `GlobalVariable`s that contain
4409            // runtime-sized arrays.
4410            let globals: Vec<Handle<crate::GlobalVariable>> = module
4411                .global_variables
4412                .iter()
4413                .filter(|&(_, var)| needs_array_length(var.ty, &module.types))
4414                .map(|(handle, _)| handle)
4415                .collect();
4416
4417            let mut buffer_indices = vec![];
4418            for vbm in &pipeline_options.vertex_buffer_mappings {
4419                buffer_indices.push(vbm.id);
4420            }
4421
4422            if !globals.is_empty() || !buffer_indices.is_empty() {
4423                writeln!(self.out, "struct _mslBufferSizes {{")?;
4424
4425                for global in globals {
4426                    writeln!(
4427                        self.out,
4428                        "{}uint {};",
4429                        back::INDENT,
4430                        ArraySizeMember(global)
4431                    )?;
4432                }
4433
4434                for idx in buffer_indices {
4435                    writeln!(self.out, "{}uint buffer_size{};", back::INDENT, idx)?;
4436                }
4437
4438                writeln!(self.out, "}};")?;
4439                writeln!(self.out)?;
4440            }
4441        };
4442
4443        self.write_type_defs(module)?;
4444        self.write_global_constants(module, info)?;
4445        self.write_functions(module, info, options, pipeline_options)
4446    }
4447
4448    /// Write the definition for the `DefaultConstructible` class.
4449    ///
4450    /// The [`ReadZeroSkipWrite`] bounds check policy requires us to be able to
4451    /// produce 'zero' values for any type, including structs, arrays, and so
4452    /// on. We could do this by emitting default constructor applications, but
4453    /// that would entail printing the name of the type, which is more trouble
4454    /// than you'd think. Instead, we just construct this magic C++14 class that
4455    /// can be converted to any type that can be default constructed, using
4456    /// template parameter inference to detect which type is needed, so we don't
4457    /// have to figure out the name.
4458    ///
4459    /// [`ReadZeroSkipWrite`]: index::BoundsCheckPolicy::ReadZeroSkipWrite
4460    fn put_default_constructible(&mut self) -> BackendResult {
4461        let tab = back::INDENT;
4462        writeln!(self.out, "struct DefaultConstructible {{")?;
4463        writeln!(self.out, "{tab}template<typename T>")?;
4464        writeln!(self.out, "{tab}operator T() && {{")?;
4465        writeln!(self.out, "{tab}{tab}return T {{}};")?;
4466        writeln!(self.out, "{tab}}}")?;
4467        writeln!(self.out, "}};")?;
4468        Ok(())
4469    }
4470
4471    fn write_type_defs(&mut self, module: &crate::Module) -> BackendResult {
4472        let mut generated_argument_buffer_wrapper = false;
4473        let mut generated_external_texture_wrapper = false;
4474        for (handle, ty) in module.types.iter() {
4475            match ty.inner {
4476                crate::TypeInner::BindingArray { .. } if !generated_argument_buffer_wrapper => {
4477                    writeln!(self.out, "template <typename T>")?;
4478                    writeln!(self.out, "struct {ARGUMENT_BUFFER_WRAPPER_STRUCT} {{")?;
4479                    writeln!(self.out, "{}T {WRAPPED_ARRAY_FIELD};", back::INDENT)?;
4480                    writeln!(self.out, "}};")?;
4481                    generated_argument_buffer_wrapper = true;
4482                }
4483                crate::TypeInner::Image {
4484                    class: crate::ImageClass::External,
4485                    ..
4486                } if !generated_external_texture_wrapper => {
4487                    let params_ty_name = &self.names
4488                        [&NameKey::Type(module.special_types.external_texture_params.unwrap())];
4489                    writeln!(self.out, "struct {EXTERNAL_TEXTURE_WRAPPER_STRUCT} {{")?;
4490                    writeln!(
4491                        self.out,
4492                        "{}{NAMESPACE}::texture2d<float, {NAMESPACE}::access::sample> plane0;",
4493                        back::INDENT
4494                    )?;
4495                    writeln!(
4496                        self.out,
4497                        "{}{NAMESPACE}::texture2d<float, {NAMESPACE}::access::sample> plane1;",
4498                        back::INDENT
4499                    )?;
4500                    writeln!(
4501                        self.out,
4502                        "{}{NAMESPACE}::texture2d<float, {NAMESPACE}::access::sample> plane2;",
4503                        back::INDENT
4504                    )?;
4505                    writeln!(self.out, "{}{params_ty_name} params;", back::INDENT)?;
4506                    writeln!(self.out, "}};")?;
4507                    generated_external_texture_wrapper = true;
4508                }
4509                _ => {}
4510            }
4511
4512            if !ty.needs_alias() {
4513                continue;
4514            }
4515            let name = &self.names[&NameKey::Type(handle)];
4516            match ty.inner {
4517                // Naga IR can pass around arrays by value, but Metal, following
4518                // C++, performs an array-to-pointer conversion (C++ [conv.array])
4519                // on expressions of array type, so assigning the array by value
4520                // isn't possible. However, Metal *does* assign structs by
4521                // value. So in our Metal output, we wrap all array types in
4522                // synthetic struct types:
4523                //
4524                //     struct type1 {
4525                //         float inner[10]
4526                //     };
4527                //
4528                // Then we carefully include `.inner` (`WRAPPED_ARRAY_FIELD`) in
4529                // any expression that actually wants access to the array.
4530                crate::TypeInner::Array {
4531                    base,
4532                    size,
4533                    stride: _,
4534                } => {
4535                    let base_name = TypeContext {
4536                        handle: base,
4537                        gctx: module.to_ctx(),
4538                        names: &self.names,
4539                        access: crate::StorageAccess::empty(),
4540                        first_time: false,
4541                    };
4542
4543                    match size.resolve(module.to_ctx())? {
4544                        proc::IndexableLength::Known(size) => {
4545                            writeln!(self.out, "struct {name} {{")?;
4546                            writeln!(
4547                                self.out,
4548                                "{}{} {}[{}];",
4549                                back::INDENT,
4550                                base_name,
4551                                WRAPPED_ARRAY_FIELD,
4552                                size
4553                            )?;
4554                            writeln!(self.out, "}};")?;
4555                        }
4556                        proc::IndexableLength::Dynamic => {
4557                            writeln!(self.out, "typedef {base_name} {name}[1];")?;
4558                        }
4559                    }
4560                }
4561                crate::TypeInner::Struct {
4562                    ref members, span, ..
4563                } => {
4564                    writeln!(self.out, "struct {name} {{")?;
4565                    let mut last_offset = 0;
4566                    for (index, member) in members.iter().enumerate() {
4567                        if member.offset > last_offset {
4568                            self.struct_member_pads.insert((handle, index as u32));
4569                            let pad = member.offset - last_offset;
4570                            writeln!(self.out, "{}char _pad{}[{}];", back::INDENT, index, pad)?;
4571                        }
4572                        let ty_inner = &module.types[member.ty].inner;
4573                        last_offset = member.offset + ty_inner.size(module.to_ctx());
4574
4575                        let member_name = &self.names[&NameKey::StructMember(handle, index as u32)];
4576
4577                        // If the member should be packed (as is the case for a misaligned vec3) issue a packed vector
4578                        match should_pack_struct_member(members, span, index, module) {
4579                            Some(scalar) => {
4580                                writeln!(
4581                                    self.out,
4582                                    "{}{}::packed_{}3 {};",
4583                                    back::INDENT,
4584                                    NAMESPACE,
4585                                    scalar.to_msl_name(),
4586                                    member_name
4587                                )?;
4588                            }
4589                            None => {
4590                                let base_name = TypeContext {
4591                                    handle: member.ty,
4592                                    gctx: module.to_ctx(),
4593                                    names: &self.names,
4594                                    access: crate::StorageAccess::empty(),
4595                                    first_time: false,
4596                                };
4597                                writeln!(
4598                                    self.out,
4599                                    "{}{} {};",
4600                                    back::INDENT,
4601                                    base_name,
4602                                    member_name
4603                                )?;
4604
4605                                // for 3-component vectors, add one component
4606                                if let crate::TypeInner::Vector {
4607                                    size: crate::VectorSize::Tri,
4608                                    scalar,
4609                                } = *ty_inner
4610                                {
4611                                    last_offset += scalar.width as u32;
4612                                }
4613                            }
4614                        }
4615                    }
4616                    if last_offset < span {
4617                        let pad = span - last_offset;
4618                        writeln!(
4619                            self.out,
4620                            "{}char _pad{}[{}];",
4621                            back::INDENT,
4622                            members.len(),
4623                            pad
4624                        )?;
4625                    }
4626                    writeln!(self.out, "}};")?;
4627                }
4628                _ => {
4629                    let ty_name = TypeContext {
4630                        handle,
4631                        gctx: module.to_ctx(),
4632                        names: &self.names,
4633                        access: crate::StorageAccess::empty(),
4634                        first_time: true,
4635                    };
4636                    writeln!(self.out, "typedef {ty_name} {name};")?;
4637                }
4638            }
4639        }
4640
4641        // Write functions to create special types.
4642        for (type_key, struct_ty) in module.special_types.predeclared_types.iter() {
4643            match type_key {
4644                &crate::PredeclaredType::ModfResult { size, scalar }
4645                | &crate::PredeclaredType::FrexpResult { size, scalar } => {
4646                    let arg_type_name_owner;
4647                    let arg_type_name = if let Some(size) = size {
4648                        arg_type_name_owner = format!(
4649                            "{NAMESPACE}::{}{}",
4650                            if scalar.width == 8 { "double" } else { "float" },
4651                            size as u8
4652                        );
4653                        &arg_type_name_owner
4654                    } else if scalar.width == 8 {
4655                        "double"
4656                    } else {
4657                        "float"
4658                    };
4659
4660                    let other_type_name_owner;
4661                    let (defined_func_name, called_func_name, other_type_name) =
4662                        if matches!(type_key, &crate::PredeclaredType::ModfResult { .. }) {
4663                            (MODF_FUNCTION, "modf", arg_type_name)
4664                        } else {
4665                            let other_type_name = if let Some(size) = size {
4666                                other_type_name_owner = format!("int{}", size as u8);
4667                                &other_type_name_owner
4668                            } else {
4669                                "int"
4670                            };
4671                            (FREXP_FUNCTION, "frexp", other_type_name)
4672                        };
4673
4674                    let struct_name = &self.names[&NameKey::Type(*struct_ty)];
4675
4676                    writeln!(self.out)?;
4677                    writeln!(
4678                        self.out,
4679                        "{struct_name} {defined_func_name}({arg_type_name} arg) {{
4680    {other_type_name} other;
4681    {arg_type_name} fract = {NAMESPACE}::{called_func_name}(arg, other);
4682    return {struct_name}{{ fract, other }};
4683}}"
4684                    )?;
4685                }
4686                &crate::PredeclaredType::AtomicCompareExchangeWeakResult(scalar) => {
4687                    let arg_type_name = scalar.to_msl_name();
4688                    let called_func_name = "atomic_compare_exchange_weak_explicit";
4689                    let defined_func_name = ATOMIC_COMP_EXCH_FUNCTION;
4690                    let struct_name = &self.names[&NameKey::Type(*struct_ty)];
4691
4692                    writeln!(self.out)?;
4693
4694                    for address_space_name in ["device", "threadgroup"] {
4695                        writeln!(
4696                            self.out,
4697                            "\
4698template <typename A>
4699{struct_name} {defined_func_name}(
4700    {address_space_name} A *atomic_ptr,
4701    {arg_type_name} cmp,
4702    {arg_type_name} v
4703) {{
4704    bool swapped = {NAMESPACE}::{called_func_name}(
4705        atomic_ptr, &cmp, v,
4706        metal::memory_order_relaxed, metal::memory_order_relaxed
4707    );
4708    return {struct_name}{{cmp, swapped}};
4709}}"
4710                        )?;
4711                    }
4712                }
4713            }
4714        }
4715
4716        Ok(())
4717    }
4718
4719    /// Writes all named constants
4720    fn write_global_constants(
4721        &mut self,
4722        module: &crate::Module,
4723        mod_info: &valid::ModuleInfo,
4724    ) -> BackendResult {
4725        let constants = module.constants.iter().filter(|&(_, c)| c.name.is_some());
4726
4727        for (handle, constant) in constants {
4728            let ty_name = TypeContext {
4729                handle: constant.ty,
4730                gctx: module.to_ctx(),
4731                names: &self.names,
4732                access: crate::StorageAccess::empty(),
4733                first_time: false,
4734            };
4735            let name = &self.names[&NameKey::Constant(handle)];
4736            write!(self.out, "constant {ty_name} {name} = ")?;
4737            self.put_const_expression(constant.init, module, mod_info, &module.global_expressions)?;
4738            writeln!(self.out, ";")?;
4739        }
4740
4741        Ok(())
4742    }
4743
4744    fn put_inline_sampler_properties(
4745        &mut self,
4746        level: back::Level,
4747        sampler: &sm::InlineSampler,
4748    ) -> BackendResult {
4749        for (&letter, address) in ['s', 't', 'r'].iter().zip(sampler.address.iter()) {
4750            writeln!(
4751                self.out,
4752                "{}{}::{}_address::{},",
4753                level,
4754                NAMESPACE,
4755                letter,
4756                address.as_str(),
4757            )?;
4758        }
4759        writeln!(
4760            self.out,
4761            "{}{}::mag_filter::{},",
4762            level,
4763            NAMESPACE,
4764            sampler.mag_filter.as_str(),
4765        )?;
4766        writeln!(
4767            self.out,
4768            "{}{}::min_filter::{},",
4769            level,
4770            NAMESPACE,
4771            sampler.min_filter.as_str(),
4772        )?;
4773        if let Some(filter) = sampler.mip_filter {
4774            writeln!(
4775                self.out,
4776                "{}{}::mip_filter::{},",
4777                level,
4778                NAMESPACE,
4779                filter.as_str(),
4780            )?;
4781        }
4782        // avoid setting it on platforms that don't support it
4783        if sampler.border_color != sm::BorderColor::TransparentBlack {
4784            writeln!(
4785                self.out,
4786                "{}{}::border_color::{},",
4787                level,
4788                NAMESPACE,
4789                sampler.border_color.as_str(),
4790            )?;
4791        }
4792        //TODO: I'm not able to feed this in a way that MSL likes:
4793        //>error: use of undeclared identifier 'lod_clamp'
4794        //>error: no member named 'max_anisotropy' in namespace 'metal'
4795        if false {
4796            if let Some(ref lod) = sampler.lod_clamp {
4797                writeln!(self.out, "{}lod_clamp({},{}),", level, lod.start, lod.end,)?;
4798            }
4799            if let Some(aniso) = sampler.max_anisotropy {
4800                writeln!(self.out, "{}max_anisotropy({}),", level, aniso.get(),)?;
4801            }
4802        }
4803        if sampler.compare_func != sm::CompareFunc::Never {
4804            writeln!(
4805                self.out,
4806                "{}{}::compare_func::{},",
4807                level,
4808                NAMESPACE,
4809                sampler.compare_func.as_str(),
4810            )?;
4811        }
4812        writeln!(
4813            self.out,
4814            "{}{}::coord::{}",
4815            level,
4816            NAMESPACE,
4817            sampler.coord.as_str()
4818        )?;
4819        Ok(())
4820    }
4821
4822    fn write_unpacking_function(
4823        &mut self,
4824        format: back::msl::VertexFormat,
4825    ) -> Result<(String, u32, Option<crate::VectorSize>, crate::Scalar), Error> {
4826        use crate::{Scalar, VectorSize};
4827        use back::msl::VertexFormat::*;
4828        match format {
4829            Uint8 => {
4830                let name = self.namer.call("unpackUint8");
4831                writeln!(self.out, "uint {name}(metal::uchar b0) {{")?;
4832                writeln!(self.out, "{}return uint(b0);", back::INDENT)?;
4833                writeln!(self.out, "}}")?;
4834                Ok((name, 1, None, Scalar::U32))
4835            }
4836            Uint8x2 => {
4837                let name = self.namer.call("unpackUint8x2");
4838                writeln!(
4839                    self.out,
4840                    "metal::uint2 {name}(metal::uchar b0, \
4841                                         metal::uchar b1) {{"
4842                )?;
4843                writeln!(self.out, "{}return metal::uint2(b0, b1);", back::INDENT)?;
4844                writeln!(self.out, "}}")?;
4845                Ok((name, 2, Some(VectorSize::Bi), Scalar::U32))
4846            }
4847            Uint8x4 => {
4848                let name = self.namer.call("unpackUint8x4");
4849                writeln!(
4850                    self.out,
4851                    "metal::uint4 {name}(metal::uchar b0, \
4852                                         metal::uchar b1, \
4853                                         metal::uchar b2, \
4854                                         metal::uchar b3) {{"
4855                )?;
4856                writeln!(
4857                    self.out,
4858                    "{}return metal::uint4(b0, b1, b2, b3);",
4859                    back::INDENT
4860                )?;
4861                writeln!(self.out, "}}")?;
4862                Ok((name, 4, Some(VectorSize::Quad), Scalar::U32))
4863            }
4864            Sint8 => {
4865                let name = self.namer.call("unpackSint8");
4866                writeln!(self.out, "int {name}(metal::uchar b0) {{")?;
4867                writeln!(self.out, "{}return int(as_type<char>(b0));", back::INDENT)?;
4868                writeln!(self.out, "}}")?;
4869                Ok((name, 1, None, Scalar::I32))
4870            }
4871            Sint8x2 => {
4872                let name = self.namer.call("unpackSint8x2");
4873                writeln!(
4874                    self.out,
4875                    "metal::int2 {name}(metal::uchar b0, \
4876                                        metal::uchar b1) {{"
4877                )?;
4878                writeln!(
4879                    self.out,
4880                    "{}return metal::int2(as_type<char>(b0), \
4881                                          as_type<char>(b1));",
4882                    back::INDENT
4883                )?;
4884                writeln!(self.out, "}}")?;
4885                Ok((name, 2, Some(VectorSize::Bi), Scalar::I32))
4886            }
4887            Sint8x4 => {
4888                let name = self.namer.call("unpackSint8x4");
4889                writeln!(
4890                    self.out,
4891                    "metal::int4 {name}(metal::uchar b0, \
4892                                        metal::uchar b1, \
4893                                        metal::uchar b2, \
4894                                        metal::uchar b3) {{"
4895                )?;
4896                writeln!(
4897                    self.out,
4898                    "{}return metal::int4(as_type<char>(b0), \
4899                                          as_type<char>(b1), \
4900                                          as_type<char>(b2), \
4901                                          as_type<char>(b3));",
4902                    back::INDENT
4903                )?;
4904                writeln!(self.out, "}}")?;
4905                Ok((name, 4, Some(VectorSize::Quad), Scalar::I32))
4906            }
4907            Unorm8 => {
4908                let name = self.namer.call("unpackUnorm8");
4909                writeln!(self.out, "float {name}(metal::uchar b0) {{")?;
4910                writeln!(
4911                    self.out,
4912                    "{}return float(float(b0) / 255.0f);",
4913                    back::INDENT
4914                )?;
4915                writeln!(self.out, "}}")?;
4916                Ok((name, 1, None, Scalar::F32))
4917            }
4918            Unorm8x2 => {
4919                let name = self.namer.call("unpackUnorm8x2");
4920                writeln!(
4921                    self.out,
4922                    "metal::float2 {name}(metal::uchar b0, \
4923                                          metal::uchar b1) {{"
4924                )?;
4925                writeln!(
4926                    self.out,
4927                    "{}return metal::float2(float(b0) / 255.0f, \
4928                                            float(b1) / 255.0f);",
4929                    back::INDENT
4930                )?;
4931                writeln!(self.out, "}}")?;
4932                Ok((name, 2, Some(VectorSize::Bi), Scalar::F32))
4933            }
4934            Unorm8x4 => {
4935                let name = self.namer.call("unpackUnorm8x4");
4936                writeln!(
4937                    self.out,
4938                    "metal::float4 {name}(metal::uchar b0, \
4939                                          metal::uchar b1, \
4940                                          metal::uchar b2, \
4941                                          metal::uchar b3) {{"
4942                )?;
4943                writeln!(
4944                    self.out,
4945                    "{}return metal::float4(float(b0) / 255.0f, \
4946                                            float(b1) / 255.0f, \
4947                                            float(b2) / 255.0f, \
4948                                            float(b3) / 255.0f);",
4949                    back::INDENT
4950                )?;
4951                writeln!(self.out, "}}")?;
4952                Ok((name, 4, Some(VectorSize::Quad), Scalar::F32))
4953            }
4954            Snorm8 => {
4955                let name = self.namer.call("unpackSnorm8");
4956                writeln!(self.out, "float {name}(metal::uchar b0) {{")?;
4957                writeln!(
4958                    self.out,
4959                    "{}return float(metal::max(-1.0f, as_type<char>(b0) / 127.0f));",
4960                    back::INDENT
4961                )?;
4962                writeln!(self.out, "}}")?;
4963                Ok((name, 1, None, Scalar::F32))
4964            }
4965            Snorm8x2 => {
4966                let name = self.namer.call("unpackSnorm8x2");
4967                writeln!(
4968                    self.out,
4969                    "metal::float2 {name}(metal::uchar b0, \
4970                                          metal::uchar b1) {{"
4971                )?;
4972                writeln!(
4973                    self.out,
4974                    "{}return metal::float2(metal::max(-1.0f, as_type<char>(b0) / 127.0f), \
4975                                            metal::max(-1.0f, as_type<char>(b1) / 127.0f));",
4976                    back::INDENT
4977                )?;
4978                writeln!(self.out, "}}")?;
4979                Ok((name, 2, Some(VectorSize::Bi), Scalar::F32))
4980            }
4981            Snorm8x4 => {
4982                let name = self.namer.call("unpackSnorm8x4");
4983                writeln!(
4984                    self.out,
4985                    "metal::float4 {name}(metal::uchar b0, \
4986                                          metal::uchar b1, \
4987                                          metal::uchar b2, \
4988                                          metal::uchar b3) {{"
4989                )?;
4990                writeln!(
4991                    self.out,
4992                    "{}return metal::float4(metal::max(-1.0f, as_type<char>(b0) / 127.0f), \
4993                                            metal::max(-1.0f, as_type<char>(b1) / 127.0f), \
4994                                            metal::max(-1.0f, as_type<char>(b2) / 127.0f), \
4995                                            metal::max(-1.0f, as_type<char>(b3) / 127.0f));",
4996                    back::INDENT
4997                )?;
4998                writeln!(self.out, "}}")?;
4999                Ok((name, 4, Some(VectorSize::Quad), Scalar::F32))
5000            }
5001            Uint16 => {
5002                let name = self.namer.call("unpackUint16");
5003                writeln!(
5004                    self.out,
5005                    "metal::uint {name}(metal::uint b0, \
5006                                        metal::uint b1) {{"
5007                )?;
5008                writeln!(
5009                    self.out,
5010                    "{}return metal::uint(b1 << 8 | b0);",
5011                    back::INDENT
5012                )?;
5013                writeln!(self.out, "}}")?;
5014                Ok((name, 2, None, Scalar::U32))
5015            }
5016            Uint16x2 => {
5017                let name = self.namer.call("unpackUint16x2");
5018                writeln!(
5019                    self.out,
5020                    "metal::uint2 {name}(metal::uint b0, \
5021                                         metal::uint b1, \
5022                                         metal::uint b2, \
5023                                         metal::uint b3) {{"
5024                )?;
5025                writeln!(
5026                    self.out,
5027                    "{}return metal::uint2(b1 << 8 | b0, \
5028                                           b3 << 8 | b2);",
5029                    back::INDENT
5030                )?;
5031                writeln!(self.out, "}}")?;
5032                Ok((name, 4, Some(VectorSize::Bi), Scalar::U32))
5033            }
5034            Uint16x4 => {
5035                let name = self.namer.call("unpackUint16x4");
5036                writeln!(
5037                    self.out,
5038                    "metal::uint4 {name}(metal::uint b0, \
5039                                         metal::uint b1, \
5040                                         metal::uint b2, \
5041                                         metal::uint b3, \
5042                                         metal::uint b4, \
5043                                         metal::uint b5, \
5044                                         metal::uint b6, \
5045                                         metal::uint b7) {{"
5046                )?;
5047                writeln!(
5048                    self.out,
5049                    "{}return metal::uint4(b1 << 8 | b0, \
5050                                           b3 << 8 | b2, \
5051                                           b5 << 8 | b4, \
5052                                           b7 << 8 | b6);",
5053                    back::INDENT
5054                )?;
5055                writeln!(self.out, "}}")?;
5056                Ok((name, 8, Some(VectorSize::Quad), Scalar::U32))
5057            }
5058            Sint16 => {
5059                let name = self.namer.call("unpackSint16");
5060                writeln!(
5061                    self.out,
5062                    "int {name}(metal::ushort b0, \
5063                                metal::ushort b1) {{"
5064                )?;
5065                writeln!(
5066                    self.out,
5067                    "{}return int(as_type<short>(metal::ushort(b1 << 8 | b0)));",
5068                    back::INDENT
5069                )?;
5070                writeln!(self.out, "}}")?;
5071                Ok((name, 2, None, Scalar::I32))
5072            }
5073            Sint16x2 => {
5074                let name = self.namer.call("unpackSint16x2");
5075                writeln!(
5076                    self.out,
5077                    "metal::int2 {name}(metal::ushort b0, \
5078                                        metal::ushort b1, \
5079                                        metal::ushort b2, \
5080                                        metal::ushort b3) {{"
5081                )?;
5082                writeln!(
5083                    self.out,
5084                    "{}return metal::int2(as_type<short>(metal::ushort(b1 << 8 | b0)), \
5085                                          as_type<short>(metal::ushort(b3 << 8 | b2)));",
5086                    back::INDENT
5087                )?;
5088                writeln!(self.out, "}}")?;
5089                Ok((name, 4, Some(VectorSize::Bi), Scalar::I32))
5090            }
5091            Sint16x4 => {
5092                let name = self.namer.call("unpackSint16x4");
5093                writeln!(
5094                    self.out,
5095                    "metal::int4 {name}(metal::ushort b0, \
5096                                        metal::ushort b1, \
5097                                        metal::ushort b2, \
5098                                        metal::ushort b3, \
5099                                        metal::ushort b4, \
5100                                        metal::ushort b5, \
5101                                        metal::ushort b6, \
5102                                        metal::ushort b7) {{"
5103                )?;
5104                writeln!(
5105                    self.out,
5106                    "{}return metal::int4(as_type<short>(metal::ushort(b1 << 8 | b0)), \
5107                                          as_type<short>(metal::ushort(b3 << 8 | b2)), \
5108                                          as_type<short>(metal::ushort(b5 << 8 | b4)), \
5109                                          as_type<short>(metal::ushort(b7 << 8 | b6)));",
5110                    back::INDENT
5111                )?;
5112                writeln!(self.out, "}}")?;
5113                Ok((name, 8, Some(VectorSize::Quad), Scalar::I32))
5114            }
5115            Unorm16 => {
5116                let name = self.namer.call("unpackUnorm16");
5117                writeln!(
5118                    self.out,
5119                    "float {name}(metal::ushort b0, \
5120                                  metal::ushort b1) {{"
5121                )?;
5122                writeln!(
5123                    self.out,
5124                    "{}return float(float(b1 << 8 | b0) / 65535.0f);",
5125                    back::INDENT
5126                )?;
5127                writeln!(self.out, "}}")?;
5128                Ok((name, 2, None, Scalar::F32))
5129            }
5130            Unorm16x2 => {
5131                let name = self.namer.call("unpackUnorm16x2");
5132                writeln!(
5133                    self.out,
5134                    "metal::float2 {name}(metal::ushort b0, \
5135                                          metal::ushort b1, \
5136                                          metal::ushort b2, \
5137                                          metal::ushort b3) {{"
5138                )?;
5139                writeln!(
5140                    self.out,
5141                    "{}return metal::float2(float(b1 << 8 | b0) / 65535.0f, \
5142                                            float(b3 << 8 | b2) / 65535.0f);",
5143                    back::INDENT
5144                )?;
5145                writeln!(self.out, "}}")?;
5146                Ok((name, 4, Some(VectorSize::Bi), Scalar::F32))
5147            }
5148            Unorm16x4 => {
5149                let name = self.namer.call("unpackUnorm16x4");
5150                writeln!(
5151                    self.out,
5152                    "metal::float4 {name}(metal::ushort b0, \
5153                                          metal::ushort b1, \
5154                                          metal::ushort b2, \
5155                                          metal::ushort b3, \
5156                                          metal::ushort b4, \
5157                                          metal::ushort b5, \
5158                                          metal::ushort b6, \
5159                                          metal::ushort b7) {{"
5160                )?;
5161                writeln!(
5162                    self.out,
5163                    "{}return metal::float4(float(b1 << 8 | b0) / 65535.0f, \
5164                                            float(b3 << 8 | b2) / 65535.0f, \
5165                                            float(b5 << 8 | b4) / 65535.0f, \
5166                                            float(b7 << 8 | b6) / 65535.0f);",
5167                    back::INDENT
5168                )?;
5169                writeln!(self.out, "}}")?;
5170                Ok((name, 8, Some(VectorSize::Quad), Scalar::F32))
5171            }
5172            Snorm16 => {
5173                let name = self.namer.call("unpackSnorm16");
5174                writeln!(
5175                    self.out,
5176                    "float {name}(metal::ushort b0, \
5177                                  metal::ushort b1) {{"
5178                )?;
5179                writeln!(
5180                    self.out,
5181                    "{}return metal::unpack_snorm2x16_to_float(b1 << 8 | b0).x;",
5182                    back::INDENT
5183                )?;
5184                writeln!(self.out, "}}")?;
5185                Ok((name, 2, None, Scalar::F32))
5186            }
5187            Snorm16x2 => {
5188                let name = self.namer.call("unpackSnorm16x2");
5189                writeln!(
5190                    self.out,
5191                    "metal::float2 {name}(uint b0, \
5192                                          uint b1, \
5193                                          uint b2, \
5194                                          uint b3) {{"
5195                )?;
5196                writeln!(
5197                    self.out,
5198                    "{}return metal::unpack_snorm2x16_to_float(b3 << 24 | b2 << 16 | b1 << 8 | b0);",
5199                    back::INDENT
5200                )?;
5201                writeln!(self.out, "}}")?;
5202                Ok((name, 4, Some(VectorSize::Bi), Scalar::F32))
5203            }
5204            Snorm16x4 => {
5205                let name = self.namer.call("unpackSnorm16x4");
5206                writeln!(
5207                    self.out,
5208                    "metal::float4 {name}(uint b0, \
5209                                          uint b1, \
5210                                          uint b2, \
5211                                          uint b3, \
5212                                          uint b4, \
5213                                          uint b5, \
5214                                          uint b6, \
5215                                          uint b7) {{"
5216                )?;
5217                writeln!(
5218                    self.out,
5219                    "{}return metal::float4(metal::unpack_snorm2x16_to_float(b3 << 24 | b2 << 16 | b1 << 8 | b0), \
5220                                            metal::unpack_snorm2x16_to_float(b7 << 24 | b6 << 16 | b5 << 8 | b4));",
5221                    back::INDENT
5222                )?;
5223                writeln!(self.out, "}}")?;
5224                Ok((name, 8, Some(VectorSize::Quad), Scalar::F32))
5225            }
5226            Float16 => {
5227                let name = self.namer.call("unpackFloat16");
5228                writeln!(
5229                    self.out,
5230                    "float {name}(metal::ushort b0, \
5231                                  metal::ushort b1) {{"
5232                )?;
5233                writeln!(
5234                    self.out,
5235                    "{}return float(as_type<half>(metal::ushort(b1 << 8 | b0)));",
5236                    back::INDENT
5237                )?;
5238                writeln!(self.out, "}}")?;
5239                Ok((name, 2, None, Scalar::F32))
5240            }
5241            Float16x2 => {
5242                let name = self.namer.call("unpackFloat16x2");
5243                writeln!(
5244                    self.out,
5245                    "metal::float2 {name}(metal::ushort b0, \
5246                                          metal::ushort b1, \
5247                                          metal::ushort b2, \
5248                                          metal::ushort b3) {{"
5249                )?;
5250                writeln!(
5251                    self.out,
5252                    "{}return metal::float2(as_type<half>(metal::ushort(b1 << 8 | b0)), \
5253                                            as_type<half>(metal::ushort(b3 << 8 | b2)));",
5254                    back::INDENT
5255                )?;
5256                writeln!(self.out, "}}")?;
5257                Ok((name, 4, Some(VectorSize::Bi), Scalar::F32))
5258            }
5259            Float16x4 => {
5260                let name = self.namer.call("unpackFloat16x4");
5261                writeln!(
5262                    self.out,
5263                    "metal::float4 {name}(metal::ushort b0, \
5264                                        metal::ushort b1, \
5265                                        metal::ushort b2, \
5266                                        metal::ushort b3, \
5267                                        metal::ushort b4, \
5268                                        metal::ushort b5, \
5269                                        metal::ushort b6, \
5270                                        metal::ushort b7) {{"
5271                )?;
5272                writeln!(
5273                    self.out,
5274                    "{}return metal::float4(as_type<half>(metal::ushort(b1 << 8 | b0)), \
5275                                          as_type<half>(metal::ushort(b3 << 8 | b2)), \
5276                                          as_type<half>(metal::ushort(b5 << 8 | b4)), \
5277                                          as_type<half>(metal::ushort(b7 << 8 | b6)));",
5278                    back::INDENT
5279                )?;
5280                writeln!(self.out, "}}")?;
5281                Ok((name, 8, Some(VectorSize::Quad), Scalar::F32))
5282            }
5283            Float32 => {
5284                let name = self.namer.call("unpackFloat32");
5285                writeln!(
5286                    self.out,
5287                    "float {name}(uint b0, \
5288                                  uint b1, \
5289                                  uint b2, \
5290                                  uint b3) {{"
5291                )?;
5292                writeln!(
5293                    self.out,
5294                    "{}return as_type<float>(b3 << 24 | b2 << 16 | b1 << 8 | b0);",
5295                    back::INDENT
5296                )?;
5297                writeln!(self.out, "}}")?;
5298                Ok((name, 4, None, Scalar::F32))
5299            }
5300            Float32x2 => {
5301                let name = self.namer.call("unpackFloat32x2");
5302                writeln!(
5303                    self.out,
5304                    "metal::float2 {name}(uint b0, \
5305                                          uint b1, \
5306                                          uint b2, \
5307                                          uint b3, \
5308                                          uint b4, \
5309                                          uint b5, \
5310                                          uint b6, \
5311                                          uint b7) {{"
5312                )?;
5313                writeln!(
5314                    self.out,
5315                    "{}return metal::float2(as_type<float>(b3 << 24 | b2 << 16 | b1 << 8 | b0), \
5316                                            as_type<float>(b7 << 24 | b6 << 16 | b5 << 8 | b4));",
5317                    back::INDENT
5318                )?;
5319                writeln!(self.out, "}}")?;
5320                Ok((name, 8, Some(VectorSize::Bi), Scalar::F32))
5321            }
5322            Float32x3 => {
5323                let name = self.namer.call("unpackFloat32x3");
5324                writeln!(
5325                    self.out,
5326                    "metal::float3 {name}(uint b0, \
5327                                          uint b1, \
5328                                          uint b2, \
5329                                          uint b3, \
5330                                          uint b4, \
5331                                          uint b5, \
5332                                          uint b6, \
5333                                          uint b7, \
5334                                          uint b8, \
5335                                          uint b9, \
5336                                          uint b10, \
5337                                          uint b11) {{"
5338                )?;
5339                writeln!(
5340                    self.out,
5341                    "{}return metal::float3(as_type<float>(b3 << 24 | b2 << 16 | b1 << 8 | b0), \
5342                                            as_type<float>(b7 << 24 | b6 << 16 | b5 << 8 | b4), \
5343                                            as_type<float>(b11 << 24 | b10 << 16 | b9 << 8 | b8));",
5344                    back::INDENT
5345                )?;
5346                writeln!(self.out, "}}")?;
5347                Ok((name, 12, Some(VectorSize::Tri), Scalar::F32))
5348            }
5349            Float32x4 => {
5350                let name = self.namer.call("unpackFloat32x4");
5351                writeln!(
5352                    self.out,
5353                    "metal::float4 {name}(uint b0, \
5354                                          uint b1, \
5355                                          uint b2, \
5356                                          uint b3, \
5357                                          uint b4, \
5358                                          uint b5, \
5359                                          uint b6, \
5360                                          uint b7, \
5361                                          uint b8, \
5362                                          uint b9, \
5363                                          uint b10, \
5364                                          uint b11, \
5365                                          uint b12, \
5366                                          uint b13, \
5367                                          uint b14, \
5368                                          uint b15) {{"
5369                )?;
5370                writeln!(
5371                    self.out,
5372                    "{}return metal::float4(as_type<float>(b3 << 24 | b2 << 16 | b1 << 8 | b0), \
5373                                            as_type<float>(b7 << 24 | b6 << 16 | b5 << 8 | b4), \
5374                                            as_type<float>(b11 << 24 | b10 << 16 | b9 << 8 | b8), \
5375                                            as_type<float>(b15 << 24 | b14 << 16 | b13 << 8 | b12));",
5376                    back::INDENT
5377                )?;
5378                writeln!(self.out, "}}")?;
5379                Ok((name, 16, Some(VectorSize::Quad), Scalar::F32))
5380            }
5381            Uint32 => {
5382                let name = self.namer.call("unpackUint32");
5383                writeln!(
5384                    self.out,
5385                    "uint {name}(uint b0, \
5386                                 uint b1, \
5387                                 uint b2, \
5388                                 uint b3) {{"
5389                )?;
5390                writeln!(
5391                    self.out,
5392                    "{}return (b3 << 24 | b2 << 16 | b1 << 8 | b0);",
5393                    back::INDENT
5394                )?;
5395                writeln!(self.out, "}}")?;
5396                Ok((name, 4, None, Scalar::U32))
5397            }
5398            Uint32x2 => {
5399                let name = self.namer.call("unpackUint32x2");
5400                writeln!(
5401                    self.out,
5402                    "uint2 {name}(uint b0, \
5403                                  uint b1, \
5404                                  uint b2, \
5405                                  uint b3, \
5406                                  uint b4, \
5407                                  uint b5, \
5408                                  uint b6, \
5409                                  uint b7) {{"
5410                )?;
5411                writeln!(
5412                    self.out,
5413                    "{}return uint2((b3 << 24 | b2 << 16 | b1 << 8 | b0), \
5414                                    (b7 << 24 | b6 << 16 | b5 << 8 | b4));",
5415                    back::INDENT
5416                )?;
5417                writeln!(self.out, "}}")?;
5418                Ok((name, 8, Some(VectorSize::Bi), Scalar::U32))
5419            }
5420            Uint32x3 => {
5421                let name = self.namer.call("unpackUint32x3");
5422                writeln!(
5423                    self.out,
5424                    "uint3 {name}(uint b0, \
5425                                  uint b1, \
5426                                  uint b2, \
5427                                  uint b3, \
5428                                  uint b4, \
5429                                  uint b5, \
5430                                  uint b6, \
5431                                  uint b7, \
5432                                  uint b8, \
5433                                  uint b9, \
5434                                  uint b10, \
5435                                  uint b11) {{"
5436                )?;
5437                writeln!(
5438                    self.out,
5439                    "{}return uint3((b3 << 24 | b2 << 16 | b1 << 8 | b0), \
5440                                    (b7 << 24 | b6 << 16 | b5 << 8 | b4), \
5441                                    (b11 << 24 | b10 << 16 | b9 << 8 | b8));",
5442                    back::INDENT
5443                )?;
5444                writeln!(self.out, "}}")?;
5445                Ok((name, 12, Some(VectorSize::Tri), Scalar::U32))
5446            }
5447            Uint32x4 => {
5448                let name = self.namer.call("unpackUint32x4");
5449                writeln!(
5450                    self.out,
5451                    "{NAMESPACE}::uint4 {name}(uint b0, \
5452                                  uint b1, \
5453                                  uint b2, \
5454                                  uint b3, \
5455                                  uint b4, \
5456                                  uint b5, \
5457                                  uint b6, \
5458                                  uint b7, \
5459                                  uint b8, \
5460                                  uint b9, \
5461                                  uint b10, \
5462                                  uint b11, \
5463                                  uint b12, \
5464                                  uint b13, \
5465                                  uint b14, \
5466                                  uint b15) {{"
5467                )?;
5468                writeln!(
5469                    self.out,
5470                    "{}return {NAMESPACE}::uint4((b3 << 24 | b2 << 16 | b1 << 8 | b0), \
5471                                    (b7 << 24 | b6 << 16 | b5 << 8 | b4), \
5472                                    (b11 << 24 | b10 << 16 | b9 << 8 | b8), \
5473                                    (b15 << 24 | b14 << 16 | b13 << 8 | b12));",
5474                    back::INDENT
5475                )?;
5476                writeln!(self.out, "}}")?;
5477                Ok((name, 16, Some(VectorSize::Quad), Scalar::U32))
5478            }
5479            Sint32 => {
5480                let name = self.namer.call("unpackSint32");
5481                writeln!(
5482                    self.out,
5483                    "int {name}(uint b0, \
5484                                uint b1, \
5485                                uint b2, \
5486                                uint b3) {{"
5487                )?;
5488                writeln!(
5489                    self.out,
5490                    "{}return as_type<int>(b3 << 24 | b2 << 16 | b1 << 8 | b0);",
5491                    back::INDENT
5492                )?;
5493                writeln!(self.out, "}}")?;
5494                Ok((name, 4, None, Scalar::I32))
5495            }
5496            Sint32x2 => {
5497                let name = self.namer.call("unpackSint32x2");
5498                writeln!(
5499                    self.out,
5500                    "metal::int2 {name}(uint b0, \
5501                                        uint b1, \
5502                                        uint b2, \
5503                                        uint b3, \
5504                                        uint b4, \
5505                                        uint b5, \
5506                                        uint b6, \
5507                                        uint b7) {{"
5508                )?;
5509                writeln!(
5510                    self.out,
5511                    "{}return metal::int2(as_type<int>(b3 << 24 | b2 << 16 | b1 << 8 | b0), \
5512                                          as_type<int>(b7 << 24 | b6 << 16 | b5 << 8 | b4));",
5513                    back::INDENT
5514                )?;
5515                writeln!(self.out, "}}")?;
5516                Ok((name, 8, Some(VectorSize::Bi), Scalar::I32))
5517            }
5518            Sint32x3 => {
5519                let name = self.namer.call("unpackSint32x3");
5520                writeln!(
5521                    self.out,
5522                    "metal::int3 {name}(uint b0, \
5523                                        uint b1, \
5524                                        uint b2, \
5525                                        uint b3, \
5526                                        uint b4, \
5527                                        uint b5, \
5528                                        uint b6, \
5529                                        uint b7, \
5530                                        uint b8, \
5531                                        uint b9, \
5532                                        uint b10, \
5533                                        uint b11) {{"
5534                )?;
5535                writeln!(
5536                    self.out,
5537                    "{}return metal::int3(as_type<int>(b3 << 24 | b2 << 16 | b1 << 8 | b0), \
5538                                          as_type<int>(b7 << 24 | b6 << 16 | b5 << 8 | b4), \
5539                                          as_type<int>(b11 << 24 | b10 << 16 | b9 << 8 | b8));",
5540                    back::INDENT
5541                )?;
5542                writeln!(self.out, "}}")?;
5543                Ok((name, 12, Some(VectorSize::Tri), Scalar::I32))
5544            }
5545            Sint32x4 => {
5546                let name = self.namer.call("unpackSint32x4");
5547                writeln!(
5548                    self.out,
5549                    "metal::int4 {name}(uint b0, \
5550                                        uint b1, \
5551                                        uint b2, \
5552                                        uint b3, \
5553                                        uint b4, \
5554                                        uint b5, \
5555                                        uint b6, \
5556                                        uint b7, \
5557                                        uint b8, \
5558                                        uint b9, \
5559                                        uint b10, \
5560                                        uint b11, \
5561                                        uint b12, \
5562                                        uint b13, \
5563                                        uint b14, \
5564                                        uint b15) {{"
5565                )?;
5566                writeln!(
5567                    self.out,
5568                    "{}return metal::int4(as_type<int>(b3 << 24 | b2 << 16 | b1 << 8 | b0), \
5569                                          as_type<int>(b7 << 24 | b6 << 16 | b5 << 8 | b4), \
5570                                          as_type<int>(b11 << 24 | b10 << 16 | b9 << 8 | b8), \
5571                                          as_type<int>(b15 << 24 | b14 << 16 | b13 << 8 | b12));",
5572                    back::INDENT
5573                )?;
5574                writeln!(self.out, "}}")?;
5575                Ok((name, 16, Some(VectorSize::Quad), Scalar::I32))
5576            }
5577            Unorm10_10_10_2 => {
5578                let name = self.namer.call("unpackUnorm10_10_10_2");
5579                writeln!(
5580                    self.out,
5581                    "metal::float4 {name}(uint b0, \
5582                                          uint b1, \
5583                                          uint b2, \
5584                                          uint b3) {{"
5585                )?;
5586                writeln!(
5587                    self.out,
5588                    // The following is correct for RGBA packing, but our format seems to
5589                    // match ABGR, which can be fed into the Metal builtin function
5590                    // unpack_unorm10a2_to_float.
5591                    /*
5592                    "{}uint v = (b3 << 24 | b2 << 16 | b1 << 8 | b0); \
5593                       uint r = (v & 0xFFC00000) >> 22; \
5594                       uint g = (v & 0x003FF000) >> 12; \
5595                       uint b = (v & 0x00000FFC) >> 2; \
5596                       uint a = (v & 0x00000003); \
5597                       return metal::float4(float(r) / 1023.0f, float(g) / 1023.0f, float(b) / 1023.0f, float(a) / 3.0f);",
5598                    */
5599                    "{}return metal::unpack_unorm10a2_to_float(b3 << 24 | b2 << 16 | b1 << 8 | b0);",
5600                    back::INDENT
5601                )?;
5602                writeln!(self.out, "}}")?;
5603                Ok((name, 4, Some(VectorSize::Quad), Scalar::F32))
5604            }
5605            Unorm8x4Bgra => {
5606                let name = self.namer.call("unpackUnorm8x4Bgra");
5607                writeln!(
5608                    self.out,
5609                    "metal::float4 {name}(metal::uchar b0, \
5610                                          metal::uchar b1, \
5611                                          metal::uchar b2, \
5612                                          metal::uchar b3) {{"
5613                )?;
5614                writeln!(
5615                    self.out,
5616                    "{}return metal::float4(float(b2) / 255.0f, \
5617                                            float(b1) / 255.0f, \
5618                                            float(b0) / 255.0f, \
5619                                            float(b3) / 255.0f);",
5620                    back::INDENT
5621                )?;
5622                writeln!(self.out, "}}")?;
5623                Ok((name, 4, Some(VectorSize::Quad), Scalar::F32))
5624            }
5625        }
5626    }
5627
5628    fn write_wrapped_unary_op(
5629        &mut self,
5630        module: &crate::Module,
5631        func_ctx: &back::FunctionCtx,
5632        op: crate::UnaryOperator,
5633        operand: Handle<crate::Expression>,
5634    ) -> BackendResult {
5635        let operand_ty = func_ctx.resolve_type(operand, &module.types);
5636        match op {
5637            // Negating the TYPE_MIN of a two's complement signed integer
5638            // type causes overflow, which is undefined behaviour in MSL. To
5639            // avoid this we bitcast the value to unsigned and negate it,
5640            // then bitcast back to signed.
5641            // This adheres to the WGSL spec in that the negative of the
5642            // type's minimum value should equal to the minimum value.
5643            crate::UnaryOperator::Negate
5644                if operand_ty.scalar_kind() == Some(crate::ScalarKind::Sint) =>
5645            {
5646                let Some((vector_size, scalar)) = operand_ty.vector_size_and_scalar() else {
5647                    return Ok(());
5648                };
5649                let wrapped = WrappedFunction::UnaryOp {
5650                    op,
5651                    ty: (vector_size, scalar),
5652                };
5653                if !self.wrapped_functions.insert(wrapped) {
5654                    return Ok(());
5655                }
5656
5657                let unsigned_scalar = crate::Scalar {
5658                    kind: crate::ScalarKind::Uint,
5659                    ..scalar
5660                };
5661                let mut type_name = String::new();
5662                let mut unsigned_type_name = String::new();
5663                match vector_size {
5664                    None => {
5665                        put_numeric_type(&mut type_name, scalar, &[])?;
5666                        put_numeric_type(&mut unsigned_type_name, unsigned_scalar, &[])?
5667                    }
5668                    Some(size) => {
5669                        put_numeric_type(&mut type_name, scalar, &[size])?;
5670                        put_numeric_type(&mut unsigned_type_name, unsigned_scalar, &[size])?;
5671                    }
5672                };
5673
5674                writeln!(self.out, "{type_name} {NEG_FUNCTION}({type_name} val) {{")?;
5675                let level = back::Level(1);
5676                // For sub-32-bit types, C++ integer promotion widens
5677                // `-as_type<ushort>(val)` to `int`, so we need static_cast
5678                // to truncate back before the outer as_type bitcast.
5679                if scalar.width < 4 {
5680                    writeln!(
5681                        self.out,
5682                        "{level}return as_type<{type_name}>(static_cast<{unsigned_type_name}>(-as_type<{unsigned_type_name}>(val)));"
5683                    )?;
5684                } else {
5685                    writeln!(
5686                        self.out,
5687                        "{level}return as_type<{type_name}>(-as_type<{unsigned_type_name}>(val));"
5688                    )?;
5689                }
5690                writeln!(self.out, "}}")?;
5691                writeln!(self.out)?;
5692            }
5693            _ => {}
5694        }
5695        Ok(())
5696    }
5697
5698    fn write_wrapped_binary_op(
5699        &mut self,
5700        module: &crate::Module,
5701        func_ctx: &back::FunctionCtx,
5702        expr: Handle<crate::Expression>,
5703        op: crate::BinaryOperator,
5704        left: Handle<crate::Expression>,
5705        right: Handle<crate::Expression>,
5706    ) -> BackendResult {
5707        let expr_ty = func_ctx.resolve_type(expr, &module.types);
5708        let left_ty = func_ctx.resolve_type(left, &module.types);
5709        let right_ty = func_ctx.resolve_type(right, &module.types);
5710        match (op, expr_ty.scalar_kind()) {
5711            // Signed integer division of TYPE_MIN / -1, or signed or
5712            // unsigned division by zero, gives an unspecified value in MSL.
5713            // We override the divisor to 1 in these cases.
5714            // This adheres to the WGSL spec in that:
5715            // * TYPE_MIN / -1 == TYPE_MIN
5716            // * x / 0 == x
5717            (
5718                crate::BinaryOperator::Divide,
5719                Some(crate::ScalarKind::Sint | crate::ScalarKind::Uint),
5720            ) => {
5721                let Some(left_wrapped_ty) = left_ty.vector_size_and_scalar() else {
5722                    return Ok(());
5723                };
5724                let Some(right_wrapped_ty) = right_ty.vector_size_and_scalar() else {
5725                    return Ok(());
5726                };
5727                let wrapped = WrappedFunction::BinaryOp {
5728                    op,
5729                    left_ty: left_wrapped_ty,
5730                    right_ty: right_wrapped_ty,
5731                };
5732                if !self.wrapped_functions.insert(wrapped) {
5733                    return Ok(());
5734                }
5735
5736                let Some((vector_size, scalar)) = expr_ty.vector_size_and_scalar() else {
5737                    return Ok(());
5738                };
5739                let mut type_name = String::new();
5740                match vector_size {
5741                    None => put_numeric_type(&mut type_name, scalar, &[])?,
5742                    Some(size) => put_numeric_type(&mut type_name, scalar, &[size])?,
5743                };
5744                writeln!(
5745                    self.out,
5746                    "{type_name} {DIV_FUNCTION}({type_name} lhs, {type_name} rhs) {{"
5747                )?;
5748                let level = back::Level(1);
5749                // Sub-32-bit types need typed literal wrappers (e.g. `short(1)`)
5750                // to avoid ambiguous metal::select overloads. For >= 32-bit,
5751                // bare literals like `1`, `-1`, `0` are unambiguous.
5752                let (lp, rp) = if scalar.width < 4 {
5753                    (format!("{type_name}("), ")".to_string())
5754                } else {
5755                    (String::new(), String::new())
5756                };
5757                match scalar.kind {
5758                    crate::ScalarKind::Sint => {
5759                        let min_val = match scalar.width {
5760                            2 => crate::Literal::I16(i16::MIN),
5761                            4 => crate::Literal::I32(i32::MIN),
5762                            8 => crate::Literal::I64(i64::MIN),
5763                            _ => {
5764                                return Err(Error::GenericValidation(format!(
5765                                    "Unexpected width for scalar {scalar:?}"
5766                                )));
5767                            }
5768                        };
5769                        write!(
5770                            self.out,
5771                            "{level}return lhs / metal::select(rhs, {lp}1{rp}, (lhs == "
5772                        )?;
5773                        self.put_literal(min_val)?;
5774                        writeln!(self.out, " & rhs == {lp}-1{rp}) | (rhs == {lp}0{rp}));")?
5775                    }
5776                    crate::ScalarKind::Uint => {
5777                        let suffix = if scalar.width < 4 { "" } else { "u" };
5778                        writeln!(
5779                            self.out,
5780                            "{level}return lhs / metal::select(rhs, {lp}1{suffix}{rp}, rhs == {lp}0{suffix}{rp});"
5781                        )?
5782                    }
5783                    _ => unreachable!(),
5784                }
5785                writeln!(self.out, "}}")?;
5786                writeln!(self.out)?;
5787            }
5788            // Integer modulo where one or both operands are negative, or the
5789            // divisor is zero, is undefined behaviour in MSL. To avoid this
5790            // we use the following equation:
5791            //
5792            // dividend - (dividend / divisor) * divisor
5793            //
5794            // overriding the divisor to 1 if either it is 0, or it is -1
5795            // and the dividend is TYPE_MIN.
5796            //
5797            // This adheres to the WGSL spec in that:
5798            // * TYPE_MIN % -1 == 0
5799            // * x % 0 == 0
5800            (
5801                crate::BinaryOperator::Modulo,
5802                Some(crate::ScalarKind::Sint | crate::ScalarKind::Uint),
5803            ) => {
5804                let Some(left_wrapped_ty) = left_ty.vector_size_and_scalar() else {
5805                    return Ok(());
5806                };
5807                let Some((right_vector_size, right_scalar)) = right_ty.vector_size_and_scalar()
5808                else {
5809                    return Ok(());
5810                };
5811                let wrapped = WrappedFunction::BinaryOp {
5812                    op,
5813                    left_ty: left_wrapped_ty,
5814                    right_ty: (right_vector_size, right_scalar),
5815                };
5816                if !self.wrapped_functions.insert(wrapped) {
5817                    return Ok(());
5818                }
5819
5820                let Some((vector_size, scalar)) = expr_ty.vector_size_and_scalar() else {
5821                    return Ok(());
5822                };
5823                let mut type_name = String::new();
5824                match vector_size {
5825                    None => put_numeric_type(&mut type_name, scalar, &[])?,
5826                    Some(size) => put_numeric_type(&mut type_name, scalar, &[size])?,
5827                };
5828                let mut rhs_type_name = String::new();
5829                match right_vector_size {
5830                    None => put_numeric_type(&mut rhs_type_name, right_scalar, &[])?,
5831                    Some(size) => put_numeric_type(&mut rhs_type_name, right_scalar, &[size])?,
5832                };
5833
5834                writeln!(
5835                    self.out,
5836                    "{type_name} {MOD_FUNCTION}({type_name} lhs, {type_name} rhs) {{"
5837                )?;
5838                let level = back::Level(1);
5839                let (lp, rp) = if scalar.width < 4 {
5840                    (format!("{type_name}("), ")".to_string())
5841                } else {
5842                    (String::new(), String::new())
5843                };
5844                match scalar.kind {
5845                    crate::ScalarKind::Sint => {
5846                        let min_val = match scalar.width {
5847                            2 => crate::Literal::I16(i16::MIN),
5848                            4 => crate::Literal::I32(i32::MIN),
5849                            8 => crate::Literal::I64(i64::MIN),
5850                            _ => {
5851                                return Err(Error::GenericValidation(format!(
5852                                    "Unexpected width for scalar {scalar:?}"
5853                                )));
5854                            }
5855                        };
5856                        write!(
5857                            self.out,
5858                            "{level}{rhs_type_name} divisor = metal::select(rhs, {lp}1{rp}, (lhs == "
5859                        )?;
5860                        self.put_literal(min_val)?;
5861                        writeln!(self.out, " & rhs == {lp}-1{rp}) | (rhs == {lp}0{rp}));")?;
5862                        writeln!(self.out, "{level}return lhs - (lhs / divisor) * divisor;")?
5863                    }
5864                    crate::ScalarKind::Uint => {
5865                        let suffix = if scalar.width < 4 { "" } else { "u" };
5866                        writeln!(
5867                            self.out,
5868                            "{level}return lhs % metal::select(rhs, {lp}1{suffix}{rp}, rhs == {lp}0{suffix}{rp});"
5869                        )?
5870                    }
5871                    _ => unreachable!(),
5872                }
5873                writeln!(self.out, "}}")?;
5874                writeln!(self.out)?;
5875            }
5876            _ => {}
5877        }
5878        Ok(())
5879    }
5880
5881    /// Build the mangled helper name for integer vector dot products.
5882    ///
5883    /// `scalar` must be a concrete integer scalar type.
5884    ///
5885    /// Result format: `{DOT_FUNCTION_PREFIX}_{type}{N}` (e.g., `naga_dot_int3`).
5886    fn get_dot_wrapper_function_helper_name(
5887        &self,
5888        scalar: crate::Scalar,
5889        size: crate::VectorSize,
5890    ) -> String {
5891        // Check for consistency with [`super::keywords::RESERVED_SET`]
5892        debug_assert!(concrete_int_scalars().any(|s| s == scalar));
5893
5894        let type_name = scalar.to_msl_name();
5895        let size_suffix = common::vector_size_str(size);
5896        format!("{DOT_FUNCTION_PREFIX}_{type_name}{size_suffix}")
5897    }
5898
5899    #[allow(clippy::too_many_arguments)]
5900    fn write_wrapped_math_function(
5901        &mut self,
5902        module: &crate::Module,
5903        func_ctx: &back::FunctionCtx,
5904        fun: crate::MathFunction,
5905        arg: Handle<crate::Expression>,
5906        _arg1: Option<Handle<crate::Expression>>,
5907        _arg2: Option<Handle<crate::Expression>>,
5908        _arg3: Option<Handle<crate::Expression>>,
5909    ) -> BackendResult {
5910        let arg_ty = func_ctx.resolve_type(arg, &module.types);
5911        match fun {
5912            // Taking the absolute value of the TYPE_MIN of a two's
5913            // complement signed integer type causes overflow, which is
5914            // undefined behaviour in MSL. To avoid this, when the value is
5915            // negative we bitcast the value to unsigned and negate it, then
5916            // bitcast back to signed.
5917            // This adheres to the WGSL spec in that the absolute of the
5918            // type's minimum value should equal to the minimum value.
5919            crate::MathFunction::Abs if arg_ty.scalar_kind() == Some(crate::ScalarKind::Sint) => {
5920                let Some((vector_size, scalar)) = arg_ty.vector_size_and_scalar() else {
5921                    return Ok(());
5922                };
5923                let wrapped = WrappedFunction::Math {
5924                    fun,
5925                    arg_ty: (vector_size, scalar),
5926                };
5927                if !self.wrapped_functions.insert(wrapped) {
5928                    return Ok(());
5929                }
5930
5931                let unsigned_scalar = crate::Scalar {
5932                    kind: crate::ScalarKind::Uint,
5933                    ..scalar
5934                };
5935                let mut type_name = String::new();
5936                let mut unsigned_type_name = String::new();
5937                match vector_size {
5938                    None => {
5939                        put_numeric_type(&mut type_name, scalar, &[])?;
5940                        put_numeric_type(&mut unsigned_type_name, unsigned_scalar, &[])?
5941                    }
5942                    Some(size) => {
5943                        put_numeric_type(&mut type_name, scalar, &[size])?;
5944                        put_numeric_type(&mut unsigned_type_name, unsigned_scalar, &[size])?;
5945                    }
5946                };
5947
5948                writeln!(self.out, "{type_name} {ABS_FUNCTION}({type_name} val) {{")?;
5949                let level = back::Level(1);
5950                let zero = if scalar.width < 4 {
5951                    format!("{type_name}(0)")
5952                } else {
5953                    "0".to_string()
5954                };
5955                let neg_expr = if scalar.width < 4 {
5956                    format!(
5957                        "static_cast<{unsigned_type_name}>(-as_type<{unsigned_type_name}>(val))"
5958                    )
5959                } else {
5960                    format!("-as_type<{unsigned_type_name}>(val)")
5961                };
5962                writeln!(self.out, "{level}return metal::select(as_type<{type_name}>({neg_expr}), val, val >= {zero});")?;
5963                writeln!(self.out, "}}")?;
5964                writeln!(self.out)?;
5965            }
5966
5967            crate::MathFunction::Dot => match *arg_ty {
5968                crate::TypeInner::Vector { size, scalar }
5969                    if matches!(
5970                        scalar.kind,
5971                        crate::ScalarKind::Sint | crate::ScalarKind::Uint
5972                    ) =>
5973                {
5974                    // De-duplicate per (fun, arg type) like other wrapped math functions
5975                    let wrapped = WrappedFunction::Math {
5976                        fun,
5977                        arg_ty: (Some(size), scalar),
5978                    };
5979                    if !self.wrapped_functions.insert(wrapped) {
5980                        return Ok(());
5981                    }
5982
5983                    let mut vec_ty = String::new();
5984                    put_numeric_type(&mut vec_ty, scalar, &[size])?;
5985                    let mut ret_ty = String::new();
5986                    put_numeric_type(&mut ret_ty, scalar, &[])?;
5987
5988                    let fun_name = self.get_dot_wrapper_function_helper_name(scalar, size);
5989
5990                    // Emit function signature and body using put_dot_product for the expression
5991                    writeln!(self.out, "{ret_ty} {fun_name}({vec_ty} a, {vec_ty} b) {{")?;
5992                    let level = back::Level(1);
5993                    write!(self.out, "{level}return ")?;
5994                    self.put_dot_product("a", "b", size as usize, |writer, name, index| {
5995                        write!(writer.out, "{name}.{}", back::COMPONENTS[index])?;
5996                        Ok(())
5997                    })?;
5998                    writeln!(self.out, ";")?;
5999                    writeln!(self.out, "}}")?;
6000                    writeln!(self.out)?;
6001                }
6002                _ => {}
6003            },
6004
6005            _ => {}
6006        }
6007        Ok(())
6008    }
6009
6010    fn write_wrapped_cast(
6011        &mut self,
6012        module: &crate::Module,
6013        func_ctx: &back::FunctionCtx,
6014        expr: Handle<crate::Expression>,
6015        kind: crate::ScalarKind,
6016        convert: Option<crate::Bytes>,
6017    ) -> BackendResult {
6018        // Avoid undefined behaviour when casting from a float to integer
6019        // when the value is out of range for the target type. Additionally
6020        // ensure we clamp to the correct value as per the WGSL spec.
6021        //
6022        // https://www.w3.org/TR/WGSL/#floating-point-conversion:
6023        // * If X is exactly representable in the target type T, then the
6024        //   result is that value.
6025        // * Otherwise, the result is the value in T closest to
6026        //   truncate(X) and also exactly representable in the original
6027        //   floating point type.
6028        let src_ty = func_ctx.resolve_type(expr, &module.types);
6029        let Some(width) = convert else {
6030            return Ok(());
6031        };
6032        let Some((vector_size, src_scalar)) = src_ty.vector_size_and_scalar() else {
6033            return Ok(());
6034        };
6035        let dst_scalar = crate::Scalar { kind, width };
6036        if src_scalar.kind != crate::ScalarKind::Float
6037            || (dst_scalar.kind != crate::ScalarKind::Sint
6038                && dst_scalar.kind != crate::ScalarKind::Uint)
6039        {
6040            return Ok(());
6041        }
6042        let wrapped = WrappedFunction::Cast {
6043            src_scalar,
6044            vector_size,
6045            dst_scalar,
6046        };
6047        if !self.wrapped_functions.insert(wrapped) {
6048            return Ok(());
6049        }
6050        let (min, max) = proc::min_max_float_representable_by(src_scalar, dst_scalar);
6051
6052        let mut src_type_name = String::new();
6053        match vector_size {
6054            None => put_numeric_type(&mut src_type_name, src_scalar, &[])?,
6055            Some(size) => put_numeric_type(&mut src_type_name, src_scalar, &[size])?,
6056        };
6057        let mut dst_type_name = String::new();
6058        match vector_size {
6059            None => put_numeric_type(&mut dst_type_name, dst_scalar, &[])?,
6060            Some(size) => put_numeric_type(&mut dst_type_name, dst_scalar, &[size])?,
6061        };
6062        let fun_name = match dst_scalar {
6063            crate::Scalar::I32 => F2I32_FUNCTION,
6064            crate::Scalar::U32 => F2U32_FUNCTION,
6065            crate::Scalar::I64 => F2I64_FUNCTION,
6066            crate::Scalar::U64 => F2U64_FUNCTION,
6067            _ => unreachable!(),
6068        };
6069
6070        writeln!(
6071            self.out,
6072            "{dst_type_name} {fun_name}({src_type_name} value) {{"
6073        )?;
6074        let level = back::Level(1);
6075        write!(
6076            self.out,
6077            "{level}return static_cast<{dst_type_name}>({NAMESPACE}::clamp(value, "
6078        )?;
6079        self.put_literal(min)?;
6080        write!(self.out, ", ")?;
6081        self.put_literal(max)?;
6082        writeln!(self.out, "));")?;
6083        writeln!(self.out, "}}")?;
6084        writeln!(self.out)?;
6085        Ok(())
6086    }
6087
6088    /// Helper function used by [`Self::write_wrapped_image_load`] and
6089    /// [`Self::write_wrapped_image_sample`] to write the shared YUV to RGB
6090    /// conversion code for external textures. Expects the preceding code to
6091    /// declare the Y component as a `float` variable of name `y`, the UV
6092    /// components as a `float2` variable of name `uv`, and the external
6093    /// texture params as a variable of name `params`. The emitted code will
6094    /// return the result.
6095    fn write_convert_yuv_to_rgb_and_return(
6096        &mut self,
6097        level: back::Level,
6098        y: &str,
6099        uv: &str,
6100        params: &str,
6101    ) -> BackendResult {
6102        let l1 = level;
6103        let l2 = l1.next();
6104
6105        // Convert from YUV to non-linear RGB in the source color space.
6106        writeln!(
6107            self.out,
6108            "{l1}float3 srcGammaRgb = ({params}.yuv_conversion_matrix * float4({y}, {uv}, 1.0)).rgb;"
6109        )?;
6110
6111        // Apply the inverse of the source transfer function to convert to
6112        // linear RGB in the source color space.
6113        writeln!(self.out, "{l1}float3 srcLinearRgb = {NAMESPACE}::select(")?;
6114        writeln!(self.out, "{l2}{NAMESPACE}::pow((srcGammaRgb + {params}.src_tf.a - 1.0) / {params}.src_tf.a, {params}.src_tf.g),")?;
6115        writeln!(self.out, "{l2}srcGammaRgb / {params}.src_tf.k,")?;
6116        writeln!(
6117            self.out,
6118            "{l2}srcGammaRgb < {params}.src_tf.k * {params}.src_tf.b);"
6119        )?;
6120
6121        // Multiply by the gamut conversion matrix to convert to linear RGB in
6122        // the destination color space.
6123        writeln!(
6124            self.out,
6125            "{l1}float3 dstLinearRgb = {params}.gamut_conversion_matrix * srcLinearRgb;"
6126        )?;
6127
6128        // Finally, apply the dest transfer function to convert to non-linear
6129        // RGB in the destination color space, and return the result.
6130        writeln!(self.out, "{l1}float3 dstGammaRgb = {NAMESPACE}::select(")?;
6131        writeln!(self.out, "{l2}{params}.dst_tf.a * {NAMESPACE}::pow(dstLinearRgb, 1.0 / {params}.dst_tf.g) - ({params}.dst_tf.a - 1),")?;
6132        writeln!(self.out, "{l2}{params}.dst_tf.k * dstLinearRgb,")?;
6133        writeln!(self.out, "{l2}dstLinearRgb < {params}.dst_tf.b);")?;
6134
6135        writeln!(self.out, "{l1}return float4(dstGammaRgb, 1.0);")?;
6136        Ok(())
6137    }
6138
6139    #[allow(clippy::too_many_arguments)]
6140    fn write_wrapped_image_load(
6141        &mut self,
6142        module: &crate::Module,
6143        func_ctx: &back::FunctionCtx,
6144        image: Handle<crate::Expression>,
6145        _coordinate: Handle<crate::Expression>,
6146        _array_index: Option<Handle<crate::Expression>>,
6147        _sample: Option<Handle<crate::Expression>>,
6148        _level: Option<Handle<crate::Expression>>,
6149    ) -> BackendResult {
6150        // We currently only need to wrap image loads for external textures
6151        let class = match *func_ctx.resolve_type(image, &module.types) {
6152            crate::TypeInner::Image { class, .. } => class,
6153            _ => unreachable!(),
6154        };
6155        if class != crate::ImageClass::External {
6156            return Ok(());
6157        }
6158        let wrapped = WrappedFunction::ImageLoad { class };
6159        if !self.wrapped_functions.insert(wrapped) {
6160            return Ok(());
6161        }
6162
6163        writeln!(self.out, "float4 {IMAGE_LOAD_EXTERNAL_FUNCTION}({EXTERNAL_TEXTURE_WRAPPER_STRUCT} tex, uint2 coords) {{")?;
6164        let l1 = back::Level(1);
6165        let l2 = l1.next();
6166        let l3 = l2.next();
6167        writeln!(
6168            self.out,
6169            "{l1}uint2 plane0_size = uint2(tex.plane0.get_width(), tex.plane0.get_height());"
6170        )?;
6171        // Clamp coords to provided size of external texture to prevent OOB
6172        // read. If params.size is zero then clamp to the actual size of the
6173        // texture.
6174        writeln!(
6175            self.out,
6176            "{l1}uint2 cropped_size = {NAMESPACE}::any(tex.params.size != 0) ? tex.params.size : plane0_size;"
6177        )?;
6178        writeln!(
6179            self.out,
6180            "{l1}coords = {NAMESPACE}::min(coords, cropped_size - 1);"
6181        )?;
6182
6183        // Apply load transformation
6184        writeln!(self.out, "{l1}uint2 plane0_coords = uint2({NAMESPACE}::round(tex.params.load_transform * float3(float2(coords), 1.0)));")?;
6185        writeln!(self.out, "{l1}if (tex.params.num_planes == 1u) {{")?;
6186        // For single plane, simply read from plane0
6187        writeln!(self.out, "{l2}return tex.plane0.read(plane0_coords);")?;
6188        writeln!(self.out, "{l1}}} else {{")?;
6189
6190        // Chroma planes may be subsampled so we must scale the coords accordingly.
6191        writeln!(
6192            self.out,
6193            "{l2}uint2 plane1_size = uint2(tex.plane1.get_width(), tex.plane1.get_height());"
6194        )?;
6195        writeln!(self.out, "{l2}uint2 plane1_coords = uint2({NAMESPACE}::floor(float2(plane0_coords) * float2(plane1_size) / float2(plane0_size)));")?;
6196
6197        // For multi-plane, read the Y value from plane 0
6198        writeln!(self.out, "{l2}float y = tex.plane0.read(plane0_coords).x;")?;
6199
6200        writeln!(self.out, "{l2}float2 uv;")?;
6201        writeln!(self.out, "{l2}if (tex.params.num_planes == 2u) {{")?;
6202        // For 2 planes, read UV from interleaved plane 1
6203        writeln!(self.out, "{l3}uv = tex.plane1.read(plane1_coords).xy;")?;
6204        writeln!(self.out, "{l2}}} else {{")?;
6205        // For 3 planes, read U and V from planes 1 and 2 respectively
6206        writeln!(
6207            self.out,
6208            "{l2}uint2 plane2_size = uint2(tex.plane2.get_width(), tex.plane2.get_height());"
6209        )?;
6210        writeln!(self.out, "{l2}uint2 plane2_coords = uint2({NAMESPACE}::floor(float2(plane0_coords) * float2(plane2_size) / float2(plane0_size)));")?;
6211        writeln!(
6212            self.out,
6213            "{l3}uv = float2(tex.plane1.read(plane1_coords).x, tex.plane2.read(plane2_coords).x);"
6214        )?;
6215        writeln!(self.out, "{l2}}}")?;
6216
6217        self.write_convert_yuv_to_rgb_and_return(l2, "y", "uv", "tex.params")?;
6218
6219        writeln!(self.out, "{l1}}}")?;
6220        writeln!(self.out, "}}")?;
6221        writeln!(self.out)?;
6222        Ok(())
6223    }
6224
6225    #[allow(clippy::too_many_arguments)]
6226    fn write_wrapped_image_sample(
6227        &mut self,
6228        module: &crate::Module,
6229        func_ctx: &back::FunctionCtx,
6230        image: Handle<crate::Expression>,
6231        _sampler: Handle<crate::Expression>,
6232        _gather: Option<crate::SwizzleComponent>,
6233        _coordinate: Handle<crate::Expression>,
6234        _array_index: Option<Handle<crate::Expression>>,
6235        _offset: Option<Handle<crate::Expression>>,
6236        _level: crate::SampleLevel,
6237        _depth_ref: Option<Handle<crate::Expression>>,
6238        clamp_to_edge: bool,
6239    ) -> BackendResult {
6240        // We currently only need to wrap textureSampleBaseClampToEdge, for
6241        // both sampled and external textures.
6242        if !clamp_to_edge {
6243            return Ok(());
6244        }
6245        let class = match *func_ctx.resolve_type(image, &module.types) {
6246            crate::TypeInner::Image { class, .. } => class,
6247            _ => unreachable!(),
6248        };
6249        let wrapped = WrappedFunction::ImageSample {
6250            class,
6251            clamp_to_edge: true,
6252        };
6253        if !self.wrapped_functions.insert(wrapped) {
6254            return Ok(());
6255        }
6256        match class {
6257            crate::ImageClass::External => {
6258                writeln!(self.out, "float4 {IMAGE_SAMPLE_BASE_CLAMP_TO_EDGE_FUNCTION}({EXTERNAL_TEXTURE_WRAPPER_STRUCT} tex, {NAMESPACE}::sampler samp, float2 coords) {{")?;
6259                let l1 = back::Level(1);
6260                let l2 = l1.next();
6261                let l3 = l2.next();
6262                writeln!(self.out, "{l1}uint2 plane0_size = uint2(tex.plane0.get_width(), tex.plane0.get_height());")?;
6263                writeln!(
6264                    self.out,
6265                    "{l1}coords = tex.params.sample_transform * float3(coords, 1.0);"
6266                )?;
6267
6268                // Calculate the sample bounds. The purported size of the texture
6269                // (params.size) is irrelevant here as we are dealing with normalized
6270                // coordinates. Usually we would clamp to (0,0)..(1,1). However, we must
6271                // apply the sample transformation to that, also bearing in mind that it
6272                // may contain a flip on either axis. We calculate and adjust for the
6273                // half-texel separately for each plane as it depends on the actual
6274                // texture size which may vary between planes.
6275                writeln!(
6276                    self.out,
6277                    "{l1}float2 bounds_min = tex.params.sample_transform * float3(0.0, 0.0, 1.0);"
6278                )?;
6279                writeln!(
6280                    self.out,
6281                    "{l1}float2 bounds_max = tex.params.sample_transform * float3(1.0, 1.0, 1.0);"
6282                )?;
6283                writeln!(self.out, "{l1}float4 bounds = float4({NAMESPACE}::min(bounds_min, bounds_max), {NAMESPACE}::max(bounds_min, bounds_max));")?;
6284                writeln!(
6285                    self.out,
6286                    "{l1}float2 plane0_half_texel = float2(0.5, 0.5) / float2(plane0_size);"
6287                )?;
6288                writeln!(
6289                    self.out,
6290                    "{l1}float2 plane0_coords = {NAMESPACE}::clamp(coords, bounds.xy + plane0_half_texel, bounds.zw - plane0_half_texel);"
6291                )?;
6292                writeln!(self.out, "{l1}if (tex.params.num_planes == 1u) {{")?;
6293                // For single plane, simply sample from plane0
6294                writeln!(
6295                    self.out,
6296                    "{l2}return tex.plane0.sample(samp, plane0_coords, {NAMESPACE}::level(0.0f));"
6297                )?;
6298                writeln!(self.out, "{l1}}} else {{")?;
6299                writeln!(self.out, "{l2}uint2 plane1_size = uint2(tex.plane1.get_width(), tex.plane1.get_height());")?;
6300                writeln!(
6301                    self.out,
6302                    "{l2}float2 plane1_half_texel = float2(0.5, 0.5) / float2(plane1_size);"
6303                )?;
6304                writeln!(
6305                    self.out,
6306                    "{l2}float2 plane1_coords = {NAMESPACE}::clamp(coords, bounds.xy + plane1_half_texel, bounds.zw - plane1_half_texel);"
6307                )?;
6308
6309                // For multi-plane, sample the Y value from plane 0
6310                writeln!(
6311                    self.out,
6312                    "{l2}float y = tex.plane0.sample(samp, plane0_coords, {NAMESPACE}::level(0.0f)).r;"
6313                )?;
6314                writeln!(self.out, "{l2}float2 uv = float2(0.0, 0.0);")?;
6315                writeln!(self.out, "{l2}if (tex.params.num_planes == 2u) {{")?;
6316                // For 2 planes, sample UV from interleaved plane 1
6317                writeln!(
6318                    self.out,
6319                    "{l3}uv = tex.plane1.sample(samp, plane1_coords, {NAMESPACE}::level(0.0f)).xy;"
6320                )?;
6321                writeln!(self.out, "{l2}}} else {{")?;
6322                // For 3 planes, sample U and V from planes 1 and 2 respectively
6323                writeln!(self.out, "{l3}uint2 plane2_size = uint2(tex.plane2.get_width(), tex.plane2.get_height());")?;
6324                writeln!(
6325                    self.out,
6326                    "{l3}float2 plane2_half_texel = float2(0.5, 0.5) / float2(plane2_size);"
6327                )?;
6328                writeln!(
6329                    self.out,
6330                    "{l3}float2 plane2_coords = {NAMESPACE}::clamp(coords, bounds.xy + plane2_half_texel, bounds.zw - plane1_half_texel);"
6331                )?;
6332                writeln!(self.out, "{l3}uv.x = tex.plane1.sample(samp, plane1_coords, {NAMESPACE}::level(0.0f)).x;")?;
6333                writeln!(self.out, "{l3}uv.y = tex.plane2.sample(samp, plane2_coords, {NAMESPACE}::level(0.0f)).x;")?;
6334                writeln!(self.out, "{l2}}}")?;
6335
6336                self.write_convert_yuv_to_rgb_and_return(l2, "y", "uv", "tex.params")?;
6337
6338                writeln!(self.out, "{l1}}}")?;
6339                writeln!(self.out, "}}")?;
6340                writeln!(self.out)?;
6341            }
6342            _ => {
6343                writeln!(self.out, "{NAMESPACE}::float4 {IMAGE_SAMPLE_BASE_CLAMP_TO_EDGE_FUNCTION}({NAMESPACE}::texture2d<float, {NAMESPACE}::access::sample> tex, {NAMESPACE}::sampler samp, {NAMESPACE}::float2 coords) {{")?;
6344                let l1 = back::Level(1);
6345                writeln!(self.out, "{l1}{NAMESPACE}::float2 half_texel = 0.5 / {NAMESPACE}::float2(tex.get_width(0u), tex.get_height(0u));")?;
6346                writeln!(
6347                    self.out,
6348                    "{l1}return tex.sample(samp, {NAMESPACE}::clamp(coords, half_texel, 1.0 - half_texel), {NAMESPACE}::level(0.0));"
6349                )?;
6350                writeln!(self.out, "}}")?;
6351                writeln!(self.out)?;
6352            }
6353        }
6354        Ok(())
6355    }
6356
6357    fn write_wrapped_image_query(
6358        &mut self,
6359        module: &crate::Module,
6360        func_ctx: &back::FunctionCtx,
6361        image: Handle<crate::Expression>,
6362        query: crate::ImageQuery,
6363    ) -> BackendResult {
6364        // We currently only need to wrap size image queries for external textures
6365        if !matches!(query, crate::ImageQuery::Size { .. }) {
6366            return Ok(());
6367        }
6368        let class = match *func_ctx.resolve_type(image, &module.types) {
6369            crate::TypeInner::Image { class, .. } => class,
6370            _ => unreachable!(),
6371        };
6372        if class != crate::ImageClass::External {
6373            return Ok(());
6374        }
6375        let wrapped = WrappedFunction::ImageQuerySize { class };
6376        if !self.wrapped_functions.insert(wrapped) {
6377            return Ok(());
6378        }
6379        writeln!(
6380            self.out,
6381            "uint2 {IMAGE_SIZE_EXTERNAL_FUNCTION}({EXTERNAL_TEXTURE_WRAPPER_STRUCT} tex) {{"
6382        )?;
6383        let l1 = back::Level(1);
6384        let l2 = l1.next();
6385        writeln!(
6386            self.out,
6387            "{l1}if ({NAMESPACE}::any(tex.params.size != uint2(0u))) {{"
6388        )?;
6389        writeln!(self.out, "{l2}return tex.params.size;")?;
6390        writeln!(self.out, "{l1}}} else {{")?;
6391        // params.size == (0, 0) indicates to query and return plane 0's actual size
6392        writeln!(
6393            self.out,
6394            "{l2}return uint2(tex.plane0.get_width(), tex.plane0.get_height());"
6395        )?;
6396        writeln!(self.out, "{l1}}}")?;
6397        writeln!(self.out, "}}")?;
6398        writeln!(self.out)?;
6399        Ok(())
6400    }
6401
6402    fn write_wrapped_cooperative_load(
6403        &mut self,
6404        module: &crate::Module,
6405        func_ctx: &back::FunctionCtx,
6406        columns: crate::CooperativeSize,
6407        rows: crate::CooperativeSize,
6408        pointer: Handle<crate::Expression>,
6409    ) -> BackendResult {
6410        let ptr_ty = func_ctx.resolve_type(pointer, &module.types);
6411        let space = ptr_ty.pointer_space().unwrap();
6412        let space_name = space.to_msl_name().unwrap_or_default();
6413        let scalar = ptr_ty
6414            .pointer_base_type()
6415            .unwrap()
6416            .inner_with(&module.types)
6417            .scalar()
6418            .unwrap();
6419        let wrapped = WrappedFunction::CooperativeLoad {
6420            space_name,
6421            columns,
6422            rows,
6423            scalar,
6424        };
6425        if !self.wrapped_functions.insert(wrapped) {
6426            return Ok(());
6427        }
6428        let scalar_name = scalar.to_msl_name();
6429        writeln!(
6430            self.out,
6431            "{NAMESPACE}::simdgroup_{scalar_name}{}x{} {COOPERATIVE_LOAD_FUNCTION}(const {space_name} {scalar_name}* ptr, int stride, bool is_row_major) {{",
6432            columns as u32, rows as u32,
6433        )?;
6434        let l1 = back::Level(1);
6435        writeln!(
6436            self.out,
6437            "{l1}{NAMESPACE}::simdgroup_{scalar_name}{}x{} m;",
6438            columns as u32, rows as u32
6439        )?;
6440        let matrix_origin = "0";
6441        writeln!(
6442            self.out,
6443            "{l1}simdgroup_load(m, ptr, stride, {matrix_origin}, is_row_major);"
6444        )?;
6445        writeln!(self.out, "{l1}return m;")?;
6446        writeln!(self.out, "}}")?;
6447        writeln!(self.out)?;
6448        Ok(())
6449    }
6450
6451    fn write_wrapped_cooperative_multiply_add(
6452        &mut self,
6453        module: &crate::Module,
6454        func_ctx: &back::FunctionCtx,
6455        space: crate::AddressSpace,
6456        a: Handle<crate::Expression>,
6457        b: Handle<crate::Expression>,
6458    ) -> BackendResult {
6459        let space_name = space.to_msl_name().unwrap_or_default();
6460        let (a_c, a_r, scalar) = match *func_ctx.resolve_type(a, &module.types) {
6461            crate::TypeInner::CooperativeMatrix {
6462                columns,
6463                rows,
6464                scalar,
6465                ..
6466            } => (columns, rows, scalar),
6467            _ => unreachable!(),
6468        };
6469        let (b_c, b_r) = match *func_ctx.resolve_type(b, &module.types) {
6470            crate::TypeInner::CooperativeMatrix { columns, rows, .. } => (columns, rows),
6471            _ => unreachable!(),
6472        };
6473        let wrapped = WrappedFunction::CooperativeMultiplyAdd {
6474            space_name,
6475            columns: b_c,
6476            rows: a_r,
6477            intermediate: a_c,
6478            scalar,
6479        };
6480        if !self.wrapped_functions.insert(wrapped) {
6481            return Ok(());
6482        }
6483        let scalar_name = scalar.to_msl_name();
6484        writeln!(
6485            self.out,
6486            "{NAMESPACE}::simdgroup_{scalar_name}{}x{} {COOPERATIVE_MULTIPLY_ADD_FUNCTION}(const {space_name} {NAMESPACE}::simdgroup_{scalar_name}{}x{}& a, const {space_name} {NAMESPACE}::simdgroup_{scalar_name}{}x{}& b, const {space_name} {NAMESPACE}::simdgroup_{scalar_name}{}x{}& c) {{",
6487            b_c as u32, a_r as u32, a_c as u32, a_r as u32, b_c as u32, b_r as u32, b_c as u32, a_r as u32,
6488        )?;
6489        let l1 = back::Level(1);
6490        writeln!(
6491            self.out,
6492            "{l1}{NAMESPACE}::simdgroup_{scalar_name}{}x{} d;",
6493            b_c as u32, a_r as u32
6494        )?;
6495        writeln!(self.out, "{l1}simdgroup_multiply_accumulate(d,a,b,c);")?;
6496        writeln!(self.out, "{l1}return d;")?;
6497        writeln!(self.out, "}}")?;
6498        writeln!(self.out)?;
6499        Ok(())
6500    }
6501
6502    pub(super) fn write_wrapped_functions(
6503        &mut self,
6504        module: &crate::Module,
6505        func_ctx: &back::FunctionCtx,
6506    ) -> BackendResult {
6507        for (expr_handle, expr) in func_ctx.expressions.iter() {
6508            match *expr {
6509                crate::Expression::Unary { op, expr: operand } => {
6510                    self.write_wrapped_unary_op(module, func_ctx, op, operand)?;
6511                }
6512                crate::Expression::Binary { op, left, right } => {
6513                    self.write_wrapped_binary_op(module, func_ctx, expr_handle, op, left, right)?;
6514                }
6515                crate::Expression::Math {
6516                    fun,
6517                    arg,
6518                    arg1,
6519                    arg2,
6520                    arg3,
6521                } => {
6522                    self.write_wrapped_math_function(module, func_ctx, fun, arg, arg1, arg2, arg3)?;
6523                }
6524                crate::Expression::As {
6525                    expr,
6526                    kind,
6527                    convert,
6528                } => {
6529                    self.write_wrapped_cast(module, func_ctx, expr, kind, convert)?;
6530                }
6531                crate::Expression::ImageLoad {
6532                    image,
6533                    coordinate,
6534                    array_index,
6535                    sample,
6536                    level,
6537                } => {
6538                    self.write_wrapped_image_load(
6539                        module,
6540                        func_ctx,
6541                        image,
6542                        coordinate,
6543                        array_index,
6544                        sample,
6545                        level,
6546                    )?;
6547                }
6548                crate::Expression::ImageSample {
6549                    image,
6550                    sampler,
6551                    gather,
6552                    coordinate,
6553                    array_index,
6554                    offset,
6555                    level,
6556                    depth_ref,
6557                    clamp_to_edge,
6558                } => {
6559                    self.write_wrapped_image_sample(
6560                        module,
6561                        func_ctx,
6562                        image,
6563                        sampler,
6564                        gather,
6565                        coordinate,
6566                        array_index,
6567                        offset,
6568                        level,
6569                        depth_ref,
6570                        clamp_to_edge,
6571                    )?;
6572                }
6573                crate::Expression::ImageQuery { image, query } => {
6574                    self.write_wrapped_image_query(module, func_ctx, image, query)?;
6575                }
6576                crate::Expression::CooperativeLoad {
6577                    columns,
6578                    rows,
6579                    role: _,
6580                    ref data,
6581                } => {
6582                    self.write_wrapped_cooperative_load(
6583                        module,
6584                        func_ctx,
6585                        columns,
6586                        rows,
6587                        data.pointer,
6588                    )?;
6589                }
6590                crate::Expression::CooperativeMultiplyAdd { a, b, c: _ } => {
6591                    let space = crate::AddressSpace::Private;
6592                    self.write_wrapped_cooperative_multiply_add(module, func_ctx, space, a, b)?;
6593                }
6594                crate::Expression::RayQueryGetIntersection { committed, .. } => {
6595                    self.write_rq_get_intersection_function(module, committed)?;
6596                }
6597                _ => {}
6598            }
6599        }
6600
6601        Ok(())
6602    }
6603
6604    // Returns the array of mapped entry point names.
6605    fn write_functions(
6606        &mut self,
6607        module: &crate::Module,
6608        mod_info: &valid::ModuleInfo,
6609        options: &Options,
6610        pipeline_options: &PipelineOptions,
6611    ) -> Result<TranslationInfo, Error> {
6612        use back::msl::VertexFormat;
6613
6614        // Define structs to hold resolved/generated data for vertex buffers and
6615        // their attributes.
6616        struct AttributeMappingResolved {
6617            ty_name: String,
6618            dimension: Option<crate::VectorSize>,
6619            scalar: crate::Scalar,
6620            name: String,
6621        }
6622        let mut am_resolved = FastHashMap::<u32, AttributeMappingResolved>::default();
6623
6624        struct VertexBufferMappingResolved<'a> {
6625            id: u32,
6626            stride: u32,
6627            step_mode: back::msl::VertexBufferStepMode,
6628            ty_name: String,
6629            param_name: String,
6630            elem_name: String,
6631            attributes: &'a Vec<back::msl::AttributeMapping>,
6632        }
6633        let mut vbm_resolved = Vec::<VertexBufferMappingResolved>::new();
6634
6635        // Define a struct to hold a named reference to a byte-unpacking function.
6636        struct UnpackingFunction {
6637            name: String,
6638            byte_count: u32,
6639            dimension: Option<crate::VectorSize>,
6640            scalar: crate::Scalar,
6641        }
6642        let mut unpacking_functions = FastHashMap::<VertexFormat, UnpackingFunction>::default();
6643
6644        // Check if we are attempting vertex pulling. If we are, generate some
6645        // names we'll need, and iterate the vertex buffer mappings to output
6646        // all the conversion functions we'll need to unpack the attribute data.
6647        // We can re-use these names for all entry points that need them, since
6648        // those entry points also use self.namer.
6649        let mut needs_vertex_id = false;
6650        let v_id = self.namer.call("v_id");
6651
6652        let mut needs_instance_id = false;
6653        let i_id = self.namer.call("i_id");
6654        if pipeline_options.vertex_pulling_transform {
6655            for vbm in &pipeline_options.vertex_buffer_mappings {
6656                let buffer_id = vbm.id;
6657                let buffer_stride = vbm.stride;
6658
6659                assert!(
6660                    buffer_stride > 0,
6661                    "Vertex pulling requires a non-zero buffer stride."
6662                );
6663
6664                match vbm.step_mode {
6665                    back::msl::VertexBufferStepMode::Constant => {}
6666                    back::msl::VertexBufferStepMode::ByVertex => {
6667                        needs_vertex_id = true;
6668                    }
6669                    back::msl::VertexBufferStepMode::ByInstance => {
6670                        needs_instance_id = true;
6671                    }
6672                }
6673
6674                let buffer_ty = self.namer.call(format!("vb_{buffer_id}_type").as_str());
6675                let buffer_param = self.namer.call(format!("vb_{buffer_id}_in").as_str());
6676                let buffer_elem = self.namer.call(format!("vb_{buffer_id}_elem").as_str());
6677
6678                vbm_resolved.push(VertexBufferMappingResolved {
6679                    id: buffer_id,
6680                    stride: buffer_stride,
6681                    step_mode: vbm.step_mode,
6682                    ty_name: buffer_ty,
6683                    param_name: buffer_param,
6684                    elem_name: buffer_elem,
6685                    attributes: &vbm.attributes,
6686                });
6687
6688                // Iterate the attributes and generate needed unpacking functions.
6689                for attribute in &vbm.attributes {
6690                    if unpacking_functions.contains_key(&attribute.format) {
6691                        continue;
6692                    }
6693                    let (name, byte_count, dimension, scalar) =
6694                        match self.write_unpacking_function(attribute.format) {
6695                            Ok((name, byte_count, dimension, scalar)) => {
6696                                (name, byte_count, dimension, scalar)
6697                            }
6698                            _ => {
6699                                continue;
6700                            }
6701                        };
6702                    unpacking_functions.insert(
6703                        attribute.format,
6704                        UnpackingFunction {
6705                            name,
6706                            byte_count,
6707                            dimension,
6708                            scalar,
6709                        },
6710                    );
6711                }
6712            }
6713        }
6714
6715        let mut pass_through_globals = Vec::new();
6716        for (fun_handle, fun) in module.functions.iter() {
6717            log::trace!(
6718                "function {:?}, handle {:?}",
6719                fun.name.as_deref().unwrap_or("(anonymous)"),
6720                fun_handle
6721            );
6722
6723            let ctx = back::FunctionCtx {
6724                ty: back::FunctionType::Function(fun_handle),
6725                info: &mod_info[fun_handle],
6726                expressions: &fun.expressions,
6727                named_expressions: &fun.named_expressions,
6728            };
6729
6730            writeln!(self.out)?;
6731            self.write_wrapped_functions(module, &ctx)?;
6732
6733            let fun_info = &mod_info[fun_handle];
6734            pass_through_globals.clear();
6735            let mut needs_buffer_sizes = false;
6736            for (handle, var) in module.global_variables.iter() {
6737                if !fun_info[handle].is_empty() {
6738                    if var.space.needs_pass_through() {
6739                        pass_through_globals.push(handle);
6740                    }
6741                    needs_buffer_sizes |= needs_array_length(var.ty, &module.types);
6742                }
6743            }
6744
6745            let fun_name = &self.names[&NameKey::Function(fun_handle)];
6746            match fun.result {
6747                Some(ref result) => {
6748                    let ty_name = TypeContext {
6749                        handle: result.ty,
6750                        gctx: module.to_ctx(),
6751                        names: &self.names,
6752                        access: crate::StorageAccess::empty(),
6753                        first_time: false,
6754                    };
6755                    write!(self.out, "{ty_name}")?;
6756                }
6757                None => {
6758                    write!(self.out, "void")?;
6759                }
6760            }
6761            writeln!(self.out, " {fun_name}(")?;
6762
6763            for (index, arg) in fun.arguments.iter().enumerate() {
6764                let name = &self.names[&NameKey::FunctionArgument(fun_handle, index as u32)];
6765                let param_type_name = TypeContext {
6766                    handle: arg.ty,
6767                    gctx: module.to_ctx(),
6768                    names: &self.names,
6769                    access: crate::StorageAccess::empty(),
6770                    first_time: false,
6771                };
6772                let separator = separate(
6773                    !pass_through_globals.is_empty()
6774                        || index + 1 != fun.arguments.len()
6775                        || needs_buffer_sizes,
6776                );
6777                writeln!(
6778                    self.out,
6779                    "{}{} {}{}",
6780                    back::INDENT,
6781                    param_type_name,
6782                    name,
6783                    separator
6784                )?;
6785            }
6786            for (index, &handle) in pass_through_globals.iter().enumerate() {
6787                let tyvar = TypedGlobalVariable {
6788                    module,
6789                    names: &self.names,
6790                    handle,
6791                    usage: fun_info[handle],
6792                    reference: true,
6793                };
6794                let separator =
6795                    separate(index + 1 != pass_through_globals.len() || needs_buffer_sizes);
6796                write!(self.out, "{}", back::INDENT)?;
6797                tyvar.try_fmt(&mut self.out)?;
6798                writeln!(self.out, "{separator}")?;
6799            }
6800
6801            if needs_buffer_sizes {
6802                writeln!(
6803                    self.out,
6804                    "{}constant _mslBufferSizes& _buffer_sizes",
6805                    back::INDENT
6806                )?;
6807            }
6808
6809            writeln!(self.out, ") {{")?;
6810
6811            let guarded_indices =
6812                index::find_checked_indexes(module, fun, fun_info, options.bounds_check_policies);
6813
6814            let context = StatementContext {
6815                expression: ExpressionContext {
6816                    function: fun,
6817                    origin: FunctionOrigin::Handle(fun_handle),
6818                    info: fun_info,
6819                    lang_version: options.lang_version,
6820                    policies: options.bounds_check_policies,
6821                    guarded_indices,
6822                    module,
6823                    mod_info,
6824                    pipeline_options,
6825                    force_loop_bounding: options.force_loop_bounding,
6826                },
6827                result_struct: None,
6828            };
6829
6830            self.put_locals(&context.expression)?;
6831            self.update_expressions_to_bake(fun, fun_info, &context.expression);
6832            self.put_block(back::Level(1), &fun.body, &context)?;
6833            writeln!(self.out, "}}")?;
6834            self.named_expressions.clear();
6835        }
6836
6837        let ep_range = get_entry_points(module, pipeline_options.entry_point.as_ref())
6838            .map_err(|(stage, name)| Error::EntryPointNotFound(stage, name))?;
6839
6840        let mut info = TranslationInfo {
6841            entry_point_names: Vec::with_capacity(ep_range.len()),
6842        };
6843
6844        for ep_index in ep_range {
6845            let ep = &module.entry_points[ep_index];
6846            let fun = &ep.function;
6847            let fun_info = mod_info.get_entry_point(ep_index);
6848            let mut ep_error = None;
6849
6850            // For vertex_id and instance_id arguments, presume that we'll
6851            // use our generated names, but switch to the name of an
6852            // existing @builtin param, if we find one.
6853            let mut v_existing_id = None;
6854            let mut i_existing_id = None;
6855
6856            log::trace!(
6857                "entry point {:?}, index {:?}",
6858                fun.name.as_deref().unwrap_or("(anonymous)"),
6859                ep_index
6860            );
6861
6862            let ctx = back::FunctionCtx {
6863                ty: back::FunctionType::EntryPoint(ep_index as u16),
6864                info: fun_info,
6865                expressions: &fun.expressions,
6866                named_expressions: &fun.named_expressions,
6867            };
6868
6869            self.write_wrapped_functions(module, &ctx)?;
6870
6871            let (em_str, in_mode, out_mode, can_vertex_pull) = match ep.stage {
6872                crate::ShaderStage::Vertex => (
6873                    Some("vertex"),
6874                    LocationMode::VertexInput,
6875                    LocationMode::VertexOutput,
6876                    true,
6877                ),
6878                crate::ShaderStage::Fragment => (
6879                    Some("fragment"),
6880                    LocationMode::FragmentInput,
6881                    LocationMode::FragmentOutput,
6882                    false,
6883                ),
6884                crate::ShaderStage::Compute => (
6885                    Some("kernel"),
6886                    LocationMode::Uniform,
6887                    LocationMode::Uniform,
6888                    false,
6889                ),
6890                crate::ShaderStage::Task => {
6891                    (None, LocationMode::Uniform, LocationMode::Uniform, false)
6892                }
6893                crate::ShaderStage::Mesh => {
6894                    (None, LocationMode::Uniform, LocationMode::MeshOutput, false)
6895                }
6896                crate::ShaderStage::RayGeneration
6897                | crate::ShaderStage::AnyHit
6898                | crate::ShaderStage::ClosestHit
6899                | crate::ShaderStage::Miss => unimplemented!(),
6900            };
6901
6902            // Should this entry point be modified to do vertex pulling?
6903            let do_vertex_pulling = can_vertex_pull
6904                && pipeline_options.vertex_pulling_transform
6905                && !pipeline_options.vertex_buffer_mappings.is_empty();
6906
6907            // Is any global variable used by this entry point dynamically sized?
6908            let needs_buffer_sizes = do_vertex_pulling
6909                || module
6910                    .global_variables
6911                    .iter()
6912                    .filter(|&(handle, _)| !fun_info[handle].is_empty())
6913                    .any(|(_, var)| needs_array_length(var.ty, &module.types));
6914
6915            // skip this entry point if any global bindings are missing,
6916            // or their types are incompatible.
6917            if !options.fake_missing_bindings {
6918                for (var_handle, var) in module.global_variables.iter() {
6919                    if fun_info[var_handle].is_empty() {
6920                        continue;
6921                    }
6922                    match var.space {
6923                        crate::AddressSpace::Uniform
6924                        | crate::AddressSpace::Storage { .. }
6925                        | crate::AddressSpace::Handle => {
6926                            let br = match var.binding {
6927                                Some(ref br) => br,
6928                                None => {
6929                                    let var_name = var.name.clone().unwrap_or_default();
6930                                    ep_error =
6931                                        Some(super::EntryPointError::MissingBinding(var_name));
6932                                    break;
6933                                }
6934                            };
6935                            let target = options.get_resource_binding_target(ep, br);
6936                            let good = match target {
6937                                Some(target) => {
6938                                    // We intentionally don't dereference binding_arrays here,
6939                                    // so that binding arrays fall to the buffer location.
6940
6941                                    match module.types[var.ty].inner {
6942                                        crate::TypeInner::Image {
6943                                            class: crate::ImageClass::External,
6944                                            ..
6945                                        } => target.external_texture.is_some(),
6946                                        crate::TypeInner::Image { .. } => target.texture.is_some(),
6947                                        crate::TypeInner::Sampler { .. } => {
6948                                            target.sampler.is_some()
6949                                        }
6950                                        _ => target.buffer.is_some(),
6951                                    }
6952                                }
6953                                None => false,
6954                            };
6955                            if !good {
6956                                ep_error = Some(super::EntryPointError::MissingBindTarget(*br));
6957                                break;
6958                            }
6959                        }
6960                        crate::AddressSpace::Immediate => {
6961                            if let Err(e) = options.resolve_immediates(ep) {
6962                                ep_error = Some(e);
6963                                break;
6964                            }
6965                        }
6966                        crate::AddressSpace::Function
6967                        | crate::AddressSpace::Private
6968                        | crate::AddressSpace::WorkGroup
6969                        | crate::AddressSpace::TaskPayload => {}
6970                        crate::AddressSpace::RayPayload
6971                        | crate::AddressSpace::IncomingRayPayload => unimplemented!(),
6972                    }
6973                }
6974                if needs_buffer_sizes {
6975                    if let Err(err) = options.resolve_sizes_buffer(ep) {
6976                        ep_error = Some(err);
6977                    }
6978                }
6979            }
6980
6981            if let Some(err) = ep_error {
6982                info.entry_point_names.push(Err(err));
6983                continue;
6984            }
6985            let fun_name = self.names[&NameKey::EntryPoint(ep_index as _)].clone();
6986            info.entry_point_names.push(Ok(fun_name.clone()));
6987
6988            writeln!(self.out)?;
6989
6990            // Since `Namer.reset` wasn't expecting struct members to be
6991            // suddenly injected into another namespace like this,
6992            // `self.names` doesn't keep them distinct from other variables.
6993            // Generate fresh names for these arguments, and remember the
6994            // mapping.
6995            let mut flattened_member_names = FastHashMap::default();
6996            // Varyings' members get their own namespace
6997            let mut varyings_namer = proc::Namer::default();
6998
6999            let mut empty_names = FastHashMap::default(); // Create a throwaway map
7000            varyings_namer.reset(
7001                module,
7002                &super::keywords::RESERVED_SET,
7003                proc::KeywordSet::empty(),
7004                proc::CaseInsensitiveKeywordSet::empty(),
7005                &[CLAMPED_LOD_LOAD_PREFIX],
7006                &mut empty_names,
7007            );
7008
7009            // List all the Naga `EntryPoint`'s `Function`'s arguments,
7010            // flattening structs into their members. In Metal, we will pass
7011            // each of these values to the entry point as a separate argument—
7012            // except for the varyings, handled next.
7013            let mut flattened_arguments = Vec::new();
7014            for (arg_index, arg) in fun.arguments.iter().enumerate() {
7015                match module.types[arg.ty].inner {
7016                    crate::TypeInner::Struct { ref members, .. } => {
7017                        for (member_index, member) in members.iter().enumerate() {
7018                            let member_index = member_index as u32;
7019                            flattened_arguments.push((
7020                                NameKey::StructMember(arg.ty, member_index),
7021                                member.ty,
7022                                member.binding.as_ref(),
7023                            ));
7024                            let name_key = NameKey::StructMember(arg.ty, member_index);
7025                            let name = match member.binding {
7026                                Some(crate::Binding::Location { .. }) => {
7027                                    if do_vertex_pulling {
7028                                        self.namer.call(&self.names[&name_key])
7029                                    } else {
7030                                        varyings_namer.call(&self.names[&name_key])
7031                                    }
7032                                }
7033                                _ => self.namer.call(&self.names[&name_key]),
7034                            };
7035                            flattened_member_names.insert(name_key, name);
7036                        }
7037                    }
7038                    _ => flattened_arguments.push((
7039                        NameKey::EntryPointArgument(ep_index as _, arg_index as u32),
7040                        arg.ty,
7041                        arg.binding.as_ref(),
7042                    )),
7043                }
7044            }
7045
7046            // Identify the varyings among the argument values, and maybe emit
7047            // a struct type named `<fun>Input` to hold them. If we are doing
7048            // vertex pulling, we instead update our attribute mapping to
7049            // note the types, names, and zero values of the attributes.
7050            let stage_in_name = self.namer.call(&format!("{fun_name}Input"));
7051            let varyings_member_name = self.namer.call("varyings");
7052            let mut has_varyings = false;
7053
7054            if !flattened_arguments.is_empty() {
7055                if !do_vertex_pulling {
7056                    writeln!(self.out, "struct {stage_in_name} {{")?;
7057                }
7058                for &(ref name_key, ty, binding) in flattened_arguments.iter() {
7059                    let Some(binding) = binding else {
7060                        continue;
7061                    };
7062                    let name = match *name_key {
7063                        NameKey::StructMember(..) => &flattened_member_names[name_key],
7064                        _ => &self.names[name_key],
7065                    };
7066                    let ty_name = TypeContext {
7067                        handle: ty,
7068                        gctx: module.to_ctx(),
7069                        names: &self.names,
7070                        access: crate::StorageAccess::empty(),
7071                        first_time: false,
7072                    };
7073                    let resolved = options.resolve_local_binding(binding, in_mode)?;
7074                    let location = match *binding {
7075                        crate::Binding::Location { location, .. } => Some(location),
7076                        crate::Binding::BuiltIn(crate::BuiltIn::Barycentric { .. }) => None,
7077                        crate::Binding::BuiltIn(_) => continue,
7078                    };
7079                    if do_vertex_pulling {
7080                        let Some(location) = location else {
7081                            continue;
7082                        };
7083                        // Update our attribute mapping.
7084                        am_resolved.insert(
7085                            location,
7086                            AttributeMappingResolved {
7087                                ty_name: ty_name.to_string(),
7088                                dimension: ty_name.vector_size(),
7089                                scalar: ty_name.scalar().unwrap(),
7090                                name: name.to_string(),
7091                            },
7092                        );
7093                    } else {
7094                        has_varyings = true;
7095                        if let super::ResolvedBinding::User {
7096                            prefix,
7097                            index,
7098                            interpolation: Some(super::ResolvedInterpolation::PerVertex),
7099                        } = resolved
7100                        {
7101                            if options.lang_version < (4, 0) {
7102                                return Err(Error::PerVertexNotSupported);
7103                            }
7104                            write!(
7105                                self.out,
7106                                "{}{NAMESPACE}::vertex_value<{}> {name} [[user({prefix}{index})]]",
7107                                back::INDENT,
7108                                ty_name.unwrap_array()
7109                            )?;
7110                        } else {
7111                            write!(self.out, "{}{} {}", back::INDENT, ty_name, name)?;
7112                            resolved.try_fmt(&mut self.out)?;
7113                        }
7114                        writeln!(self.out, ";")?;
7115                    }
7116                }
7117                if !do_vertex_pulling {
7118                    writeln!(self.out, "}};")?;
7119                }
7120            }
7121
7122            // Define a struct type named for the return value, if any, named
7123            // `<fun>Output`.
7124            let stage_out_name = self.namer.call(&format!("{fun_name}Output"));
7125            let result_member_name = self.namer.call("member");
7126            let result_type_name = match fun.result {
7127                Some(ref result) if ep.stage != crate::ShaderStage::Task => {
7128                    let mut result_members = Vec::new();
7129                    if let crate::TypeInner::Struct { ref members, .. } =
7130                        module.types[result.ty].inner
7131                    {
7132                        for (member_index, member) in members.iter().enumerate() {
7133                            result_members.push((
7134                                &self.names[&NameKey::StructMember(result.ty, member_index as u32)],
7135                                member.ty,
7136                                member.binding.as_ref(),
7137                            ));
7138                        }
7139                    } else {
7140                        result_members.push((
7141                            &result_member_name,
7142                            result.ty,
7143                            result.binding.as_ref(),
7144                        ));
7145                    }
7146
7147                    writeln!(self.out, "struct {stage_out_name} {{")?;
7148                    let mut has_point_size = false;
7149                    for (name, ty, binding) in result_members {
7150                        let ty_name = TypeContext {
7151                            handle: ty,
7152                            gctx: module.to_ctx(),
7153                            names: &self.names,
7154                            access: crate::StorageAccess::empty(),
7155                            first_time: true,
7156                        };
7157                        let binding = binding.ok_or_else(|| {
7158                            Error::GenericValidation("Expected binding, got None".into())
7159                        })?;
7160
7161                        if let crate::Binding::BuiltIn(crate::BuiltIn::PointSize) = *binding {
7162                            has_point_size = true;
7163                            if !pipeline_options.allow_and_force_point_size {
7164                                continue;
7165                            }
7166                        }
7167
7168                        let array_len = match module.types[ty].inner {
7169                            crate::TypeInner::Array {
7170                                size: crate::ArraySize::Constant(size),
7171                                ..
7172                            } => Some(size),
7173                            _ => None,
7174                        };
7175                        let resolved = options.resolve_local_binding(binding, out_mode)?;
7176                        write!(self.out, "{}{} {}", back::INDENT, ty_name, name)?;
7177                        resolved.try_fmt(&mut self.out)?;
7178                        if let Some(array_len) = array_len {
7179                            write!(self.out, " [{array_len}]")?;
7180                        }
7181                        writeln!(self.out, ";")?;
7182                    }
7183
7184                    if pipeline_options.allow_and_force_point_size
7185                        && ep.stage == crate::ShaderStage::Vertex
7186                        && !has_point_size
7187                    {
7188                        // inject the point size output last
7189                        writeln!(
7190                            self.out,
7191                            "{}float _point_size [[point_size]];",
7192                            back::INDENT
7193                        )?;
7194                    }
7195                    writeln!(self.out, "}};")?;
7196                    &stage_out_name
7197                }
7198                Some(ref result) if ep.stage == crate::ShaderStage::Task => {
7199                    assert_eq!(
7200                        module.types[result.ty].inner,
7201                        crate::TypeInner::Vector {
7202                            size: crate::VectorSize::Tri,
7203                            scalar: crate::Scalar::U32
7204                        }
7205                    );
7206
7207                    "metal::uint3"
7208                }
7209                _ => "void",
7210            };
7211
7212            let out_mesh_info = if let Some(ref mesh_info) = ep.mesh_info {
7213                Some(self.write_mesh_output_types(
7214                    mesh_info,
7215                    &fun_name,
7216                    module,
7217                    pipeline_options.allow_and_force_point_size,
7218                    options,
7219                )?)
7220            } else {
7221                None
7222            };
7223
7224            // If we're doing a vertex pulling transform, define the buffer
7225            // structure types.
7226            if do_vertex_pulling {
7227                for vbm in &vbm_resolved {
7228                    let buffer_stride = vbm.stride;
7229                    let buffer_ty = &vbm.ty_name;
7230
7231                    // Define a structure of bytes of the appropriate size.
7232                    // When we access the attributes, we'll be unpacking these
7233                    // bytes at some offset.
7234                    writeln!(
7235                        self.out,
7236                        "struct {buffer_ty} {{ metal::uchar data[{buffer_stride}]; }};"
7237                    )?;
7238                }
7239            }
7240
7241            let is_wrapped = matches!(
7242                ep.stage,
7243                crate::ShaderStage::Task | crate::ShaderStage::Mesh
7244            );
7245            let fun_name = fun_name.clone();
7246            let nested_fun_name = if is_wrapped {
7247                self.namer.call(&format!("_{fun_name}"))
7248            } else {
7249                fun_name.clone()
7250            };
7251
7252            // https://web.archive.org/web/20181029003926/https://developer.apple.com/metal/Metal-Shading-Language-Specification.pdf
7253            if ep.stage == crate::ShaderStage::Compute && options.lang_version >= (2, 1) {
7254                let total_threads =
7255                    ep.workgroup_size[0] * ep.workgroup_size[1] * ep.workgroup_size[2];
7256                write!(
7257                    self.out,
7258                    "[[max_total_threads_per_threadgroup({total_threads})]] "
7259                )?;
7260            }
7261
7262            // Write the entry point function's name, and begin its argument list.
7263            if let Some(em_str) = em_str {
7264                write!(self.out, "{em_str} ")?;
7265            }
7266            writeln!(self.out, "{result_type_name} {nested_fun_name}(")?;
7267
7268            let mut args = Vec::new();
7269
7270            // If we have produced a struct holding the `EntryPoint`'s
7271            // `Function`'s arguments' varyings, pass that struct first.
7272            if has_varyings {
7273                args.push(EntryPointArgument {
7274                    ty_name: stage_in_name,
7275                    name: varyings_member_name.clone(),
7276                    binding: " [[stage_in]]".to_string(),
7277                    init: None,
7278                });
7279            }
7280
7281            let mut local_invocation_index = None;
7282
7283            // Then pass the remaining arguments not included in the varyings
7284            // struct.
7285            for &(ref name_key, ty, binding) in flattened_arguments.iter() {
7286                let binding = match binding {
7287                    Some(&crate::Binding::BuiltIn(crate::BuiltIn::Barycentric { .. })) => continue,
7288                    Some(binding @ &crate::Binding::BuiltIn { .. }) => binding,
7289                    _ => continue,
7290                };
7291                let name = match *name_key {
7292                    NameKey::StructMember(..) => &flattened_member_names[name_key],
7293                    _ => &self.names[name_key],
7294                };
7295
7296                if binding == &crate::Binding::BuiltIn(crate::BuiltIn::LocalInvocationIndex) {
7297                    local_invocation_index = Some(name_key);
7298                }
7299
7300                let ty_name = TypeContext {
7301                    handle: ty,
7302                    gctx: module.to_ctx(),
7303                    names: &self.names,
7304                    access: crate::StorageAccess::empty(),
7305                    first_time: false,
7306                };
7307
7308                match *binding {
7309                    crate::Binding::BuiltIn(crate::BuiltIn::VertexIndex) => {
7310                        v_existing_id = Some(name.clone());
7311                    }
7312                    crate::Binding::BuiltIn(crate::BuiltIn::InstanceIndex) => {
7313                        i_existing_id = Some(name.clone());
7314                    }
7315                    _ => {}
7316                };
7317
7318                let resolved = options.resolve_local_binding(binding, in_mode)?;
7319                let mut binding = String::new();
7320                resolved.try_fmt(&mut binding)?;
7321
7322                args.push(EntryPointArgument {
7323                    ty_name: format!("{ty_name}"),
7324                    name: name.clone(),
7325                    binding,
7326                    init: None,
7327                });
7328            }
7329
7330            let need_workgroup_variables_initialization =
7331                self.need_workgroup_variables_initialization(options, ep, module, fun_info);
7332
7333            if local_invocation_index.is_none()
7334                && (need_workgroup_variables_initialization
7335                    || ep.stage == crate::ShaderStage::Task
7336                    || ep.stage == crate::ShaderStage::Mesh)
7337            {
7338                args.push(EntryPointArgument {
7339                    ty_name: "uint".to_string(),
7340                    name: "__local_invocation_index".to_string(),
7341                    binding: " [[thread_index_in_threadgroup]]".to_string(),
7342                    init: None,
7343                });
7344            }
7345
7346            // Those global variables used by this entry point and its callees
7347            // get passed as arguments. `Private` globals are an exception, they
7348            // don't outlive this invocation, so we declare them below as locals
7349            // within the entry point.
7350            for (handle, var) in module.global_variables.iter() {
7351                let usage = fun_info[handle];
7352                if usage.is_empty() || var.space == crate::AddressSpace::Private {
7353                    continue;
7354                }
7355
7356                if options.lang_version < (1, 2) {
7357                    match var.space {
7358                        // This restriction is not documented in the MSL spec
7359                        // but validation will fail if it is not upheld.
7360                        //
7361                        // We infer the required version from the "Function
7362                        // Buffer Read-Writes" section of [what's new], where
7363                        // the feature sets listed correspond with the ones
7364                        // supporting MSL 1.2.
7365                        //
7366                        // [what's new]: https://developer.apple.com/library/archive/documentation/Miscellaneous/Conceptual/MetalProgrammingGuide/WhatsNewiniOS10tvOS10andOSX1012/WhatsNewiniOS10tvOS10andOSX1012.html
7367                        crate::AddressSpace::Storage { access }
7368                            if access.contains(crate::StorageAccess::STORE)
7369                                && ep.stage == crate::ShaderStage::Fragment =>
7370                        {
7371                            return Err(Error::UnsupportedWritableStorageBuffer)
7372                        }
7373                        crate::AddressSpace::Handle => {
7374                            match module.types[var.ty].inner {
7375                                crate::TypeInner::Image {
7376                                    class: crate::ImageClass::Storage { access, .. },
7377                                    ..
7378                                } => {
7379                                    // This restriction is not documented in the MSL spec
7380                                    // but validation will fail if it is not upheld.
7381                                    //
7382                                    // We infer the required version from the "Function
7383                                    // Texture Read-Writes" section of [what's new], where
7384                                    // the feature sets listed correspond with the ones
7385                                    // supporting MSL 1.2.
7386                                    //
7387                                    // [what's new]: https://developer.apple.com/library/archive/documentation/Miscellaneous/Conceptual/MetalProgrammingGuide/WhatsNewiniOS10tvOS10andOSX1012/WhatsNewiniOS10tvOS10andOSX1012.html
7388                                    if access.contains(crate::StorageAccess::STORE)
7389                                        && (ep.stage == crate::ShaderStage::Vertex
7390                                            || ep.stage == crate::ShaderStage::Fragment)
7391                                    {
7392                                        return Err(Error::UnsupportedWritableStorageTexture(
7393                                            ep.stage,
7394                                        ));
7395                                    }
7396
7397                                    if access.contains(
7398                                        crate::StorageAccess::LOAD | crate::StorageAccess::STORE,
7399                                    ) {
7400                                        return Err(Error::UnsupportedRWStorageTexture);
7401                                    }
7402                                }
7403                                _ => {}
7404                            }
7405                        }
7406                        _ => {}
7407                    }
7408                }
7409
7410                // Check min MSL version for binding arrays
7411                match var.space {
7412                    crate::AddressSpace::Handle => match module.types[var.ty].inner {
7413                        crate::TypeInner::BindingArray { base, .. } => {
7414                            match module.types[base].inner {
7415                                crate::TypeInner::Sampler { .. } => {
7416                                    if options.lang_version < (2, 0) {
7417                                        return Err(Error::UnsupportedArrayOf(
7418                                            "samplers".to_string(),
7419                                        ));
7420                                    }
7421                                }
7422                                crate::TypeInner::Image { class, .. } => match class {
7423                                    crate::ImageClass::Sampled { .. }
7424                                    | crate::ImageClass::Depth { .. }
7425                                    | crate::ImageClass::Storage {
7426                                        access: crate::StorageAccess::LOAD,
7427                                        ..
7428                                    } => {
7429                                        // Array of textures since:
7430                                        // - iOS: Metal 1.2 (check depends on https://github.com/gfx-rs/naga/issues/2164)
7431                                        // - macOS: Metal 2
7432
7433                                        if options.lang_version < (2, 0) {
7434                                            return Err(Error::UnsupportedArrayOf(
7435                                                "textures".to_string(),
7436                                            ));
7437                                        }
7438                                    }
7439                                    crate::ImageClass::Storage {
7440                                        access: crate::StorageAccess::STORE,
7441                                        ..
7442                                    } => {
7443                                        // Array of write-only textures since:
7444                                        // - iOS: Metal 2.2 (check depends on https://github.com/gfx-rs/naga/issues/2164)
7445                                        // - macOS: Metal 2
7446
7447                                        if options.lang_version < (2, 0) {
7448                                            return Err(Error::UnsupportedArrayOf(
7449                                                "write-only textures".to_string(),
7450                                            ));
7451                                        }
7452                                    }
7453                                    crate::ImageClass::Storage { .. } => {
7454                                        if options.lang_version < (3, 0) {
7455                                            return Err(Error::UnsupportedArrayOf(
7456                                                "read-write textures".to_string(),
7457                                            ));
7458                                        }
7459                                    }
7460                                    crate::ImageClass::External => {
7461                                        return Err(Error::UnsupportedArrayOf(
7462                                            "external textures".to_string(),
7463                                        ));
7464                                    }
7465                                },
7466                                _ => {
7467                                    return Err(Error::UnsupportedArrayOfType(base));
7468                                }
7469                            }
7470                        }
7471                        _ => {}
7472                    },
7473                    _ => {}
7474                }
7475
7476                // the resolves have already been checked for `!fake_missing_bindings` case
7477                let resolved = match var.space {
7478                    crate::AddressSpace::Immediate => options.resolve_immediates(ep).ok(),
7479                    crate::AddressSpace::WorkGroup => None,
7480                    crate::AddressSpace::TaskPayload => Some(back::msl::ResolvedBinding::Payload),
7481                    _ => options
7482                        .resolve_resource_binding(ep, var.binding.as_ref().unwrap())
7483                        .ok(),
7484                };
7485                if let Some(ref resolved) = resolved {
7486                    // Inline samplers are be defined in the EP body
7487                    if resolved.as_inline_sampler(options).is_some() {
7488                        continue;
7489                    }
7490                }
7491
7492                match module.types[var.ty].inner {
7493                    crate::TypeInner::Image {
7494                        class: crate::ImageClass::External,
7495                        ..
7496                    } => {
7497                        // External texture global variables get lowered to 3 textures
7498                        // and a constant buffer. We must emit a separate argument for
7499                        // each of these.
7500                        let target = match resolved {
7501                            Some(back::msl::ResolvedBinding::Resource(target)) => {
7502                                target.external_texture
7503                            }
7504                            _ => None,
7505                        };
7506
7507                        for i in 0..3 {
7508                            let plane_name = &self.names[&NameKey::ExternalTextureGlobalVariable(
7509                                handle,
7510                                ExternalTextureNameKey::Plane(i),
7511                            )];
7512                            let ty_name = format!(
7513                                "{NAMESPACE}::texture2d<float, {NAMESPACE}::access::sample>"
7514                            );
7515                            let name = plane_name.clone();
7516                            let binding = if let Some(ref target) = target {
7517                                format!(" [[texture({})]]", target.planes[i])
7518                            } else {
7519                                String::new()
7520                            };
7521                            args.push(EntryPointArgument {
7522                                ty_name,
7523                                name,
7524                                binding,
7525                                init: None,
7526                            });
7527                        }
7528                        let params_ty_name = &self.names
7529                            [&NameKey::Type(module.special_types.external_texture_params.unwrap())];
7530                        let params_name = &self.names[&NameKey::ExternalTextureGlobalVariable(
7531                            handle,
7532                            ExternalTextureNameKey::Params,
7533                        )];
7534                        let binding = if let Some(ref target) = target {
7535                            format!(" [[buffer({})]]", target.params)
7536                        } else {
7537                            String::new()
7538                        };
7539
7540                        args.push(EntryPointArgument {
7541                            ty_name: format!("constant {params_ty_name}&"),
7542                            name: params_name.clone(),
7543                            binding,
7544                            init: None,
7545                        });
7546                    }
7547                    _ => {
7548                        if var.space == crate::AddressSpace::WorkGroup
7549                            && ep.stage == crate::ShaderStage::Mesh
7550                        {
7551                            continue;
7552                        }
7553                        let tyvar = TypedGlobalVariable {
7554                            module,
7555                            names: &self.names,
7556                            handle,
7557                            usage,
7558                            reference: true,
7559                        };
7560                        let parts = tyvar.to_parts()?;
7561                        let mut binding = String::new();
7562                        if let Some(resolved) = resolved {
7563                            resolved.try_fmt(&mut binding)?;
7564                        }
7565                        args.push(EntryPointArgument {
7566                            ty_name: parts.ty_name,
7567                            name: parts.var_name,
7568                            binding,
7569                            init: var.init,
7570                        });
7571                    }
7572                }
7573            }
7574
7575            if do_vertex_pulling {
7576                if needs_vertex_id && v_existing_id.is_none() {
7577                    // Write the [[vertex_id]] argument.
7578                    args.push(EntryPointArgument {
7579                        ty_name: "uint".to_string(),
7580                        name: v_id.clone(),
7581                        binding: " [[vertex_id]]".to_string(),
7582                        init: None,
7583                    });
7584                }
7585
7586                if needs_instance_id && i_existing_id.is_none() {
7587                    args.push(EntryPointArgument {
7588                        ty_name: "uint".to_string(),
7589                        name: i_id.clone(),
7590                        binding: " [[instance_id]]".to_string(),
7591                        init: None,
7592                    });
7593                }
7594
7595                // Iterate vbm_resolved, output one argument for every vertex buffer,
7596                // using the names we generated earlier.
7597                for vbm in &vbm_resolved {
7598                    let id = &vbm.id;
7599                    let ty_name = &vbm.ty_name;
7600                    let param_name = &vbm.param_name;
7601                    args.push(EntryPointArgument {
7602                        ty_name: format!("const device {ty_name}*"),
7603                        name: param_name.clone(),
7604                        binding: format!(" [[buffer({id})]]"),
7605                        init: None,
7606                    });
7607                }
7608            }
7609
7610            // If this entry uses any variable-length arrays, their sizes are
7611            // passed as a final struct-typed argument.
7612            if needs_buffer_sizes {
7613                // this is checked earlier
7614                let resolved = options.resolve_sizes_buffer(ep).unwrap();
7615                let mut binding = String::new();
7616                resolved.try_fmt(&mut binding)?;
7617                args.push(EntryPointArgument {
7618                    ty_name: "constant _mslBufferSizes&".to_string(),
7619                    name: "_buffer_sizes".to_string(),
7620                    binding,
7621                    init: None,
7622                });
7623            }
7624
7625            let mut is_first_arg = true;
7626            for arg in &args {
7627                if is_first_arg {
7628                    write!(self.out, "  ")?;
7629                } else {
7630                    write!(self.out, ", ")?;
7631                }
7632                is_first_arg = false;
7633                write!(self.out, "{} {}", arg.ty_name, arg.name)?;
7634                if !is_wrapped {
7635                    write!(self.out, "{}", arg.binding)?;
7636                    if let Some(init) = arg.init {
7637                        write!(self.out, " = ")?;
7638                        self.put_const_expression(
7639                            init,
7640                            module,
7641                            mod_info,
7642                            &module.global_expressions,
7643                        )?;
7644                    }
7645                }
7646                writeln!(self.out)?;
7647            }
7648            if ep.stage == crate::ShaderStage::Mesh {
7649                for (handle, var) in module.global_variables.iter() {
7650                    if var.space != crate::AddressSpace::WorkGroup || fun_info[handle].is_empty() {
7651                        continue;
7652                    }
7653                    if is_first_arg {
7654                        write!(self.out, "  ")?;
7655                    } else {
7656                        write!(self.out, ", ")?;
7657                    }
7658                    let ty_context = TypeContext {
7659                        handle: module.global_variables[handle].ty,
7660                        gctx: module.to_ctx(),
7661                        names: &self.names,
7662                        access: crate::StorageAccess::empty(),
7663                        first_time: false,
7664                    };
7665                    writeln!(
7666                        self.out,
7667                        "threadgroup {ty_context}& {}",
7668                        self.names[&NameKey::GlobalVariable(handle)]
7669                    )?;
7670                }
7671            }
7672
7673            // end of the entry point argument list
7674            writeln!(self.out, ") {{")?;
7675
7676            // Starting the function body.
7677            if do_vertex_pulling {
7678                // Provide zero values for all the attributes, which we will overwrite with
7679                // real data from the vertex attribute buffers, if the indices are in-bounds.
7680                for vbm in &vbm_resolved {
7681                    for attribute in vbm.attributes {
7682                        let location = attribute.shader_location;
7683                        let am_option = am_resolved.get(&location);
7684                        if am_option.is_none() {
7685                            // This bound attribute isn't used in this entry point, so
7686                            // don't bother zero-initializing it.
7687                            continue;
7688                        }
7689                        let am = am_option.unwrap();
7690                        let attribute_ty_name = &am.ty_name;
7691                        let attribute_name = &am.name;
7692
7693                        writeln!(
7694                            self.out,
7695                            "{}{attribute_ty_name} {attribute_name} = {{}};",
7696                            back::Level(1)
7697                        )?;
7698                    }
7699
7700                    // Output a bounds check block that will set real values for the
7701                    // attributes, if the bounds are satisfied.
7702                    write!(self.out, "{}if (", back::Level(1))?;
7703
7704                    let idx = &vbm.id;
7705                    let stride = &vbm.stride;
7706                    let index_name = match vbm.step_mode {
7707                        back::msl::VertexBufferStepMode::Constant => "0",
7708                        back::msl::VertexBufferStepMode::ByVertex => {
7709                            if let Some(ref name) = v_existing_id {
7710                                name
7711                            } else {
7712                                &v_id
7713                            }
7714                        }
7715                        back::msl::VertexBufferStepMode::ByInstance => {
7716                            if let Some(ref name) = i_existing_id {
7717                                name
7718                            } else {
7719                                &i_id
7720                            }
7721                        }
7722                    };
7723                    write!(
7724                        self.out,
7725                        "{index_name} < (_buffer_sizes.buffer_size{idx} / {stride})"
7726                    )?;
7727
7728                    writeln!(self.out, ") {{")?;
7729
7730                    // Pull the bytes out of the vertex buffer.
7731                    let ty_name = &vbm.ty_name;
7732                    let elem_name = &vbm.elem_name;
7733                    let param_name = &vbm.param_name;
7734
7735                    writeln!(
7736                        self.out,
7737                        "{}const {ty_name} {elem_name} = {param_name}[{index_name}];",
7738                        back::Level(2),
7739                    )?;
7740
7741                    // Now set real values for each of the attributes, by unpacking the data
7742                    // from the buffer elements.
7743                    for attribute in vbm.attributes {
7744                        let location = attribute.shader_location;
7745                        let Some(am) = am_resolved.get(&location) else {
7746                            // This bound attribute isn't used in this entry point, so
7747                            // don't bother extracting the data. Too bad we emitted the
7748                            // unpacking function earlier -- it might not get used.
7749                            continue;
7750                        };
7751                        let attribute_name = &am.name;
7752                        let attribute_ty_name = &am.ty_name;
7753
7754                        let offset = attribute.offset;
7755                        let func = unpacking_functions
7756                            .get(&attribute.format)
7757                            .expect("Should have generated this unpacking function earlier.");
7758                        let func_name = &func.name;
7759
7760                        // Check dimensionality of the attribute compared to the unpacking
7761                        // function. If attribute dimension > unpack dimension, we have to
7762                        // pad out the unpack value from a vec4(0, 0, 0, 1) of matching
7763                        // scalar type. Otherwise, if attribute dimension is < unpack
7764                        // dimension, then we need to explicitly truncate the result.
7765                        let needs_padding_or_truncation = am.dimension.cmp(&func.dimension);
7766
7767                        // We need an extra type conversion if the shader type does not
7768                        // match the type returned from the unpacking function.
7769                        let needs_conversion = am.scalar != func.scalar;
7770
7771                        if needs_padding_or_truncation != Ordering::Equal {
7772                            // Emit a comment flagging that a conversion is happening,
7773                            // since the actual logic can be at the end of a long line.
7774                            writeln!(
7775                                self.out,
7776                                "{}// {attribute_ty_name} <- {:?}",
7777                                back::Level(2),
7778                                attribute.format
7779                            )?;
7780                        }
7781
7782                        write!(self.out, "{}{attribute_name} = ", back::Level(2),)?;
7783
7784                        if needs_padding_or_truncation == Ordering::Greater {
7785                            // Needs padding: emit constructor call for wider type
7786                            write!(self.out, "{attribute_ty_name}(")?;
7787                        }
7788
7789                        // Emit call to unpacking function
7790                        if needs_conversion {
7791                            put_numeric_type(&mut self.out, am.scalar, func.dimension.as_slice())?;
7792                            write!(self.out, "(")?;
7793                        }
7794                        write!(self.out, "{func_name}({elem_name}.data[{offset}]")?;
7795                        for i in (offset + 1)..(offset + func.byte_count) {
7796                            write!(self.out, ", {elem_name}.data[{i}]")?;
7797                        }
7798                        write!(self.out, ")")?;
7799                        if needs_conversion {
7800                            write!(self.out, ")")?;
7801                        }
7802
7803                        match needs_padding_or_truncation {
7804                            Ordering::Greater => {
7805                                // Padding
7806                                let ty_is_int = scalar_is_int(am.scalar);
7807                                let zero_value = if ty_is_int { "0" } else { "0.0" };
7808                                let one_value = if ty_is_int { "1" } else { "1.0" };
7809                                for i in func.dimension.map_or(1, u8::from)
7810                                    ..am.dimension.map_or(1, u8::from)
7811                                {
7812                                    write!(
7813                                        self.out,
7814                                        ", {}",
7815                                        if i == 3 { one_value } else { zero_value }
7816                                    )?;
7817                                }
7818                            }
7819                            Ordering::Less => {
7820                                // Truncate to the first `am.dimension` components
7821                                write!(
7822                                    self.out,
7823                                    ".{}",
7824                                    &"xyzw"[0..usize::from(am.dimension.map_or(1, u8::from))]
7825                                )?;
7826                            }
7827                            Ordering::Equal => {}
7828                        }
7829
7830                        if needs_padding_or_truncation == Ordering::Greater {
7831                            write!(self.out, ")")?;
7832                        }
7833
7834                        writeln!(self.out, ";")?;
7835                    }
7836
7837                    // End the bounds check / attribute setting block.
7838                    writeln!(self.out, "{}}}", back::Level(1))?;
7839                }
7840            }
7841
7842            // Metal doesn't support private mutable variables outside of functions,
7843            // so we put them here, just like the locals.
7844            for (handle, var) in module.global_variables.iter() {
7845                let usage = fun_info[handle];
7846                if usage.is_empty() {
7847                    continue;
7848                }
7849                if var.space == crate::AddressSpace::Private {
7850                    let tyvar = TypedGlobalVariable {
7851                        module,
7852                        names: &self.names,
7853                        handle,
7854                        usage,
7855
7856                        reference: false,
7857                    };
7858                    write!(self.out, "{}", back::INDENT)?;
7859                    tyvar.try_fmt(&mut self.out)?;
7860                    match var.init {
7861                        Some(value) => {
7862                            write!(self.out, " = ")?;
7863                            self.put_const_expression(
7864                                value,
7865                                module,
7866                                mod_info,
7867                                &module.global_expressions,
7868                            )?;
7869                            writeln!(self.out, ";")?;
7870                        }
7871                        None => {
7872                            writeln!(self.out, " = {{}};")?;
7873                        }
7874                    };
7875                } else if let Some(ref binding) = var.binding {
7876                    let resolved = options.resolve_resource_binding(ep, binding).unwrap();
7877                    if let Some(sampler) = resolved.as_inline_sampler(options) {
7878                        // write an inline sampler
7879                        let name = &self.names[&NameKey::GlobalVariable(handle)];
7880                        writeln!(
7881                            self.out,
7882                            "{}constexpr {}::sampler {}(",
7883                            back::INDENT,
7884                            NAMESPACE,
7885                            name
7886                        )?;
7887                        self.put_inline_sampler_properties(back::Level(2), sampler)?;
7888                        writeln!(self.out, "{});", back::INDENT)?;
7889                    } else if let crate::TypeInner::Image {
7890                        class: crate::ImageClass::External,
7891                        ..
7892                    } = module.types[var.ty].inner
7893                    {
7894                        // Wrap the individual arguments for each external texture global
7895                        // in a struct which can be easily passed around.
7896                        let wrapper_name = &self.names[&NameKey::GlobalVariable(handle)];
7897                        let l1 = back::Level(1);
7898                        let l2 = l1.next();
7899                        writeln!(
7900                            self.out,
7901                            "{l1}const {EXTERNAL_TEXTURE_WRAPPER_STRUCT} {wrapper_name} {{"
7902                        )?;
7903                        for i in 0..3 {
7904                            let plane_name = &self.names[&NameKey::ExternalTextureGlobalVariable(
7905                                handle,
7906                                ExternalTextureNameKey::Plane(i),
7907                            )];
7908                            writeln!(self.out, "{l2}.plane{i} = {plane_name},")?;
7909                        }
7910                        let params_name = &self.names[&NameKey::ExternalTextureGlobalVariable(
7911                            handle,
7912                            ExternalTextureNameKey::Params,
7913                        )];
7914                        writeln!(self.out, "{l2}.params = {params_name},")?;
7915                        writeln!(self.out, "{l1}}};")?;
7916                    }
7917                }
7918            }
7919
7920            if need_workgroup_variables_initialization {
7921                self.write_workgroup_variables_initialization(
7922                    module,
7923                    mod_info,
7924                    fun_info,
7925                    local_invocation_index,
7926                    ep.stage,
7927                )?;
7928            }
7929
7930            // Now take the arguments that we gathered into structs, and the
7931            // structs that we flattened into arguments, and emit local
7932            // variables with initializers that put everything back the way the
7933            // body code expects.
7934            //
7935            // If we had to generate fresh names for struct members passed as
7936            // arguments, be sure to use those names when rebuilding the struct.
7937            //
7938            // "Each day, I change some zeros to ones, and some ones to zeros.
7939            // The rest, I leave alone."
7940            for (arg_index, arg) in fun.arguments.iter().enumerate() {
7941                let arg_name =
7942                    &self.names[&NameKey::EntryPointArgument(ep_index as _, arg_index as u32)];
7943                match module.types[arg.ty].inner {
7944                    crate::TypeInner::Struct { ref members, .. } => {
7945                        let struct_name = &self.names[&NameKey::Type(arg.ty)];
7946                        write!(
7947                            self.out,
7948                            "{}const {} {} = {{ ",
7949                            back::INDENT,
7950                            struct_name,
7951                            arg_name
7952                        )?;
7953                        for (member_index, member) in members.iter().enumerate() {
7954                            let key = NameKey::StructMember(arg.ty, member_index as u32);
7955                            let name = &flattened_member_names[&key];
7956                            if member_index != 0 {
7957                                write!(self.out, ", ")?;
7958                            }
7959                            // insert padding initialization, if needed
7960                            if self
7961                                .struct_member_pads
7962                                .contains(&(arg.ty, member_index as u32))
7963                            {
7964                                write!(self.out, "{{}}, ")?;
7965                            }
7966                            match member.binding {
7967                                Some(crate::Binding::Location {
7968                                    interpolation: Some(crate::Interpolation::PerVertex),
7969                                    ..
7970                                }) => {
7971                                    writeln!(
7972                                        self.out,
7973                                        "{0}{{ {1}.{2}.get({NAMESPACE}::vertex_index::first), {1}.{2}.get({NAMESPACE}::vertex_index::second), {1}.{2}.get({NAMESPACE}::vertex_index::third) }}",
7974                                        back::INDENT,
7975                                        varyings_member_name,
7976                                        arg_name,
7977                                    )?;
7978                                    continue;
7979                                }
7980                                Some(crate::Binding::Location { .. }) => {
7981                                    if has_varyings {
7982                                        write!(self.out, "{varyings_member_name}.")?;
7983                                    }
7984                                }
7985                                _ => (),
7986                            }
7987                            write!(self.out, "{name}")?;
7988                        }
7989                        writeln!(self.out, " }};")?;
7990                    }
7991                    _ => match arg.binding {
7992                        Some(crate::Binding::Location {
7993                            interpolation: Some(crate::Interpolation::PerVertex),
7994                            ..
7995                        }) => {
7996                            let ty_name = TypeContext {
7997                                handle: arg.ty,
7998                                gctx: module.to_ctx(),
7999                                names: &self.names,
8000                                access: crate::StorageAccess::empty(),
8001                                first_time: false,
8002                            };
8003                            writeln!(
8004                                self.out,
8005                                "{0}const {ty_name} {arg_name} = {{ {1}.{2}.get({NAMESPACE}::vertex_index::first), {1}.{2}.get({NAMESPACE}::vertex_index::second), {1}.{2}.get({NAMESPACE}::vertex_index::third) }};",
8006                                back::INDENT,
8007                                varyings_member_name,
8008                                arg_name,
8009                            )?;
8010                        }
8011                        Some(crate::Binding::Location { .. })
8012                        | Some(crate::Binding::BuiltIn(crate::BuiltIn::Barycentric { .. })) => {
8013                            if has_varyings {
8014                                writeln!(
8015                                    self.out,
8016                                    "{}const auto {} = {}.{};",
8017                                    back::INDENT,
8018                                    arg_name,
8019                                    varyings_member_name,
8020                                    arg_name
8021                                )?;
8022                            }
8023                        }
8024                        _ => {}
8025                    },
8026                }
8027            }
8028
8029            let guarded_indices =
8030                index::find_checked_indexes(module, fun, fun_info, options.bounds_check_policies);
8031
8032            let context = StatementContext {
8033                expression: ExpressionContext {
8034                    function: fun,
8035                    origin: FunctionOrigin::EntryPoint(ep_index as _),
8036                    info: fun_info,
8037                    lang_version: options.lang_version,
8038                    policies: options.bounds_check_policies,
8039                    guarded_indices,
8040                    module,
8041                    mod_info,
8042                    pipeline_options,
8043                    force_loop_bounding: options.force_loop_bounding,
8044                },
8045                result_struct: if ep.stage == crate::ShaderStage::Task {
8046                    None
8047                } else {
8048                    Some(&stage_out_name)
8049                },
8050            };
8051
8052            // Finally, declare all the local variables that we need
8053            //TODO: we can postpone this till the relevant expressions are emitted
8054            self.put_locals(&context.expression)?;
8055            self.update_expressions_to_bake(fun, fun_info, &context.expression);
8056            self.put_block(back::Level(1), &fun.body, &context)?;
8057            writeln!(self.out, "}}")?;
8058            if ep_index + 1 != module.entry_points.len() {
8059                writeln!(self.out)?;
8060            }
8061            self.named_expressions.clear();
8062
8063            if is_wrapped {
8064                self.write_wrapper_function(NestedFunctionInfo {
8065                    options,
8066                    ep,
8067                    module,
8068                    mod_info,
8069                    fun_info,
8070                    args,
8071                    local_invocation_index,
8072                    nested_name: &nested_fun_name,
8073                    outer_name: &fun_name,
8074                    out_mesh_info,
8075                })?;
8076            }
8077        }
8078
8079        Ok(info)
8080    }
8081
8082    pub(super) fn write_barrier(
8083        &mut self,
8084        flags: crate::Barrier,
8085        level: back::Level,
8086    ) -> BackendResult {
8087        // Note: OR-ring bitflags requires `__HAVE_MEMFLAG_OPERATORS__`,
8088        // so we try to avoid it here.
8089        if flags.is_empty() {
8090            writeln!(
8091                self.out,
8092                "{level}{NAMESPACE}::threadgroup_barrier({NAMESPACE}::mem_flags::mem_none);",
8093            )?;
8094        }
8095        if flags.contains(crate::Barrier::STORAGE) {
8096            writeln!(
8097                self.out,
8098                "{level}{NAMESPACE}::threadgroup_barrier({NAMESPACE}::mem_flags::mem_device);",
8099            )?;
8100        }
8101        if flags.contains(crate::Barrier::WORK_GROUP) {
8102            writeln!(
8103                self.out,
8104                "{level}{NAMESPACE}::threadgroup_barrier({NAMESPACE}::mem_flags::mem_threadgroup);",
8105            )?;
8106            if self.needs_object_memory_barriers {
8107                writeln!(
8108                    self.out,
8109                    "{level}{NAMESPACE}::threadgroup_barrier({NAMESPACE}::mem_flags::mem_object_data);",
8110                )?;
8111            }
8112        }
8113        if flags.contains(crate::Barrier::SUB_GROUP) {
8114            writeln!(
8115                self.out,
8116                "{level}{NAMESPACE}::simdgroup_barrier({NAMESPACE}::mem_flags::mem_threadgroup);",
8117            )?;
8118        }
8119        if flags.contains(crate::Barrier::TEXTURE) {
8120            writeln!(
8121                self.out,
8122                "{level}{NAMESPACE}::threadgroup_barrier({NAMESPACE}::mem_flags::mem_texture);",
8123            )?;
8124        }
8125        Ok(())
8126    }
8127}
8128
8129/// Initializing workgroup variables is more tricky for Metal because we have to deal
8130/// with atomics at the type-level (which don't have a copy constructor).
8131mod workgroup_mem_init {
8132    use crate::EntryPoint;
8133
8134    use super::*;
8135
8136    enum Access {
8137        GlobalVariable(Handle<crate::GlobalVariable>),
8138        StructMember(Handle<crate::Type>, u32),
8139        Array(usize),
8140    }
8141
8142    impl Access {
8143        fn write<W: Write>(
8144            &self,
8145            writer: &mut W,
8146            names: &FastHashMap<NameKey, String>,
8147        ) -> Result<(), core::fmt::Error> {
8148            match *self {
8149                Access::GlobalVariable(handle) => {
8150                    write!(writer, "{}", &names[&NameKey::GlobalVariable(handle)])
8151                }
8152                Access::StructMember(handle, index) => {
8153                    write!(writer, ".{}", &names[&NameKey::StructMember(handle, index)])
8154                }
8155                Access::Array(depth) => write!(writer, ".{WRAPPED_ARRAY_FIELD}[__i{depth}]"),
8156            }
8157        }
8158    }
8159
8160    struct AccessStack {
8161        stack: Vec<Access>,
8162        array_depth: usize,
8163    }
8164
8165    impl AccessStack {
8166        const fn new() -> Self {
8167            Self {
8168                stack: Vec::new(),
8169                array_depth: 0,
8170            }
8171        }
8172
8173        fn enter_array<R>(&mut self, cb: impl FnOnce(&mut Self, usize) -> R) -> R {
8174            let array_depth = self.array_depth;
8175            self.stack.push(Access::Array(array_depth));
8176            self.array_depth += 1;
8177            let res = cb(self, array_depth);
8178            self.stack.pop();
8179            self.array_depth -= 1;
8180            res
8181        }
8182
8183        fn enter<R>(&mut self, new: Access, cb: impl FnOnce(&mut Self) -> R) -> R {
8184            self.stack.push(new);
8185            let res = cb(self);
8186            self.stack.pop();
8187            res
8188        }
8189
8190        fn write<W: Write>(
8191            &self,
8192            writer: &mut W,
8193            names: &FastHashMap<NameKey, String>,
8194        ) -> Result<(), core::fmt::Error> {
8195            for next in self.stack.iter() {
8196                next.write(writer, names)?;
8197            }
8198            Ok(())
8199        }
8200    }
8201
8202    impl<W: Write> Writer<W> {
8203        pub(super) fn need_workgroup_variables_initialization(
8204            &mut self,
8205            options: &Options,
8206            ep: &EntryPoint,
8207            module: &crate::Module,
8208            fun_info: &valid::FunctionInfo,
8209        ) -> bool {
8210            let is_task = ep.stage == crate::ShaderStage::Task;
8211            options.zero_initialize_workgroup_memory
8212                && ep.stage.compute_like()
8213                && module.global_variables.iter().any(|(handle, var)| {
8214                    let is_right_address_space = var.space == crate::AddressSpace::WorkGroup
8215                        || (var.space == crate::AddressSpace::TaskPayload && is_task);
8216                    !fun_info[handle].is_empty() && is_right_address_space
8217                })
8218        }
8219
8220        pub fn write_workgroup_variables_initialization(
8221            &mut self,
8222            module: &crate::Module,
8223            module_info: &valid::ModuleInfo,
8224            fun_info: &valid::FunctionInfo,
8225            local_invocation_index: Option<&NameKey>,
8226            stage: crate::ShaderStage,
8227        ) -> BackendResult {
8228            let level = back::Level(1);
8229
8230            writeln!(
8231                self.out,
8232                "{}if ({} == 0u) {{",
8233                level,
8234                local_invocation_index
8235                    .map(|name_key| self.names[name_key].as_str())
8236                    .unwrap_or("__local_invocation_index"),
8237            )?;
8238
8239            let mut access_stack = AccessStack::new();
8240
8241            let is_task = stage == crate::ShaderStage::Task;
8242            let vars = module.global_variables.iter().filter(|&(handle, var)| {
8243                let is_right_address_space = var.space == crate::AddressSpace::WorkGroup
8244                    || (var.space == crate::AddressSpace::TaskPayload && is_task);
8245                !fun_info[handle].is_empty() && is_right_address_space
8246            });
8247
8248            for (handle, var) in vars {
8249                access_stack.enter(Access::GlobalVariable(handle), |access_stack| {
8250                    self.write_workgroup_variable_initialization(
8251                        module,
8252                        module_info,
8253                        var.ty,
8254                        access_stack,
8255                        level.next(),
8256                    )
8257                })?;
8258            }
8259
8260            writeln!(self.out, "{level}}}")?;
8261            self.write_barrier(crate::Barrier::WORK_GROUP, level)
8262        }
8263
8264        fn write_workgroup_variable_initialization(
8265            &mut self,
8266            module: &crate::Module,
8267            module_info: &valid::ModuleInfo,
8268            ty: Handle<crate::Type>,
8269            access_stack: &mut AccessStack,
8270            level: back::Level,
8271        ) -> BackendResult {
8272            if module_info[ty].contains(valid::TypeFlags::CONSTRUCTIBLE) {
8273                write!(self.out, "{level}")?;
8274                access_stack.write(&mut self.out, &self.names)?;
8275                writeln!(self.out, " = {{}};")?;
8276            } else {
8277                match module.types[ty].inner {
8278                    crate::TypeInner::Atomic { .. } => {
8279                        write!(
8280                            self.out,
8281                            "{level}{NAMESPACE}::atomic_store_explicit({ATOMIC_REFERENCE}"
8282                        )?;
8283                        access_stack.write(&mut self.out, &self.names)?;
8284                        writeln!(self.out, ", 0, {NAMESPACE}::memory_order_relaxed);")?;
8285                    }
8286                    crate::TypeInner::Array { base, size, .. } => {
8287                        let count = match size.resolve(module.to_ctx())? {
8288                            proc::IndexableLength::Known(count) => count,
8289                            proc::IndexableLength::Dynamic => unreachable!(),
8290                        };
8291
8292                        access_stack.enter_array(|access_stack, array_depth| {
8293                            writeln!(
8294                                self.out,
8295                                "{level}for (int __i{array_depth} = 0; __i{array_depth} < {count}; __i{array_depth}++) {{"
8296                            )?;
8297                            self.write_workgroup_variable_initialization(
8298                                module,
8299                                module_info,
8300                                base,
8301                                access_stack,
8302                                level.next(),
8303                            )?;
8304                            writeln!(self.out, "{level}}}")?;
8305                            BackendResult::Ok(())
8306                        })?;
8307                    }
8308                    crate::TypeInner::Struct { ref members, .. } => {
8309                        for (index, member) in members.iter().enumerate() {
8310                            access_stack.enter(
8311                                Access::StructMember(ty, index as u32),
8312                                |access_stack| {
8313                                    self.write_workgroup_variable_initialization(
8314                                        module,
8315                                        module_info,
8316                                        member.ty,
8317                                        access_stack,
8318                                        level,
8319                                    )
8320                                },
8321                            )?;
8322                        }
8323                    }
8324                    _ => unreachable!(),
8325                }
8326            }
8327
8328            Ok(())
8329        }
8330    }
8331}
8332
8333impl crate::AtomicFunction {
8334    const fn to_msl(self) -> &'static str {
8335        match self {
8336            Self::Add => "fetch_add",
8337            Self::Subtract => "fetch_sub",
8338            Self::And => "fetch_and",
8339            Self::InclusiveOr => "fetch_or",
8340            Self::ExclusiveOr => "fetch_xor",
8341            Self::Min => "fetch_min",
8342            Self::Max => "fetch_max",
8343            Self::Exchange { compare: None } => "exchange",
8344            Self::Exchange { compare: Some(_) } => ATOMIC_COMP_EXCH_FUNCTION,
8345        }
8346    }
8347
8348    fn to_msl_64_bit(self) -> Result<&'static str, Error> {
8349        Ok(match self {
8350            Self::Min => "min",
8351            Self::Max => "max",
8352            _ => Err(Error::FeatureNotImplemented(
8353                "64-bit atomic operation other than min/max".to_string(),
8354            ))?,
8355        })
8356    }
8357}