naga/back/msl/
writer.rs

1use alloc::{
2    format,
3    string::{String, ToString},
4    vec,
5    vec::Vec,
6};
7use core::{
8    cmp::Ordering,
9    fmt::{Display, Error as FmtError, Formatter, Write},
10    iter,
11};
12use num_traits::real::Real as _;
13
14use half::f16;
15
16use super::{sampler as sm, Error, LocationMode, Options, PipelineOptions, TranslationInfo};
17use crate::{
18    arena::{Handle, HandleSet},
19    back::{self, get_entry_points, Baked},
20    common,
21    proc::{
22        self, concrete_int_scalars,
23        index::{self, BoundsCheck},
24        ExternalTextureNameKey, NameKey, TypeResolution,
25    },
26    valid, FastHashMap, FastHashSet,
27};
28
29#[cfg(test)]
30use core::ptr;
31
32/// Shorthand result used internally by the backend
33type BackendResult = Result<(), Error>;
34
35const NAMESPACE: &str = "metal";
36// The name of the array member of the Metal struct types we generate to
37// represent Naga `Array` types. See the comments in `Writer::write_type_defs`
38// for details.
39const WRAPPED_ARRAY_FIELD: &str = "inner";
40// This is a hack: we need to pass a pointer to an atomic,
41// but generally the backend isn't putting "&" in front of every pointer.
42// Some more general handling of pointers is needed to be implemented here.
43const ATOMIC_REFERENCE: &str = "&";
44
45const RT_NAMESPACE: &str = "metal::raytracing";
46const RAY_QUERY_TYPE: &str = "_RayQuery";
47const RAY_QUERY_FIELD_INTERSECTOR: &str = "intersector";
48const RAY_QUERY_FIELD_INTERSECTION: &str = "intersection";
49const RAY_QUERY_MODERN_SUPPORT: bool = false; //TODO
50const RAY_QUERY_FIELD_READY: &str = "ready";
51const RAY_QUERY_FUN_MAP_INTERSECTION: &str = "_map_intersection_type";
52
53pub(crate) const ATOMIC_COMP_EXCH_FUNCTION: &str = "naga_atomic_compare_exchange_weak_explicit";
54pub(crate) const MODF_FUNCTION: &str = "naga_modf";
55pub(crate) const FREXP_FUNCTION: &str = "naga_frexp";
56pub(crate) const ABS_FUNCTION: &str = "naga_abs";
57pub(crate) const DIV_FUNCTION: &str = "naga_div";
58pub(crate) const DOT_FUNCTION_PREFIX: &str = "naga_dot";
59pub(crate) const MOD_FUNCTION: &str = "naga_mod";
60pub(crate) const NEG_FUNCTION: &str = "naga_neg";
61pub(crate) const F2I32_FUNCTION: &str = "naga_f2i32";
62pub(crate) const F2U32_FUNCTION: &str = "naga_f2u32";
63pub(crate) const F2I64_FUNCTION: &str = "naga_f2i64";
64pub(crate) const F2U64_FUNCTION: &str = "naga_f2u64";
65pub(crate) const IMAGE_LOAD_EXTERNAL_FUNCTION: &str = "nagaTextureLoadExternal";
66pub(crate) const IMAGE_SIZE_EXTERNAL_FUNCTION: &str = "nagaTextureDimensionsExternal";
67pub(crate) const IMAGE_SAMPLE_BASE_CLAMP_TO_EDGE_FUNCTION: &str =
68    "nagaTextureSampleBaseClampToEdge";
69/// For some reason, Metal does not let you have `metal::texture<..>*` as a buffer argument.
70/// However, if you put that texture inside a struct, everything is totally fine. This
71/// baffles me to no end.
72///
73/// As such, we wrap all argument buffers in a struct that has a single generic `<T>` field.
74/// This allows `NagaArgumentBufferWrapper<metal::texture<..>>*` to work. The astute among
75/// you have noticed that this should be exactly the same to the compiler, and you're correct.
76pub(crate) const ARGUMENT_BUFFER_WRAPPER_STRUCT: &str = "NagaArgumentBufferWrapper";
77/// Name of the struct that is declared to wrap the 3 textures and parameters
78/// buffer that [`crate::ImageClass::External`] variables are lowered to,
79/// allowing them to be conveniently passed to user-defined or wrapper
80/// functions. The struct is declared in [`Writer::write_type_defs`].
81pub(crate) const EXTERNAL_TEXTURE_WRAPPER_STRUCT: &str = "NagaExternalTextureWrapper";
82pub(crate) const COOPERATIVE_LOAD_FUNCTION: &str = "NagaCooperativeLoad";
83pub(crate) const COOPERATIVE_MULTIPLY_ADD_FUNCTION: &str = "NagaCooperativeMultiplyAdd";
84
85/// Write the Metal name for a Naga numeric type: scalar, vector, or matrix.
86///
87/// The `sizes` slice determines whether this function writes a
88/// scalar, vector, or matrix type:
89///
90/// - An empty slice produces a scalar type.
91/// - A one-element slice produces a vector type.
92/// - A two element slice `[ROWS COLUMNS]` produces a matrix of the given size.
93fn put_numeric_type(
94    out: &mut impl Write,
95    scalar: crate::Scalar,
96    sizes: &[crate::VectorSize],
97) -> Result<(), FmtError> {
98    match (scalar, sizes) {
99        (scalar, &[]) => {
100            write!(out, "{}", scalar.to_msl_name())
101        }
102        (scalar, &[rows]) => {
103            write!(
104                out,
105                "{}::{}{}",
106                NAMESPACE,
107                scalar.to_msl_name(),
108                common::vector_size_str(rows)
109            )
110        }
111        (scalar, &[rows, columns]) => {
112            write!(
113                out,
114                "{}::{}{}x{}",
115                NAMESPACE,
116                scalar.to_msl_name(),
117                common::vector_size_str(columns),
118                common::vector_size_str(rows)
119            )
120        }
121        (_, _) => Ok(()), // not meaningful
122    }
123}
124
125const fn scalar_is_int(scalar: crate::Scalar) -> bool {
126    use crate::ScalarKind::*;
127    match scalar.kind {
128        Sint | Uint | AbstractInt | Bool => true,
129        Float | AbstractFloat => false,
130    }
131}
132
133/// Prefix for cached clamped level-of-detail values for `ImageLoad` expressions.
134const CLAMPED_LOD_LOAD_PREFIX: &str = "clamped_lod_e";
135
136/// Prefix for reinterpreted expressions using `as_type<T>(...)`.
137const REINTERPRET_PREFIX: &str = "reinterpreted_";
138
139/// Wrapper for identifier names for clamped level-of-detail values
140///
141/// Values of this type implement [`core::fmt::Display`], formatting as
142/// the name of the variable used to hold the cached clamped
143/// level-of-detail value for an `ImageLoad` expression.
144struct ClampedLod(Handle<crate::Expression>);
145
146impl Display for ClampedLod {
147    fn fmt(&self, f: &mut Formatter<'_>) -> core::fmt::Result {
148        self.0.write_prefixed(f, CLAMPED_LOD_LOAD_PREFIX)
149    }
150}
151
152/// Wrapper for generating `struct _mslBufferSizes` member names for
153/// runtime-sized array lengths.
154///
155/// On Metal, `wgpu_hal` passes the element counts for all runtime-sized arrays
156/// as an argument to the entry point. This argument's type in the MSL is
157/// `struct _mslBufferSizes`, a Naga-synthesized struct with a `uint` member for
158/// each global variable containing a runtime-sized array.
159///
160/// If `global` is a [`Handle`] for a [`GlobalVariable`] that contains a
161/// runtime-sized array, then the value `ArraySize(global)` implements
162/// [`core::fmt::Display`], formatting as the name of the struct member carrying
163/// the number of elements in that runtime-sized array.
164///
165/// [`GlobalVariable`]: crate::GlobalVariable
166struct ArraySizeMember(Handle<crate::GlobalVariable>);
167
168impl Display for ArraySizeMember {
169    fn fmt(&self, f: &mut Formatter<'_>) -> core::fmt::Result {
170        self.0.write_prefixed(f, "size")
171    }
172}
173
174/// Wrapper for reinterpreted variables using `as_type<target_type>(orig)`.
175///
176/// Implements [`core::fmt::Display`], formatting as a name derived from
177/// `target_type` and the variable name of `orig`.
178#[derive(Clone, Copy)]
179struct Reinterpreted<'a> {
180    target_type: &'a str,
181    orig: Handle<crate::Expression>,
182}
183
184impl<'a> Reinterpreted<'a> {
185    const fn new(target_type: &'a str, orig: Handle<crate::Expression>) -> Self {
186        Self { target_type, orig }
187    }
188}
189
190impl Display for Reinterpreted<'_> {
191    fn fmt(&self, f: &mut Formatter<'_>) -> core::fmt::Result {
192        f.write_str(REINTERPRET_PREFIX)?;
193        f.write_str(self.target_type)?;
194        self.orig.write_prefixed(f, "_e")
195    }
196}
197
198struct TypeContext<'a> {
199    handle: Handle<crate::Type>,
200    gctx: proc::GlobalCtx<'a>,
201    names: &'a FastHashMap<NameKey, String>,
202    access: crate::StorageAccess,
203    first_time: bool,
204}
205
206impl TypeContext<'_> {
207    fn scalar(&self) -> Option<crate::Scalar> {
208        let ty = &self.gctx.types[self.handle];
209        ty.inner.scalar()
210    }
211
212    fn vertex_input_dimension(&self) -> u32 {
213        let ty = &self.gctx.types[self.handle];
214        match ty.inner {
215            crate::TypeInner::Scalar(_) => 1,
216            crate::TypeInner::Vector { size, .. } => size as u32,
217            _ => unreachable!(),
218        }
219    }
220}
221
222impl Display for TypeContext<'_> {
223    fn fmt(&self, out: &mut Formatter<'_>) -> Result<(), FmtError> {
224        let ty = &self.gctx.types[self.handle];
225        if ty.needs_alias() && !self.first_time {
226            let name = &self.names[&NameKey::Type(self.handle)];
227            return write!(out, "{name}");
228        }
229
230        match ty.inner {
231            crate::TypeInner::Scalar(scalar) => put_numeric_type(out, scalar, &[]),
232            crate::TypeInner::Atomic(scalar) => {
233                write!(out, "{}::atomic_{}", NAMESPACE, scalar.to_msl_name())
234            }
235            crate::TypeInner::Vector { size, scalar } => put_numeric_type(out, scalar, &[size]),
236            crate::TypeInner::Matrix {
237                columns,
238                rows,
239                scalar,
240            } => put_numeric_type(out, scalar, &[rows, columns]),
241            // Requires Metal-2.3
242            crate::TypeInner::CooperativeMatrix {
243                columns,
244                rows,
245                scalar,
246                role: _,
247            } => {
248                write!(
249                    out,
250                    "{NAMESPACE}::simdgroup_{}{}x{}",
251                    scalar.to_msl_name(),
252                    columns as u32,
253                    rows as u32,
254                )
255            }
256            crate::TypeInner::Pointer { base, space } => {
257                let sub = Self {
258                    handle: base,
259                    first_time: false,
260                    ..*self
261                };
262                let space_name = match space.to_msl_name() {
263                    Some(name) => name,
264                    None => return Ok(()),
265                };
266                write!(out, "{space_name} {sub}&")
267            }
268            crate::TypeInner::ValuePointer {
269                size,
270                scalar,
271                space,
272            } => {
273                match space.to_msl_name() {
274                    Some(name) => write!(out, "{name} ")?,
275                    None => return Ok(()),
276                };
277                match size {
278                    Some(rows) => put_numeric_type(out, scalar, &[rows])?,
279                    None => put_numeric_type(out, scalar, &[])?,
280                };
281
282                write!(out, "&")
283            }
284            crate::TypeInner::Array { base, .. } => {
285                let sub = Self {
286                    handle: base,
287                    first_time: false,
288                    ..*self
289                };
290                // Array lengths go at the end of the type definition,
291                // so just print the element type here.
292                write!(out, "{sub}")
293            }
294            crate::TypeInner::Struct { .. } => unreachable!(),
295            crate::TypeInner::Image {
296                dim,
297                arrayed,
298                class,
299            } => {
300                let dim_str = match dim {
301                    crate::ImageDimension::D1 => "1d",
302                    crate::ImageDimension::D2 => "2d",
303                    crate::ImageDimension::D3 => "3d",
304                    crate::ImageDimension::Cube => "cube",
305                };
306                let (texture_str, msaa_str, scalar, access) = match class {
307                    crate::ImageClass::Sampled { kind, multi } => {
308                        let (msaa_str, access) = if multi {
309                            ("_ms", "read")
310                        } else {
311                            ("", "sample")
312                        };
313                        let scalar = crate::Scalar { kind, width: 4 };
314                        ("texture", msaa_str, scalar, access)
315                    }
316                    crate::ImageClass::Depth { multi } => {
317                        let (msaa_str, access) = if multi {
318                            ("_ms", "read")
319                        } else {
320                            ("", "sample")
321                        };
322                        let scalar = crate::Scalar {
323                            kind: crate::ScalarKind::Float,
324                            width: 4,
325                        };
326                        ("depth", msaa_str, scalar, access)
327                    }
328                    crate::ImageClass::Storage { format, .. } => {
329                        let access = if self
330                            .access
331                            .contains(crate::StorageAccess::LOAD | crate::StorageAccess::STORE)
332                        {
333                            "read_write"
334                        } else if self.access.contains(crate::StorageAccess::STORE) {
335                            "write"
336                        } else if self.access.contains(crate::StorageAccess::LOAD) {
337                            "read"
338                        } else {
339                            log::warn!(
340                                "Storage access for {:?} (name '{}'): {:?}",
341                                self.handle,
342                                ty.name.as_deref().unwrap_or_default(),
343                                self.access
344                            );
345                            unreachable!("module is not valid");
346                        };
347                        ("texture", "", format.into(), access)
348                    }
349                    crate::ImageClass::External => {
350                        return write!(out, "{EXTERNAL_TEXTURE_WRAPPER_STRUCT}");
351                    }
352                };
353                let base_name = scalar.to_msl_name();
354                let array_str = if arrayed { "_array" } else { "" };
355                write!(
356                    out,
357                    "{NAMESPACE}::{texture_str}{dim_str}{msaa_str}{array_str}<{base_name}, {NAMESPACE}::access::{access}>",
358                )
359            }
360            crate::TypeInner::Sampler { comparison: _ } => {
361                write!(out, "{NAMESPACE}::sampler")
362            }
363            crate::TypeInner::AccelerationStructure { vertex_return } => {
364                if vertex_return {
365                    unimplemented!("metal does not support vertex ray hit return")
366                }
367                write!(out, "{RT_NAMESPACE}::instance_acceleration_structure")
368            }
369            crate::TypeInner::RayQuery { vertex_return } => {
370                if vertex_return {
371                    unimplemented!("metal does not support vertex ray hit return")
372                }
373                write!(out, "{RAY_QUERY_TYPE}")
374            }
375            crate::TypeInner::BindingArray { base, .. } => {
376                let base_tyname = Self {
377                    handle: base,
378                    first_time: false,
379                    ..*self
380                };
381
382                write!(
383                    out,
384                    "constant {ARGUMENT_BUFFER_WRAPPER_STRUCT}<{base_tyname}>*"
385                )
386            }
387        }
388    }
389}
390
391struct TypedGlobalVariable<'a> {
392    module: &'a crate::Module,
393    names: &'a FastHashMap<NameKey, String>,
394    handle: Handle<crate::GlobalVariable>,
395    usage: valid::GlobalUse,
396    reference: bool,
397}
398
399impl TypedGlobalVariable<'_> {
400    fn try_fmt<W: Write>(&self, out: &mut W) -> BackendResult {
401        let var = &self.module.global_variables[self.handle];
402        let name = &self.names[&NameKey::GlobalVariable(self.handle)];
403
404        let storage_access = match var.space {
405            crate::AddressSpace::Storage { access } => access,
406            _ => match self.module.types[var.ty].inner {
407                crate::TypeInner::Image {
408                    class: crate::ImageClass::Storage { access, .. },
409                    ..
410                } => access,
411                crate::TypeInner::BindingArray { base, .. } => {
412                    match self.module.types[base].inner {
413                        crate::TypeInner::Image {
414                            class: crate::ImageClass::Storage { access, .. },
415                            ..
416                        } => access,
417                        _ => crate::StorageAccess::default(),
418                    }
419                }
420                _ => crate::StorageAccess::default(),
421            },
422        };
423        let ty_name = TypeContext {
424            handle: var.ty,
425            gctx: self.module.to_ctx(),
426            names: self.names,
427            access: storage_access,
428            first_time: false,
429        };
430
431        let (space, access, reference) = match var.space.to_msl_name() {
432            Some(space) if self.reference => {
433                let access = if var.space.needs_access_qualifier()
434                    && !self.usage.intersects(valid::GlobalUse::WRITE)
435                {
436                    "const"
437                } else {
438                    ""
439                };
440                (space, access, "&")
441            }
442            _ => ("", "", ""),
443        };
444
445        Ok(write!(
446            out,
447            "{}{}{}{}{}{} {}",
448            space,
449            if space.is_empty() { "" } else { " " },
450            ty_name,
451            if access.is_empty() { "" } else { " " },
452            access,
453            reference,
454            name,
455        )?)
456    }
457}
458
459#[derive(Eq, PartialEq, Hash)]
460enum WrappedFunction {
461    UnaryOp {
462        op: crate::UnaryOperator,
463        ty: (Option<crate::VectorSize>, crate::Scalar),
464    },
465    BinaryOp {
466        op: crate::BinaryOperator,
467        left_ty: (Option<crate::VectorSize>, crate::Scalar),
468        right_ty: (Option<crate::VectorSize>, crate::Scalar),
469    },
470    Math {
471        fun: crate::MathFunction,
472        arg_ty: (Option<crate::VectorSize>, crate::Scalar),
473    },
474    Cast {
475        src_scalar: crate::Scalar,
476        vector_size: Option<crate::VectorSize>,
477        dst_scalar: crate::Scalar,
478    },
479    ImageLoad {
480        class: crate::ImageClass,
481    },
482    ImageSample {
483        class: crate::ImageClass,
484        clamp_to_edge: bool,
485    },
486    ImageQuerySize {
487        class: crate::ImageClass,
488    },
489    CooperativeLoad {
490        space_name: &'static str,
491        columns: crate::CooperativeSize,
492        rows: crate::CooperativeSize,
493        scalar: crate::Scalar,
494    },
495    CooperativeMultiplyAdd {
496        space_name: &'static str,
497        columns: crate::CooperativeSize,
498        rows: crate::CooperativeSize,
499        intermediate: crate::CooperativeSize,
500        scalar: crate::Scalar,
501    },
502}
503
504pub struct Writer<W> {
505    out: W,
506    names: FastHashMap<NameKey, String>,
507    named_expressions: crate::NamedExpressions,
508    /// Set of expressions that need to be baked to avoid unnecessary repetition in output
509    need_bake_expressions: back::NeedBakeExpressions,
510    namer: proc::Namer,
511    wrapped_functions: FastHashSet<WrappedFunction>,
512    #[cfg(test)]
513    put_expression_stack_pointers: FastHashSet<*const ()>,
514    #[cfg(test)]
515    put_block_stack_pointers: FastHashSet<*const ()>,
516    /// Set of (struct type, struct field index) denoting which fields require
517    /// padding inserted **before** them (i.e. between fields at index - 1 and index)
518    struct_member_pads: FastHashSet<(Handle<crate::Type>, u32)>,
519}
520
521impl crate::Scalar {
522    pub(super) fn to_msl_name(self) -> &'static str {
523        use crate::ScalarKind as Sk;
524        match self {
525            Self {
526                kind: Sk::Float,
527                width: 4,
528            } => "float",
529            Self {
530                kind: Sk::Float,
531                width: 2,
532            } => "half",
533            Self {
534                kind: Sk::Sint,
535                width: 4,
536            } => "int",
537            Self {
538                kind: Sk::Uint,
539                width: 4,
540            } => "uint",
541            Self {
542                kind: Sk::Sint,
543                width: 8,
544            } => "long",
545            Self {
546                kind: Sk::Uint,
547                width: 8,
548            } => "ulong",
549            Self {
550                kind: Sk::Bool,
551                width: _,
552            } => "bool",
553            Self {
554                kind: Sk::AbstractInt | Sk::AbstractFloat,
555                width: _,
556            } => unreachable!("Found Abstract scalar kind"),
557            _ => unreachable!("Unsupported scalar kind: {:?}", self),
558        }
559    }
560}
561
562const fn separate(need_separator: bool) -> &'static str {
563    if need_separator {
564        ","
565    } else {
566        ""
567    }
568}
569
570fn should_pack_struct_member(
571    members: &[crate::StructMember],
572    span: u32,
573    index: usize,
574    module: &crate::Module,
575) -> Option<crate::Scalar> {
576    let member = &members[index];
577
578    let ty_inner = &module.types[member.ty].inner;
579    let last_offset = member.offset + ty_inner.size(module.to_ctx());
580    let next_offset = match members.get(index + 1) {
581        Some(next) => next.offset,
582        None => span,
583    };
584    let is_tight = next_offset == last_offset;
585
586    match *ty_inner {
587        crate::TypeInner::Vector {
588            size: crate::VectorSize::Tri,
589            scalar: scalar @ crate::Scalar { width: 4 | 2, .. },
590        } if is_tight => Some(scalar),
591        _ => None,
592    }
593}
594
595fn needs_array_length(ty: Handle<crate::Type>, arena: &crate::UniqueArena<crate::Type>) -> bool {
596    match arena[ty].inner {
597        crate::TypeInner::Struct { ref members, .. } => {
598            if let Some(member) = members.last() {
599                if let crate::TypeInner::Array {
600                    size: crate::ArraySize::Dynamic,
601                    ..
602                } = arena[member.ty].inner
603                {
604                    return true;
605                }
606            }
607            false
608        }
609        crate::TypeInner::Array {
610            size: crate::ArraySize::Dynamic,
611            ..
612        } => true,
613        _ => false,
614    }
615}
616
617impl crate::AddressSpace {
618    /// Returns true if global variables in this address space are
619    /// passed in function arguments. These arguments need to be
620    /// passed through any functions called from the entry point.
621    const fn needs_pass_through(&self) -> bool {
622        match *self {
623            Self::Uniform
624            | Self::Storage { .. }
625            | Self::Private
626            | Self::WorkGroup
627            | Self::Immediate
628            | Self::Handle
629            | Self::TaskPayload => true,
630            Self::Function => false,
631        }
632    }
633
634    /// Returns true if the address space may need a "const" qualifier.
635    const fn needs_access_qualifier(&self) -> bool {
636        match *self {
637            //Note: we are ignoring the storage access here, and instead
638            // rely on the actual use of a global by functions. This means we
639            // may end up with "const" even if the binding is read-write,
640            // and that should be OK.
641            Self::Storage { .. } => true,
642            Self::TaskPayload => unimplemented!(),
643            // These should always be read-write.
644            Self::Private | Self::WorkGroup => false,
645            // These translate to `constant` address space, no need for qualifiers.
646            Self::Uniform | Self::Immediate => false,
647            // Not applicable.
648            Self::Handle | Self::Function => false,
649        }
650    }
651
652    const fn to_msl_name(self) -> Option<&'static str> {
653        match self {
654            Self::Handle => None,
655            Self::Uniform | Self::Immediate => Some("constant"),
656            Self::Storage { .. } => Some("device"),
657            Self::Private | Self::Function => Some("thread"),
658            Self::WorkGroup => Some("threadgroup"),
659            Self::TaskPayload => Some("object_data"),
660        }
661    }
662}
663
664impl crate::Type {
665    // Returns `true` if we need to emit an alias for this type.
666    const fn needs_alias(&self) -> bool {
667        use crate::TypeInner as Ti;
668
669        match self.inner {
670            // value types are concise enough, we only alias them if they are named
671            Ti::Scalar(_)
672            | Ti::Vector { .. }
673            | Ti::Matrix { .. }
674            | Ti::CooperativeMatrix { .. }
675            | Ti::Atomic(_)
676            | Ti::Pointer { .. }
677            | Ti::ValuePointer { .. } => self.name.is_some(),
678            // composite types are better to be aliased, regardless of the name
679            Ti::Struct { .. } | Ti::Array { .. } => true,
680            // handle types may be different, depending on the global var access, so we always inline them
681            Ti::Image { .. }
682            | Ti::Sampler { .. }
683            | Ti::AccelerationStructure { .. }
684            | Ti::RayQuery { .. }
685            | Ti::BindingArray { .. } => false,
686        }
687    }
688}
689
690#[derive(Clone, Copy)]
691enum FunctionOrigin {
692    Handle(Handle<crate::Function>),
693    EntryPoint(proc::EntryPointIndex),
694}
695
696trait NameKeyExt {
697    fn local(origin: FunctionOrigin, local_handle: Handle<crate::LocalVariable>) -> NameKey {
698        match origin {
699            FunctionOrigin::Handle(handle) => NameKey::FunctionLocal(handle, local_handle),
700            FunctionOrigin::EntryPoint(idx) => NameKey::EntryPointLocal(idx, local_handle),
701        }
702    }
703
704    /// Return the name key for a local variable used by ReadZeroSkipWrite bounds-check
705    /// policy when it needs to produce a pointer-typed result for an OOB access. These
706    /// are unique per accessed type, so the second argument is a type handle. See docs
707    /// for [`crate::back::msl`].
708    fn oob_local_for_type(origin: FunctionOrigin, ty: Handle<crate::Type>) -> NameKey {
709        match origin {
710            FunctionOrigin::Handle(handle) => NameKey::FunctionOobLocal(handle, ty),
711            FunctionOrigin::EntryPoint(idx) => NameKey::EntryPointOobLocal(idx, ty),
712        }
713    }
714}
715
716impl NameKeyExt for NameKey {}
717
718/// A level of detail argument.
719///
720/// When [`BoundsCheckPolicy::Restrict`] applies to an [`ImageLoad`] access, we
721/// save the clamped level of detail in a temporary variable whose name is based
722/// on the handle of the `ImageLoad` expression. But for other policies, we just
723/// use the expression directly.
724///
725/// [`BoundsCheckPolicy::Restrict`]: index::BoundsCheckPolicy::Restrict
726/// [`ImageLoad`]: crate::Expression::ImageLoad
727#[derive(Clone, Copy)]
728enum LevelOfDetail {
729    Direct(Handle<crate::Expression>),
730    Restricted(Handle<crate::Expression>),
731}
732
733/// Values needed to select a particular texel for [`ImageLoad`] and [`ImageStore`].
734///
735/// When this is used in code paths unconcerned with the `Restrict` bounds check
736/// policy, the `LevelOfDetail` enum introduces an unneeded match, since `level`
737/// will always be either `None` or `Some(Direct(_))`. But this turns out not to
738/// be too awkward. If that changes, we can revisit.
739///
740/// [`ImageLoad`]: crate::Expression::ImageLoad
741/// [`ImageStore`]: crate::Statement::ImageStore
742struct TexelAddress {
743    coordinate: Handle<crate::Expression>,
744    array_index: Option<Handle<crate::Expression>>,
745    sample: Option<Handle<crate::Expression>>,
746    level: Option<LevelOfDetail>,
747}
748
749struct ExpressionContext<'a> {
750    function: &'a crate::Function,
751    origin: FunctionOrigin,
752    info: &'a valid::FunctionInfo,
753    module: &'a crate::Module,
754    mod_info: &'a valid::ModuleInfo,
755    pipeline_options: &'a PipelineOptions,
756    lang_version: (u8, u8),
757    policies: index::BoundsCheckPolicies,
758
759    /// The set of expressions used as indices in `ReadZeroSkipWrite`-policy
760    /// accesses. These may need to be cached in temporary variables. See
761    /// `index::find_checked_indexes` for details.
762    guarded_indices: HandleSet<crate::Expression>,
763    /// See [`Writer::gen_force_bounded_loop_statements`] for details.
764    force_loop_bounding: bool,
765}
766
767impl<'a> ExpressionContext<'a> {
768    fn resolve_type(&self, handle: Handle<crate::Expression>) -> &'a crate::TypeInner {
769        self.info[handle].ty.inner_with(&self.module.types)
770    }
771
772    /// Return true if calls to `image`'s `read` and `write` methods should supply a level of detail.
773    ///
774    /// Only mipmapped images need to specify a level of detail. Since 1D
775    /// textures cannot have mipmaps, MSL requires that the level argument to
776    /// texture1d queries and accesses must be a constexpr 0. It's easiest
777    /// just to omit the level entirely for 1D textures.
778    fn image_needs_lod(&self, image: Handle<crate::Expression>) -> bool {
779        let image_ty = self.resolve_type(image);
780        if let crate::TypeInner::Image { dim, class, .. } = *image_ty {
781            class.is_mipmapped() && dim != crate::ImageDimension::D1
782        } else {
783            false
784        }
785    }
786
787    fn choose_bounds_check_policy(
788        &self,
789        pointer: Handle<crate::Expression>,
790    ) -> index::BoundsCheckPolicy {
791        self.policies
792            .choose_policy(pointer, &self.module.types, self.info)
793    }
794
795    /// See docs for [`proc::index::access_needs_check`].
796    fn access_needs_check(
797        &self,
798        base: Handle<crate::Expression>,
799        index: index::GuardedIndex,
800    ) -> Option<index::IndexableLength> {
801        index::access_needs_check(
802            base,
803            index,
804            self.module,
805            &self.function.expressions,
806            self.info,
807        )
808    }
809
810    /// See docs for [`proc::index::bounds_check_iter`].
811    fn bounds_check_iter(
812        &self,
813        chain: Handle<crate::Expression>,
814    ) -> impl Iterator<Item = BoundsCheck> + '_ {
815        index::bounds_check_iter(chain, self.module, self.function, self.info)
816    }
817
818    /// See docs for [`proc::index::oob_local_types`].
819    fn oob_local_types(&self) -> FastHashSet<Handle<crate::Type>> {
820        index::oob_local_types(self.module, self.function, self.info, self.policies)
821    }
822
823    fn get_packed_vec_kind(&self, expr_handle: Handle<crate::Expression>) -> Option<crate::Scalar> {
824        match self.function.expressions[expr_handle] {
825            crate::Expression::AccessIndex { base, index } => {
826                let ty = match *self.resolve_type(base) {
827                    crate::TypeInner::Pointer { base, .. } => &self.module.types[base].inner,
828                    ref ty => ty,
829                };
830                match *ty {
831                    crate::TypeInner::Struct {
832                        ref members, span, ..
833                    } => should_pack_struct_member(members, span, index as usize, self.module),
834                    _ => None,
835                }
836            }
837            _ => None,
838        }
839    }
840}
841
842struct StatementContext<'a> {
843    expression: ExpressionContext<'a>,
844    result_struct: Option<&'a str>,
845}
846
847impl<W: Write> Writer<W> {
848    /// Creates a new `Writer` instance.
849    pub fn new(out: W) -> Self {
850        Writer {
851            out,
852            names: FastHashMap::default(),
853            named_expressions: Default::default(),
854            need_bake_expressions: Default::default(),
855            namer: proc::Namer::default(),
856            wrapped_functions: FastHashSet::default(),
857            #[cfg(test)]
858            put_expression_stack_pointers: Default::default(),
859            #[cfg(test)]
860            put_block_stack_pointers: Default::default(),
861            struct_member_pads: FastHashSet::default(),
862        }
863    }
864
865    /// Finishes writing and returns the output.
866    // See https://github.com/rust-lang/rust-clippy/issues/4979.
867    #[allow(clippy::missing_const_for_fn)]
868    pub fn finish(self) -> W {
869        self.out
870    }
871
872    /// Generates statements to be inserted immediately before and at the very
873    /// start of the body of each loop, to defeat MSL infinite loop reasoning.
874    /// The 0th item of the returned tuple should be inserted immediately prior
875    /// to the loop and the 1st item should be inserted at the very start of
876    /// the loop body.
877    ///
878    /// # What is this trying to solve?
879    ///
880    /// In Metal Shading Language, an infinite loop has undefined behavior.
881    /// (This rule is inherited from C++14.) This means that, if the MSL
882    /// compiler determines that a given loop will never exit, it may assume
883    /// that it is never reached. It may thus assume that any conditions
884    /// sufficient to cause the loop to be reached must be false. Like many
885    /// optimizing compilers, MSL uses this kind of analysis to establish limits
886    /// on the range of values variables involved in those conditions might
887    /// hold.
888    ///
889    /// For example, suppose the MSL compiler sees the code:
890    ///
891    /// ```ignore
892    /// if (i >= 10) {
893    ///     while (true) { }
894    /// }
895    /// ```
896    ///
897    /// It will recognize that the `while` loop will never terminate, conclude
898    /// that it must be unreachable, and thus infer that, if this code is
899    /// reached, then `i < 10` at that point.
900    ///
901    /// Now suppose that, at some point where `i` has the same value as above,
902    /// the compiler sees the code:
903    ///
904    /// ```ignore
905    /// if (i < 10) {
906    ///     a[i] = 1;
907    /// }
908    /// ```
909    ///
910    /// Because the compiler is confident that `i < 10`, it will make the
911    /// assignment to `a[i]` unconditional, rewriting this code as, simply:
912    ///
913    /// ```ignore
914    /// a[i] = 1;
915    /// ```
916    ///
917    /// If that `if` condition was injected by Naga to implement a bounds check,
918    /// the MSL compiler's optimizations could allow out-of-bounds array
919    /// accesses to occur.
920    ///
921    /// Naga cannot feasibly anticipate whether the MSL compiler will determine
922    /// that a loop is infinite, so an attacker could craft a Naga module
923    /// containing an infinite loop protected by conditions that cause the Metal
924    /// compiler to remove bounds checks that Naga injected elsewhere in the
925    /// function.
926    ///
927    /// This rewrite could occur even if the conditional assignment appears
928    /// *before* the `while` loop, as long as `i < 10` by the time the loop is
929    /// reached. This would allow the attacker to save the results of
930    /// unauthorized reads somewhere accessible before entering the infinite
931    /// loop. But even worse, the MSL compiler has been observed to simply
932    /// delete the infinite loop entirely, so that even code dominated by the
933    /// loop becomes reachable. This would make the attack even more flexible,
934    /// since shaders that would appear to never terminate would actually exit
935    /// nicely, after having stolen data from elsewhere in the GPU address
936    /// space.
937    ///
938    /// To avoid UB, Naga must persuade the MSL compiler that no loop Naga
939    /// generates is infinite. One approach would be to add inline assembly to
940    /// each loop that is annotated as potentially branching out of the loop,
941    /// but which in fact generates no instructions. Unfortunately, inline
942    /// assembly is not handled correctly by some Metal device drivers.
943    ///
944    /// A previously used approach was to add the following code to the bottom
945    /// of every loop:
946    ///
947    /// ```ignore
948    /// if (volatile bool unpredictable = false; unpredictable)
949    ///     break;
950    /// ```
951    ///
952    /// Although the `if` condition will always be false in any real execution,
953    /// the `volatile` qualifier prevents the compiler from assuming this. Thus,
954    /// it must assume that the `break` might be reached, and hence that the
955    /// loop is not unbounded. This prevents the range analysis impact described
956    /// above. Unfortunately this prevented the compiler from making important,
957    /// and safe, optimizations such as loop unrolling and was observed to
958    /// significantly hurt performance.
959    ///
960    /// Our current approach declares a counter before every loop and
961    /// increments it every iteration, breaking after 2^64 iterations:
962    ///
963    /// ```ignore
964    /// uint2 loop_bound = uint2(0);
965    /// while (true) {
966    ///   if (metal::all(loop_bound == uint2(4294967295))) { break; }
967    ///   loop_bound += uint2(loop_bound.y == 4294967295, 1);
968    /// }
969    /// ```
970    ///
971    /// This convinces the compiler that the loop is finite and therefore may
972    /// execute, whilst at the same time allowing optimizations such as loop
973    /// unrolling. Furthermore the 64-bit counter is large enough it seems
974    /// implausible that it would affect the execution of any shader.
975    ///
976    /// This approach is also used by Chromium WebGPU's Dawn shader compiler:
977    /// <https://dawn.googlesource.com/dawn/+/d9e2d1f718678ebee0728b999830576c410cce0a/src/tint/lang/core/ir/transform/prevent_infinite_loops.cc>
978    fn gen_force_bounded_loop_statements(
979        &mut self,
980        level: back::Level,
981        context: &StatementContext,
982    ) -> Option<(String, String)> {
983        if !context.expression.force_loop_bounding {
984            return None;
985        }
986
987        let loop_bound_name = self.namer.call("loop_bound");
988        // Count down from u32::MAX rather than up from 0 to avoid hang on
989        // certain Intel drivers. See <https://github.com/gfx-rs/wgpu/issues/7319>.
990        let decl = format!("{level}uint2 {loop_bound_name} = uint2({}u);", u32::MAX);
991        let level = level.next();
992        let break_and_inc = format!(
993            "{level}if ({NAMESPACE}::all({loop_bound_name} == uint2(0u))) {{ break; }}
994{level}{loop_bound_name} -= uint2({loop_bound_name}.y == 0u, 1u);"
995        );
996
997        Some((decl, break_and_inc))
998    }
999
1000    fn put_call_parameters(
1001        &mut self,
1002        parameters: impl Iterator<Item = Handle<crate::Expression>>,
1003        context: &ExpressionContext,
1004    ) -> BackendResult {
1005        self.put_call_parameters_impl(parameters, context, |writer, context, expr| {
1006            writer.put_expression(expr, context, true)
1007        })
1008    }
1009
1010    fn put_call_parameters_impl<C, E>(
1011        &mut self,
1012        parameters: impl Iterator<Item = Handle<crate::Expression>>,
1013        ctx: &C,
1014        put_expression: E,
1015    ) -> BackendResult
1016    where
1017        E: Fn(&mut Self, &C, Handle<crate::Expression>) -> BackendResult,
1018    {
1019        write!(self.out, "(")?;
1020        for (i, handle) in parameters.enumerate() {
1021            if i != 0 {
1022                write!(self.out, ", ")?;
1023            }
1024            put_expression(self, ctx, handle)?;
1025        }
1026        write!(self.out, ")")?;
1027        Ok(())
1028    }
1029
1030    /// Writes the local variables of the given function, as well as any extra
1031    /// out-of-bounds locals that are needed.
1032    ///
1033    /// The names of the OOB locals are also added to `self.names` at the same
1034    /// time.
1035    fn put_locals(&mut self, context: &ExpressionContext) -> BackendResult {
1036        let oob_local_types = context.oob_local_types();
1037        for &ty in oob_local_types.iter() {
1038            let name_key = NameKey::oob_local_for_type(context.origin, ty);
1039            self.names.insert(name_key, self.namer.call("oob"));
1040        }
1041
1042        for (name_key, ty, init) in context
1043            .function
1044            .local_variables
1045            .iter()
1046            .map(|(local_handle, local)| {
1047                let name_key = NameKey::local(context.origin, local_handle);
1048                (name_key, local.ty, local.init)
1049            })
1050            .chain(oob_local_types.iter().map(|&ty| {
1051                let name_key = NameKey::oob_local_for_type(context.origin, ty);
1052                (name_key, ty, None)
1053            }))
1054        {
1055            let ty_name = TypeContext {
1056                handle: ty,
1057                gctx: context.module.to_ctx(),
1058                names: &self.names,
1059                access: crate::StorageAccess::empty(),
1060                first_time: false,
1061            };
1062            write!(
1063                self.out,
1064                "{}{} {}",
1065                back::INDENT,
1066                ty_name,
1067                self.names[&name_key]
1068            )?;
1069            match init {
1070                Some(value) => {
1071                    write!(self.out, " = ")?;
1072                    self.put_expression(value, context, true)?;
1073                }
1074                None => {
1075                    write!(self.out, " = {{}}")?;
1076                }
1077            };
1078            writeln!(self.out, ";")?;
1079        }
1080        Ok(())
1081    }
1082
1083    fn put_level_of_detail(
1084        &mut self,
1085        level: LevelOfDetail,
1086        context: &ExpressionContext,
1087    ) -> BackendResult {
1088        match level {
1089            LevelOfDetail::Direct(expr) => self.put_expression(expr, context, true)?,
1090            LevelOfDetail::Restricted(load) => write!(self.out, "{}", ClampedLod(load))?,
1091        }
1092        Ok(())
1093    }
1094
1095    fn put_image_query(
1096        &mut self,
1097        image: Handle<crate::Expression>,
1098        query: &str,
1099        level: Option<LevelOfDetail>,
1100        context: &ExpressionContext,
1101    ) -> BackendResult {
1102        self.put_expression(image, context, false)?;
1103        write!(self.out, ".get_{query}(")?;
1104        if let Some(level) = level {
1105            self.put_level_of_detail(level, context)?;
1106        }
1107        write!(self.out, ")")?;
1108        Ok(())
1109    }
1110
1111    fn put_image_size_query(
1112        &mut self,
1113        image: Handle<crate::Expression>,
1114        level: Option<LevelOfDetail>,
1115        kind: crate::ScalarKind,
1116        context: &ExpressionContext,
1117    ) -> BackendResult {
1118        if let crate::TypeInner::Image {
1119            class: crate::ImageClass::External,
1120            ..
1121        } = *context.resolve_type(image)
1122        {
1123            write!(self.out, "{IMAGE_SIZE_EXTERNAL_FUNCTION}(")?;
1124            self.put_expression(image, context, true)?;
1125            write!(self.out, ")")?;
1126            return Ok(());
1127        }
1128
1129        //Note: MSL only has separate width/height/depth queries,
1130        // so compose the result of them.
1131        let dim = match *context.resolve_type(image) {
1132            crate::TypeInner::Image { dim, .. } => dim,
1133            ref other => unreachable!("Unexpected type {:?}", other),
1134        };
1135        let scalar = crate::Scalar { kind, width: 4 };
1136        let coordinate_type = scalar.to_msl_name();
1137        match dim {
1138            crate::ImageDimension::D1 => {
1139                // Since 1D textures never have mipmaps, MSL requires that the
1140                // `level` argument be a constexpr 0. It's simplest for us just
1141                // to pass `None` and omit the level entirely.
1142                if kind == crate::ScalarKind::Uint {
1143                    // No need to construct a vector. No cast needed.
1144                    self.put_image_query(image, "width", None, context)?;
1145                } else {
1146                    // There's no definition for `int` in the `metal` namespace.
1147                    write!(self.out, "int(")?;
1148                    self.put_image_query(image, "width", None, context)?;
1149                    write!(self.out, ")")?;
1150                }
1151            }
1152            crate::ImageDimension::D2 => {
1153                write!(self.out, "{NAMESPACE}::{coordinate_type}2(")?;
1154                self.put_image_query(image, "width", level, context)?;
1155                write!(self.out, ", ")?;
1156                self.put_image_query(image, "height", level, context)?;
1157                write!(self.out, ")")?;
1158            }
1159            crate::ImageDimension::D3 => {
1160                write!(self.out, "{NAMESPACE}::{coordinate_type}3(")?;
1161                self.put_image_query(image, "width", level, context)?;
1162                write!(self.out, ", ")?;
1163                self.put_image_query(image, "height", level, context)?;
1164                write!(self.out, ", ")?;
1165                self.put_image_query(image, "depth", level, context)?;
1166                write!(self.out, ")")?;
1167            }
1168            crate::ImageDimension::Cube => {
1169                write!(self.out, "{NAMESPACE}::{coordinate_type}2(")?;
1170                self.put_image_query(image, "width", level, context)?;
1171                write!(self.out, ")")?;
1172            }
1173        }
1174        Ok(())
1175    }
1176
1177    fn put_cast_to_uint_scalar_or_vector(
1178        &mut self,
1179        expr: Handle<crate::Expression>,
1180        context: &ExpressionContext,
1181    ) -> BackendResult {
1182        // coordinates in IR are int, but Metal expects uint
1183        match *context.resolve_type(expr) {
1184            crate::TypeInner::Scalar(_) => {
1185                put_numeric_type(&mut self.out, crate::Scalar::U32, &[])?
1186            }
1187            crate::TypeInner::Vector { size, .. } => {
1188                put_numeric_type(&mut self.out, crate::Scalar::U32, &[size])?
1189            }
1190            _ => {
1191                return Err(Error::GenericValidation(
1192                    "Invalid type for image coordinate".into(),
1193                ))
1194            }
1195        };
1196
1197        write!(self.out, "(")?;
1198        self.put_expression(expr, context, true)?;
1199        write!(self.out, ")")?;
1200        Ok(())
1201    }
1202
1203    fn put_image_sample_level(
1204        &mut self,
1205        image: Handle<crate::Expression>,
1206        level: crate::SampleLevel,
1207        context: &ExpressionContext,
1208    ) -> BackendResult {
1209        let has_levels = context.image_needs_lod(image);
1210        match level {
1211            crate::SampleLevel::Auto => {}
1212            crate::SampleLevel::Zero => {
1213                //TODO: do we support Zero on `Sampled` image classes?
1214            }
1215            _ if !has_levels => {
1216                log::warn!("1D image can't be sampled with level {level:?}");
1217            }
1218            crate::SampleLevel::Exact(h) => {
1219                write!(self.out, ", {NAMESPACE}::level(")?;
1220                self.put_expression(h, context, true)?;
1221                write!(self.out, ")")?;
1222            }
1223            crate::SampleLevel::Bias(h) => {
1224                write!(self.out, ", {NAMESPACE}::bias(")?;
1225                self.put_expression(h, context, true)?;
1226                write!(self.out, ")")?;
1227            }
1228            crate::SampleLevel::Gradient { x, y } => {
1229                write!(self.out, ", {NAMESPACE}::gradient2d(")?;
1230                self.put_expression(x, context, true)?;
1231                write!(self.out, ", ")?;
1232                self.put_expression(y, context, true)?;
1233                write!(self.out, ")")?;
1234            }
1235        }
1236        Ok(())
1237    }
1238
1239    fn put_image_coordinate_limits(
1240        &mut self,
1241        image: Handle<crate::Expression>,
1242        level: Option<LevelOfDetail>,
1243        context: &ExpressionContext,
1244    ) -> BackendResult {
1245        self.put_image_size_query(image, level, crate::ScalarKind::Uint, context)?;
1246        write!(self.out, " - 1")?;
1247        Ok(())
1248    }
1249
1250    /// General function for writing restricted image indexes.
1251    ///
1252    /// This is used to produce restricted mip levels, array indices, and sample
1253    /// indices for [`ImageLoad`] and [`ImageStore`] accesses under the
1254    /// [`Restrict`] bounds check policy.
1255    ///
1256    /// This function writes an expression of the form:
1257    ///
1258    /// ```ignore
1259    ///
1260    ///     metal::min(uint(INDEX), IMAGE.LIMIT_METHOD() - 1)
1261    ///
1262    /// ```
1263    ///
1264    /// [`ImageLoad`]: crate::Expression::ImageLoad
1265    /// [`ImageStore`]: crate::Statement::ImageStore
1266    /// [`Restrict`]: index::BoundsCheckPolicy::Restrict
1267    fn put_restricted_scalar_image_index(
1268        &mut self,
1269        image: Handle<crate::Expression>,
1270        index: Handle<crate::Expression>,
1271        limit_method: &str,
1272        context: &ExpressionContext,
1273    ) -> BackendResult {
1274        write!(self.out, "{NAMESPACE}::min(uint(")?;
1275        self.put_expression(index, context, true)?;
1276        write!(self.out, "), ")?;
1277        self.put_expression(image, context, false)?;
1278        write!(self.out, ".{limit_method}() - 1)")?;
1279        Ok(())
1280    }
1281
1282    fn put_restricted_texel_address(
1283        &mut self,
1284        image: Handle<crate::Expression>,
1285        address: &TexelAddress,
1286        context: &ExpressionContext,
1287    ) -> BackendResult {
1288        // Write the coordinate.
1289        write!(self.out, "{NAMESPACE}::min(")?;
1290        self.put_cast_to_uint_scalar_or_vector(address.coordinate, context)?;
1291        write!(self.out, ", ")?;
1292        self.put_image_coordinate_limits(image, address.level, context)?;
1293        write!(self.out, ")")?;
1294
1295        // Write the array index, if present.
1296        if let Some(array_index) = address.array_index {
1297            write!(self.out, ", ")?;
1298            self.put_restricted_scalar_image_index(image, array_index, "get_array_size", context)?;
1299        }
1300
1301        // Write the sample index, if present.
1302        if let Some(sample) = address.sample {
1303            write!(self.out, ", ")?;
1304            self.put_restricted_scalar_image_index(image, sample, "get_num_samples", context)?;
1305        }
1306
1307        // The level of detail should be clamped and cached by
1308        // `put_cache_restricted_level`, so we don't need to clamp it here.
1309        if let Some(level) = address.level {
1310            write!(self.out, ", ")?;
1311            self.put_level_of_detail(level, context)?;
1312        }
1313
1314        Ok(())
1315    }
1316
1317    /// Write an expression that is true if the given image access is in bounds.
1318    fn put_image_access_bounds_check(
1319        &mut self,
1320        image: Handle<crate::Expression>,
1321        address: &TexelAddress,
1322        context: &ExpressionContext,
1323    ) -> BackendResult {
1324        let mut conjunction = "";
1325
1326        // First, check the level of detail. Only if that is in bounds can we
1327        // use it to find the appropriate bounds for the coordinates.
1328        let level = if let Some(level) = address.level {
1329            write!(self.out, "uint(")?;
1330            self.put_level_of_detail(level, context)?;
1331            write!(self.out, ") < ")?;
1332            self.put_expression(image, context, true)?;
1333            write!(self.out, ".get_num_mip_levels()")?;
1334            conjunction = " && ";
1335            Some(level)
1336        } else {
1337            None
1338        };
1339
1340        // Check sample index, if present.
1341        if let Some(sample) = address.sample {
1342            write!(self.out, "uint(")?;
1343            self.put_expression(sample, context, true)?;
1344            write!(self.out, ") < ")?;
1345            self.put_expression(image, context, true)?;
1346            write!(self.out, ".get_num_samples()")?;
1347            conjunction = " && ";
1348        }
1349
1350        // Check array index, if present.
1351        if let Some(array_index) = address.array_index {
1352            write!(self.out, "{conjunction}uint(")?;
1353            self.put_expression(array_index, context, true)?;
1354            write!(self.out, ") < ")?;
1355            self.put_expression(image, context, true)?;
1356            write!(self.out, ".get_array_size()")?;
1357            conjunction = " && ";
1358        }
1359
1360        // Finally, check if the coordinates are within bounds.
1361        let coord_is_vector = match *context.resolve_type(address.coordinate) {
1362            crate::TypeInner::Vector { .. } => true,
1363            _ => false,
1364        };
1365        write!(self.out, "{conjunction}")?;
1366        if coord_is_vector {
1367            write!(self.out, "{NAMESPACE}::all(")?;
1368        }
1369        self.put_cast_to_uint_scalar_or_vector(address.coordinate, context)?;
1370        write!(self.out, " < ")?;
1371        self.put_image_size_query(image, level, crate::ScalarKind::Uint, context)?;
1372        if coord_is_vector {
1373            write!(self.out, ")")?;
1374        }
1375
1376        Ok(())
1377    }
1378
1379    fn put_image_load(
1380        &mut self,
1381        load: Handle<crate::Expression>,
1382        image: Handle<crate::Expression>,
1383        mut address: TexelAddress,
1384        context: &ExpressionContext,
1385    ) -> BackendResult {
1386        if let crate::TypeInner::Image {
1387            class: crate::ImageClass::External,
1388            ..
1389        } = *context.resolve_type(image)
1390        {
1391            write!(self.out, "{IMAGE_LOAD_EXTERNAL_FUNCTION}(")?;
1392            self.put_expression(image, context, true)?;
1393            write!(self.out, ", ")?;
1394            self.put_cast_to_uint_scalar_or_vector(address.coordinate, context)?;
1395            write!(self.out, ")")?;
1396            return Ok(());
1397        }
1398
1399        match context.policies.image_load {
1400            proc::BoundsCheckPolicy::Restrict => {
1401                // Use the cached restricted level of detail, if any. Omit the
1402                // level altogether for 1D textures.
1403                if address.level.is_some() {
1404                    address.level = if context.image_needs_lod(image) {
1405                        Some(LevelOfDetail::Restricted(load))
1406                    } else {
1407                        None
1408                    }
1409                }
1410
1411                self.put_expression(image, context, false)?;
1412                write!(self.out, ".read(")?;
1413                self.put_restricted_texel_address(image, &address, context)?;
1414                write!(self.out, ")")?;
1415            }
1416            proc::BoundsCheckPolicy::ReadZeroSkipWrite => {
1417                write!(self.out, "(")?;
1418                self.put_image_access_bounds_check(image, &address, context)?;
1419                write!(self.out, " ? ")?;
1420                self.put_unchecked_image_load(image, &address, context)?;
1421                write!(self.out, ": DefaultConstructible())")?;
1422            }
1423            proc::BoundsCheckPolicy::Unchecked => {
1424                self.put_unchecked_image_load(image, &address, context)?;
1425            }
1426        }
1427
1428        Ok(())
1429    }
1430
1431    fn put_unchecked_image_load(
1432        &mut self,
1433        image: Handle<crate::Expression>,
1434        address: &TexelAddress,
1435        context: &ExpressionContext,
1436    ) -> BackendResult {
1437        self.put_expression(image, context, false)?;
1438        write!(self.out, ".read(")?;
1439        // coordinates in IR are int, but Metal expects uint
1440        self.put_cast_to_uint_scalar_or_vector(address.coordinate, context)?;
1441        if let Some(expr) = address.array_index {
1442            write!(self.out, ", ")?;
1443            self.put_expression(expr, context, true)?;
1444        }
1445        if let Some(sample) = address.sample {
1446            write!(self.out, ", ")?;
1447            self.put_expression(sample, context, true)?;
1448        }
1449        if let Some(level) = address.level {
1450            if context.image_needs_lod(image) {
1451                write!(self.out, ", ")?;
1452                self.put_level_of_detail(level, context)?;
1453            }
1454        }
1455        write!(self.out, ")")?;
1456
1457        Ok(())
1458    }
1459
1460    fn put_image_atomic(
1461        &mut self,
1462        level: back::Level,
1463        image: Handle<crate::Expression>,
1464        address: &TexelAddress,
1465        fun: crate::AtomicFunction,
1466        value: Handle<crate::Expression>,
1467        context: &StatementContext,
1468    ) -> BackendResult {
1469        write!(self.out, "{level}")?;
1470        self.put_expression(image, &context.expression, false)?;
1471        let op = if context.expression.resolve_type(value).scalar_width() == Some(8) {
1472            fun.to_msl_64_bit()?
1473        } else {
1474            fun.to_msl()
1475        };
1476        write!(self.out, ".atomic_{op}(")?;
1477        // coordinates in IR are int, but Metal expects uint
1478        self.put_cast_to_uint_scalar_or_vector(address.coordinate, &context.expression)?;
1479        write!(self.out, ", ")?;
1480        self.put_expression(value, &context.expression, true)?;
1481        writeln!(self.out, ");")?;
1482
1483        Ok(())
1484    }
1485
1486    fn put_image_store(
1487        &mut self,
1488        level: back::Level,
1489        image: Handle<crate::Expression>,
1490        address: &TexelAddress,
1491        value: Handle<crate::Expression>,
1492        context: &StatementContext,
1493    ) -> BackendResult {
1494        write!(self.out, "{level}")?;
1495        self.put_expression(image, &context.expression, false)?;
1496        write!(self.out, ".write(")?;
1497        self.put_expression(value, &context.expression, true)?;
1498        write!(self.out, ", ")?;
1499        // coordinates in IR are int, but Metal expects uint
1500        self.put_cast_to_uint_scalar_or_vector(address.coordinate, &context.expression)?;
1501        if let Some(expr) = address.array_index {
1502            write!(self.out, ", ")?;
1503            self.put_expression(expr, &context.expression, true)?;
1504        }
1505        writeln!(self.out, ");")?;
1506
1507        Ok(())
1508    }
1509
1510    /// Write the maximum valid index of the dynamically sized array at the end of `handle`.
1511    ///
1512    /// The 'maximum valid index' is simply one less than the array's length.
1513    ///
1514    /// This emits an expression of the form `a / b`, so the caller must
1515    /// parenthesize its output if it will be applying operators of higher
1516    /// precedence.
1517    ///
1518    /// `handle` must be the handle of a global variable whose final member is a
1519    /// dynamically sized array.
1520    fn put_dynamic_array_max_index(
1521        &mut self,
1522        handle: Handle<crate::GlobalVariable>,
1523        context: &ExpressionContext,
1524    ) -> BackendResult {
1525        let global = &context.module.global_variables[handle];
1526        let (offset, array_ty) = match context.module.types[global.ty].inner {
1527            crate::TypeInner::Struct { ref members, .. } => match members.last() {
1528                Some(&crate::StructMember { offset, ty, .. }) => (offset, ty),
1529                None => return Err(Error::GenericValidation("Struct has no members".into())),
1530            },
1531            crate::TypeInner::Array {
1532                size: crate::ArraySize::Dynamic,
1533                ..
1534            } => (0, global.ty),
1535            ref ty => {
1536                return Err(Error::GenericValidation(format!(
1537                    "Expected type with dynamic array, got {ty:?}"
1538                )))
1539            }
1540        };
1541
1542        let (size, stride) = match context.module.types[array_ty].inner {
1543            crate::TypeInner::Array { base, stride, .. } => (
1544                context.module.types[base]
1545                    .inner
1546                    .size(context.module.to_ctx()),
1547                stride,
1548            ),
1549            ref ty => {
1550                return Err(Error::GenericValidation(format!(
1551                    "Expected array type, got {ty:?}"
1552                )))
1553            }
1554        };
1555
1556        // When the stride length is larger than the size, the final element's stride of
1557        // bytes would have padding following the value. But the buffer size in
1558        // `buffer_sizes.sizeN` may not include this padding - it only needs to be large
1559        // enough to hold the actual values' bytes.
1560        //
1561        // So subtract off the size to get a byte size that falls at the start or within
1562        // the final element. Then divide by the stride size, to get one less than the
1563        // length, and then add one. This works even if the buffer size does include the
1564        // stride padding, since division rounds towards zero (MSL 2.4 §6.1). It will fail
1565        // if there are zero elements in the array, but the WebGPU `validating shader binding`
1566        // rules, together with draw-time validation when `minBindingSize` is zero,
1567        // prevent that.
1568        write!(
1569            self.out,
1570            "(_buffer_sizes.{member} - {offset} - {size}) / {stride}",
1571            member = ArraySizeMember(handle),
1572            offset = offset,
1573            size = size,
1574            stride = stride,
1575        )?;
1576        Ok(())
1577    }
1578
1579    /// Emit code for the arithmetic expression of the dot product.
1580    ///
1581    /// The argument `extractor` is a function that accepts a `Writer`, a vector, and
1582    /// an index. It writes out the expression for the vector component at that index.
1583    fn put_dot_product<T: Copy>(
1584        &mut self,
1585        arg: T,
1586        arg1: T,
1587        size: usize,
1588        extractor: impl Fn(&mut Self, T, usize) -> BackendResult,
1589    ) -> BackendResult {
1590        // Write parentheses around the dot product expression to prevent operators
1591        // with different precedences from applying earlier.
1592        write!(self.out, "(")?;
1593
1594        // Cycle through all the components of the vector
1595        for index in 0..size {
1596            // Write the addition to the previous product
1597            // This will print an extra '+' at the beginning but that is fine in msl
1598            write!(self.out, " + ")?;
1599            extractor(self, arg, index)?;
1600            write!(self.out, " * ")?;
1601            extractor(self, arg1, index)?;
1602        }
1603
1604        write!(self.out, ")")?;
1605        Ok(())
1606    }
1607
1608    /// Emit code for the WGSL functions `pack4x{I, U}8[Clamp]`.
1609    fn put_pack4x8(
1610        &mut self,
1611        arg: Handle<crate::Expression>,
1612        context: &ExpressionContext<'_>,
1613        was_signed: bool,
1614        clamp_bounds: Option<(&str, &str)>,
1615    ) -> Result<(), Error> {
1616        let write_arg = |this: &mut Self| -> BackendResult {
1617            if let Some((min, max)) = clamp_bounds {
1618                // Clamping with scalar bounds works (component-wise) even for packed_[u]char4.
1619                write!(this.out, "{NAMESPACE}::clamp(")?;
1620                this.put_expression(arg, context, true)?;
1621                write!(this.out, ", {min}, {max})")?;
1622            } else {
1623                this.put_expression(arg, context, true)?;
1624            }
1625            Ok(())
1626        };
1627
1628        if context.lang_version >= (2, 1) {
1629            let packed_type = if was_signed {
1630                "packed_char4"
1631            } else {
1632                "packed_uchar4"
1633            };
1634            // Metal uses little endian byte order, which matches what WGSL expects here.
1635            write!(self.out, "as_type<uint>({packed_type}(")?;
1636            write_arg(self)?;
1637            write!(self.out, "))")?;
1638        } else {
1639            // MSL < 2.1 doesn't support `as_type` casting between packed chars and scalars.
1640            if was_signed {
1641                write!(self.out, "uint(")?;
1642            }
1643            write!(self.out, "(")?;
1644            write_arg(self)?;
1645            write!(self.out, "[0] & 0xFF) | ((")?;
1646            write_arg(self)?;
1647            write!(self.out, "[1] & 0xFF) << 8) | ((")?;
1648            write_arg(self)?;
1649            write!(self.out, "[2] & 0xFF) << 16) | ((")?;
1650            write_arg(self)?;
1651            write!(self.out, "[3] & 0xFF) << 24)")?;
1652            if was_signed {
1653                write!(self.out, ")")?;
1654            }
1655        }
1656
1657        Ok(())
1658    }
1659
1660    /// Emit code for the isign expression.
1661    ///
1662    fn put_isign(
1663        &mut self,
1664        arg: Handle<crate::Expression>,
1665        context: &ExpressionContext,
1666    ) -> BackendResult {
1667        write!(self.out, "{NAMESPACE}::select({NAMESPACE}::select(")?;
1668        let scalar = context
1669            .resolve_type(arg)
1670            .scalar()
1671            .expect("put_isign should only be called for args which have an integer scalar type")
1672            .to_msl_name();
1673        match context.resolve_type(arg) {
1674            &crate::TypeInner::Vector { size, .. } => {
1675                let size = common::vector_size_str(size);
1676                write!(self.out, "{scalar}{size}(-1), {scalar}{size}(1)")?;
1677            }
1678            _ => {
1679                write!(self.out, "{scalar}(-1), {scalar}(1)")?;
1680            }
1681        }
1682        write!(self.out, ", (")?;
1683        self.put_expression(arg, context, true)?;
1684        write!(self.out, " > 0)), {scalar}(0), (")?;
1685        self.put_expression(arg, context, true)?;
1686        write!(self.out, " == 0))")?;
1687        Ok(())
1688    }
1689
1690    fn put_const_expression(
1691        &mut self,
1692        expr_handle: Handle<crate::Expression>,
1693        module: &crate::Module,
1694        mod_info: &valid::ModuleInfo,
1695        arena: &crate::Arena<crate::Expression>,
1696    ) -> BackendResult {
1697        self.put_possibly_const_expression(
1698            expr_handle,
1699            arena,
1700            module,
1701            mod_info,
1702            &(module, mod_info),
1703            |&(_, mod_info), expr| &mod_info[expr],
1704            |writer, &(module, _), expr| writer.put_const_expression(expr, module, mod_info, arena),
1705        )
1706    }
1707
1708    fn put_literal(&mut self, literal: crate::Literal) -> BackendResult {
1709        match literal {
1710            crate::Literal::F64(_) => {
1711                return Err(Error::CapabilityNotSupported(valid::Capabilities::FLOAT64))
1712            }
1713            crate::Literal::F16(value) => {
1714                if value.is_infinite() {
1715                    let sign = if value.is_sign_negative() { "-" } else { "" };
1716                    write!(self.out, "{sign}INFINITY")?;
1717                } else if value.is_nan() {
1718                    write!(self.out, "NAN")?;
1719                } else {
1720                    let suffix = if value.fract() == f16::from_f32(0.0) {
1721                        ".0h"
1722                    } else {
1723                        "h"
1724                    };
1725                    write!(self.out, "{value}{suffix}")?;
1726                }
1727            }
1728            crate::Literal::F32(value) => {
1729                if value.is_infinite() {
1730                    let sign = if value.is_sign_negative() { "-" } else { "" };
1731                    write!(self.out, "{sign}INFINITY")?;
1732                } else if value.is_nan() {
1733                    write!(self.out, "NAN")?;
1734                } else {
1735                    let suffix = if value.fract() == 0.0 { ".0" } else { "" };
1736                    write!(self.out, "{value}{suffix}")?;
1737                }
1738            }
1739            crate::Literal::U32(value) => {
1740                write!(self.out, "{value}u")?;
1741            }
1742            crate::Literal::I32(value) => {
1743                // `-2147483648` is parsed as unary negation of positive 2147483648.
1744                // 2147483648 is too large for int32_t meaning the expression gets
1745                // promoted to a int64_t which is not our intention. Avoid this by instead
1746                // using `-2147483647 - 1`.
1747                if value == i32::MIN {
1748                    write!(self.out, "({} - 1)", value + 1)?;
1749                } else {
1750                    write!(self.out, "{value}")?;
1751                }
1752            }
1753            crate::Literal::U64(value) => {
1754                write!(self.out, "{value}uL")?;
1755            }
1756            crate::Literal::I64(value) => {
1757                // `-9223372036854775808` is parsed as unary negation of positive
1758                // 9223372036854775808. 9223372036854775808 is too large for int64_t
1759                // causing Metal to emit a `-Wconstant-conversion` warning, and change the
1760                // value to `-9223372036854775808`. Which would then be negated, possibly
1761                // causing undefined behaviour. Avoid this by instead using
1762                // `-9223372036854775808L - 1L`.
1763                if value == i64::MIN {
1764                    write!(self.out, "({}L - 1L)", value + 1)?;
1765                } else {
1766                    write!(self.out, "{value}L")?;
1767                }
1768            }
1769            crate::Literal::Bool(value) => {
1770                write!(self.out, "{value}")?;
1771            }
1772            crate::Literal::AbstractInt(_) | crate::Literal::AbstractFloat(_) => {
1773                return Err(Error::GenericValidation(
1774                    "Unsupported abstract literal".into(),
1775                ));
1776            }
1777        }
1778        Ok(())
1779    }
1780
1781    #[allow(clippy::too_many_arguments)]
1782    fn put_possibly_const_expression<C, I, E>(
1783        &mut self,
1784        expr_handle: Handle<crate::Expression>,
1785        expressions: &crate::Arena<crate::Expression>,
1786        module: &crate::Module,
1787        mod_info: &valid::ModuleInfo,
1788        ctx: &C,
1789        get_expr_ty: I,
1790        put_expression: E,
1791    ) -> BackendResult
1792    where
1793        I: Fn(&C, Handle<crate::Expression>) -> &TypeResolution,
1794        E: Fn(&mut Self, &C, Handle<crate::Expression>) -> BackendResult,
1795    {
1796        match expressions[expr_handle] {
1797            crate::Expression::Literal(literal) => {
1798                self.put_literal(literal)?;
1799            }
1800            crate::Expression::Constant(handle) => {
1801                let constant = &module.constants[handle];
1802                if constant.name.is_some() {
1803                    write!(self.out, "{}", self.names[&NameKey::Constant(handle)])?;
1804                } else {
1805                    self.put_const_expression(
1806                        constant.init,
1807                        module,
1808                        mod_info,
1809                        &module.global_expressions,
1810                    )?;
1811                }
1812            }
1813            crate::Expression::ZeroValue(ty) => {
1814                let ty_name = TypeContext {
1815                    handle: ty,
1816                    gctx: module.to_ctx(),
1817                    names: &self.names,
1818                    access: crate::StorageAccess::empty(),
1819                    first_time: false,
1820                };
1821                write!(self.out, "{ty_name} {{}}")?;
1822            }
1823            crate::Expression::Compose { ty, ref components } => {
1824                let ty_name = TypeContext {
1825                    handle: ty,
1826                    gctx: module.to_ctx(),
1827                    names: &self.names,
1828                    access: crate::StorageAccess::empty(),
1829                    first_time: false,
1830                };
1831                write!(self.out, "{ty_name}")?;
1832                match module.types[ty].inner {
1833                    crate::TypeInner::Scalar(_)
1834                    | crate::TypeInner::Vector { .. }
1835                    | crate::TypeInner::Matrix { .. } => {
1836                        self.put_call_parameters_impl(
1837                            components.iter().copied(),
1838                            ctx,
1839                            put_expression,
1840                        )?;
1841                    }
1842                    crate::TypeInner::Array { .. } | crate::TypeInner::Struct { .. } => {
1843                        write!(self.out, " {{")?;
1844                        for (index, &component) in components.iter().enumerate() {
1845                            if index != 0 {
1846                                write!(self.out, ", ")?;
1847                            }
1848                            // insert padding initialization, if needed
1849                            if self.struct_member_pads.contains(&(ty, index as u32)) {
1850                                write!(self.out, "{{}}, ")?;
1851                            }
1852                            put_expression(self, ctx, component)?;
1853                        }
1854                        write!(self.out, "}}")?;
1855                    }
1856                    _ => return Err(Error::UnsupportedCompose(ty)),
1857                }
1858            }
1859            crate::Expression::Splat { size, value } => {
1860                let scalar = match *get_expr_ty(ctx, value).inner_with(&module.types) {
1861                    crate::TypeInner::Scalar(scalar) => scalar,
1862                    ref ty => {
1863                        return Err(Error::GenericValidation(format!(
1864                            "Expected splat value type must be a scalar, got {ty:?}",
1865                        )))
1866                    }
1867                };
1868                put_numeric_type(&mut self.out, scalar, &[size])?;
1869                write!(self.out, "(")?;
1870                put_expression(self, ctx, value)?;
1871                write!(self.out, ")")?;
1872            }
1873            _ => {
1874                return Err(Error::Override);
1875            }
1876        }
1877
1878        Ok(())
1879    }
1880
1881    /// Emit code for the expression `expr_handle`.
1882    ///
1883    /// The `is_scoped` argument is true if the surrounding operators have the
1884    /// precedence of the comma operator, or lower. So, for example:
1885    ///
1886    /// - Pass `true` for `is_scoped` when writing function arguments, an
1887    ///   expression statement, an initializer expression, or anything already
1888    ///   wrapped in parenthesis.
1889    ///
1890    /// - Pass `false` if it is an operand of a `?:` operator, a `[]`, or really
1891    ///   almost anything else.
1892    fn put_expression(
1893        &mut self,
1894        expr_handle: Handle<crate::Expression>,
1895        context: &ExpressionContext,
1896        is_scoped: bool,
1897    ) -> BackendResult {
1898        // Add to the set in order to track the stack size.
1899        #[cfg(test)]
1900        self.put_expression_stack_pointers
1901            .insert(ptr::from_ref(&expr_handle).cast());
1902
1903        if let Some(name) = self.named_expressions.get(&expr_handle) {
1904            write!(self.out, "{name}")?;
1905            return Ok(());
1906        }
1907
1908        let expression = &context.function.expressions[expr_handle];
1909        match *expression {
1910            crate::Expression::Literal(_)
1911            | crate::Expression::Constant(_)
1912            | crate::Expression::ZeroValue(_)
1913            | crate::Expression::Compose { .. }
1914            | crate::Expression::Splat { .. } => {
1915                self.put_possibly_const_expression(
1916                    expr_handle,
1917                    &context.function.expressions,
1918                    context.module,
1919                    context.mod_info,
1920                    context,
1921                    |context, expr: Handle<crate::Expression>| &context.info[expr].ty,
1922                    |writer, context, expr| writer.put_expression(expr, context, true),
1923                )?;
1924            }
1925            crate::Expression::Override(_) => return Err(Error::Override),
1926            crate::Expression::Access { base, .. }
1927            | crate::Expression::AccessIndex { base, .. } => {
1928                // This is an acceptable place to generate a `ReadZeroSkipWrite` check.
1929                // Since `put_bounds_checks` and `put_access_chain` handle an entire
1930                // access chain at a time, recursing back through `put_expression` only
1931                // for index expressions and the base object, we will never see intermediate
1932                // `Access` or `AccessIndex` expressions here.
1933                let policy = context.choose_bounds_check_policy(base);
1934                if policy == index::BoundsCheckPolicy::ReadZeroSkipWrite
1935                    && self.put_bounds_checks(
1936                        expr_handle,
1937                        context,
1938                        back::Level(0),
1939                        if is_scoped { "" } else { "(" },
1940                    )?
1941                {
1942                    write!(self.out, " ? ")?;
1943                    self.put_access_chain(expr_handle, policy, context)?;
1944                    write!(self.out, " : ")?;
1945
1946                    if context.resolve_type(base).pointer_space().is_some() {
1947                        // We can't just use `DefaultConstructible` if this is a pointer.
1948                        // Instead, we create a dummy local variable to serve as pointer
1949                        // target if the access is out of bounds.
1950                        let result_ty = context.info[expr_handle]
1951                            .ty
1952                            .inner_with(&context.module.types)
1953                            .pointer_base_type();
1954                        let result_ty_handle = match result_ty {
1955                            Some(TypeResolution::Handle(handle)) => handle,
1956                            Some(TypeResolution::Value(_)) => {
1957                                // As long as the result of a pointer access expression is
1958                                // passed to a function or stored in a let binding, the
1959                                // type will be in the arena. If additional uses of
1960                                // pointers become valid, this assumption might no longer
1961                                // hold. Note that the LHS of a load or store doesn't
1962                                // take this path -- there is dedicated code in `put_load`
1963                                // and `put_store`.
1964                                unreachable!(
1965                                    "Expected type {result_ty:?} of access through pointer type {base:?} to be in the arena",
1966                                );
1967                            }
1968                            None => {
1969                                unreachable!(
1970                                    "Expected access through pointer type {base:?} to return a pointer, but got {result_ty:?}",
1971                                )
1972                            }
1973                        };
1974                        let name_key =
1975                            NameKey::oob_local_for_type(context.origin, result_ty_handle);
1976                        self.out.write_str(&self.names[&name_key])?;
1977                    } else {
1978                        write!(self.out, "DefaultConstructible()")?;
1979                    }
1980
1981                    if !is_scoped {
1982                        write!(self.out, ")")?;
1983                    }
1984                } else {
1985                    self.put_access_chain(expr_handle, policy, context)?;
1986                }
1987            }
1988            crate::Expression::Swizzle {
1989                size,
1990                vector,
1991                pattern,
1992            } => {
1993                self.put_wrapped_expression_for_packed_vec3_access(
1994                    vector,
1995                    context,
1996                    false,
1997                    &Self::put_expression,
1998                )?;
1999                write!(self.out, ".")?;
2000                for &sc in pattern[..size as usize].iter() {
2001                    write!(self.out, "{}", back::COMPONENTS[sc as usize])?;
2002                }
2003            }
2004            crate::Expression::FunctionArgument(index) => {
2005                let name_key = match context.origin {
2006                    FunctionOrigin::Handle(handle) => NameKey::FunctionArgument(handle, index),
2007                    FunctionOrigin::EntryPoint(ep_index) => {
2008                        NameKey::EntryPointArgument(ep_index, index)
2009                    }
2010                };
2011                let name = &self.names[&name_key];
2012                write!(self.out, "{name}")?;
2013            }
2014            crate::Expression::GlobalVariable(handle) => {
2015                let name = &self.names[&NameKey::GlobalVariable(handle)];
2016                write!(self.out, "{name}")?;
2017            }
2018            crate::Expression::LocalVariable(handle) => {
2019                let name_key = NameKey::local(context.origin, handle);
2020                let name = &self.names[&name_key];
2021                write!(self.out, "{name}")?;
2022            }
2023            crate::Expression::Load { pointer } => self.put_load(pointer, context, is_scoped)?,
2024            crate::Expression::ImageSample {
2025                coordinate,
2026                image,
2027                sampler,
2028                clamp_to_edge: true,
2029                gather: None,
2030                array_index: None,
2031                offset: None,
2032                level: crate::SampleLevel::Zero,
2033                depth_ref: None,
2034            } => {
2035                write!(self.out, "{IMAGE_SAMPLE_BASE_CLAMP_TO_EDGE_FUNCTION}(")?;
2036                self.put_expression(image, context, true)?;
2037                write!(self.out, ", ")?;
2038                self.put_expression(sampler, context, true)?;
2039                write!(self.out, ", ")?;
2040                self.put_expression(coordinate, context, true)?;
2041                write!(self.out, ")")?;
2042            }
2043            crate::Expression::ImageSample {
2044                image,
2045                sampler,
2046                gather,
2047                coordinate,
2048                array_index,
2049                offset,
2050                level,
2051                depth_ref,
2052                clamp_to_edge,
2053            } => {
2054                if clamp_to_edge {
2055                    return Err(Error::GenericValidation(
2056                        "ImageSample::clamp_to_edge should have been validated out".to_string(),
2057                    ));
2058                }
2059
2060                let main_op = match gather {
2061                    Some(_) => "gather",
2062                    None => "sample",
2063                };
2064                let comparison_op = match depth_ref {
2065                    Some(_) => "_compare",
2066                    None => "",
2067                };
2068                self.put_expression(image, context, false)?;
2069                write!(self.out, ".{main_op}{comparison_op}(")?;
2070                self.put_expression(sampler, context, true)?;
2071                write!(self.out, ", ")?;
2072                self.put_expression(coordinate, context, true)?;
2073                if let Some(expr) = array_index {
2074                    write!(self.out, ", ")?;
2075                    self.put_expression(expr, context, true)?;
2076                }
2077                if let Some(dref) = depth_ref {
2078                    write!(self.out, ", ")?;
2079                    self.put_expression(dref, context, true)?;
2080                }
2081
2082                self.put_image_sample_level(image, level, context)?;
2083
2084                if let Some(offset) = offset {
2085                    write!(self.out, ", ")?;
2086                    self.put_expression(offset, context, true)?;
2087                }
2088
2089                match gather {
2090                    None | Some(crate::SwizzleComponent::X) => {}
2091                    Some(component) => {
2092                        let is_cube_map = match *context.resolve_type(image) {
2093                            crate::TypeInner::Image {
2094                                dim: crate::ImageDimension::Cube,
2095                                ..
2096                            } => true,
2097                            _ => false,
2098                        };
2099                        // Offset always comes before the gather, except
2100                        // in cube maps where it's not applicable
2101                        if offset.is_none() && !is_cube_map {
2102                            write!(self.out, ", {NAMESPACE}::int2(0)")?;
2103                        }
2104                        let letter = back::COMPONENTS[component as usize];
2105                        write!(self.out, ", {NAMESPACE}::component::{letter}")?;
2106                    }
2107                }
2108                write!(self.out, ")")?;
2109            }
2110            crate::Expression::ImageLoad {
2111                image,
2112                coordinate,
2113                array_index,
2114                sample,
2115                level,
2116            } => {
2117                let address = TexelAddress {
2118                    coordinate,
2119                    array_index,
2120                    sample,
2121                    level: level.map(LevelOfDetail::Direct),
2122                };
2123                self.put_image_load(expr_handle, image, address, context)?;
2124            }
2125            //Note: for all the queries, the signed integers are expected,
2126            // so a conversion is needed.
2127            crate::Expression::ImageQuery { image, query } => match query {
2128                crate::ImageQuery::Size { level } => {
2129                    self.put_image_size_query(
2130                        image,
2131                        level.map(LevelOfDetail::Direct),
2132                        crate::ScalarKind::Uint,
2133                        context,
2134                    )?;
2135                }
2136                crate::ImageQuery::NumLevels => {
2137                    self.put_expression(image, context, false)?;
2138                    write!(self.out, ".get_num_mip_levels()")?;
2139                }
2140                crate::ImageQuery::NumLayers => {
2141                    self.put_expression(image, context, false)?;
2142                    write!(self.out, ".get_array_size()")?;
2143                }
2144                crate::ImageQuery::NumSamples => {
2145                    self.put_expression(image, context, false)?;
2146                    write!(self.out, ".get_num_samples()")?;
2147                }
2148            },
2149            crate::Expression::Unary { op, expr } => {
2150                let op_str = match op {
2151                    crate::UnaryOperator::Negate => {
2152                        match context.resolve_type(expr).scalar_kind() {
2153                            Some(crate::ScalarKind::Sint) => NEG_FUNCTION,
2154                            _ => "-",
2155                        }
2156                    }
2157                    crate::UnaryOperator::LogicalNot => "!",
2158                    crate::UnaryOperator::BitwiseNot => "~",
2159                };
2160                write!(self.out, "{op_str}(")?;
2161                self.put_expression(expr, context, false)?;
2162                write!(self.out, ")")?;
2163            }
2164            crate::Expression::Binary { op, left, right } => {
2165                let kind = context
2166                    .resolve_type(left)
2167                    .scalar_kind()
2168                    .ok_or(Error::UnsupportedBinaryOp(op))?;
2169
2170                if op == crate::BinaryOperator::Divide
2171                    && (kind == crate::ScalarKind::Sint || kind == crate::ScalarKind::Uint)
2172                {
2173                    write!(self.out, "{DIV_FUNCTION}(")?;
2174                    self.put_expression(left, context, true)?;
2175                    write!(self.out, ", ")?;
2176                    self.put_expression(right, context, true)?;
2177                    write!(self.out, ")")?;
2178                } else if op == crate::BinaryOperator::Modulo
2179                    && (kind == crate::ScalarKind::Sint || kind == crate::ScalarKind::Uint)
2180                {
2181                    write!(self.out, "{MOD_FUNCTION}(")?;
2182                    self.put_expression(left, context, true)?;
2183                    write!(self.out, ", ")?;
2184                    self.put_expression(right, context, true)?;
2185                    write!(self.out, ")")?;
2186                } else if op == crate::BinaryOperator::Modulo && kind == crate::ScalarKind::Float {
2187                    // TODO: handle undefined behavior of BinaryOperator::Modulo
2188                    //
2189                    // float:
2190                    // if right == 0 return ? see https://github.com/gpuweb/gpuweb/issues/2798
2191                    write!(self.out, "{NAMESPACE}::fmod(")?;
2192                    self.put_expression(left, context, true)?;
2193                    write!(self.out, ", ")?;
2194                    self.put_expression(right, context, true)?;
2195                    write!(self.out, ")")?;
2196                } else if (op == crate::BinaryOperator::Add
2197                    || op == crate::BinaryOperator::Subtract
2198                    || op == crate::BinaryOperator::Multiply)
2199                    && kind == crate::ScalarKind::Sint
2200                {
2201                    let to_unsigned = |ty: &crate::TypeInner| match *ty {
2202                        crate::TypeInner::Scalar(scalar) => {
2203                            Ok(crate::TypeInner::Scalar(crate::Scalar {
2204                                kind: crate::ScalarKind::Uint,
2205                                ..scalar
2206                            }))
2207                        }
2208                        crate::TypeInner::Vector { size, scalar } => Ok(crate::TypeInner::Vector {
2209                            size,
2210                            scalar: crate::Scalar {
2211                                kind: crate::ScalarKind::Uint,
2212                                ..scalar
2213                            },
2214                        }),
2215                        _ => Err(Error::UnsupportedBitCast(ty.clone())),
2216                    };
2217
2218                    // Avoid undefined behaviour due to overflowing signed
2219                    // integer arithmetic. Cast the operands to unsigned prior
2220                    // to performing the operation, then cast the result back
2221                    // to signed.
2222                    self.put_bitcasted_expression(
2223                        context.resolve_type(expr_handle),
2224                        context,
2225                        &|writer, context, is_scoped| {
2226                            writer.put_binop(
2227                                op,
2228                                left,
2229                                right,
2230                                context,
2231                                is_scoped,
2232                                &|writer, expr, context, _is_scoped| {
2233                                    writer.put_bitcasted_expression(
2234                                        &to_unsigned(context.resolve_type(expr))?,
2235                                        context,
2236                                        &|writer, context, is_scoped| {
2237                                            writer.put_expression(expr, context, is_scoped)
2238                                        },
2239                                    )
2240                                },
2241                            )
2242                        },
2243                    )?;
2244                } else {
2245                    self.put_binop(op, left, right, context, is_scoped, &Self::put_expression)?;
2246                }
2247            }
2248            crate::Expression::Select {
2249                condition,
2250                accept,
2251                reject,
2252            } => match *context.resolve_type(condition) {
2253                crate::TypeInner::Scalar(crate::Scalar {
2254                    kind: crate::ScalarKind::Bool,
2255                    ..
2256                }) => {
2257                    if !is_scoped {
2258                        write!(self.out, "(")?;
2259                    }
2260                    self.put_expression(condition, context, false)?;
2261                    write!(self.out, " ? ")?;
2262                    self.put_expression(accept, context, false)?;
2263                    write!(self.out, " : ")?;
2264                    self.put_expression(reject, context, false)?;
2265                    if !is_scoped {
2266                        write!(self.out, ")")?;
2267                    }
2268                }
2269                crate::TypeInner::Vector {
2270                    scalar:
2271                        crate::Scalar {
2272                            kind: crate::ScalarKind::Bool,
2273                            ..
2274                        },
2275                    ..
2276                } => {
2277                    write!(self.out, "{NAMESPACE}::select(")?;
2278                    self.put_expression(reject, context, true)?;
2279                    write!(self.out, ", ")?;
2280                    self.put_expression(accept, context, true)?;
2281                    write!(self.out, ", ")?;
2282                    self.put_expression(condition, context, true)?;
2283                    write!(self.out, ")")?;
2284                }
2285                ref ty => {
2286                    return Err(Error::GenericValidation(format!(
2287                        "Expected select condition to be a non-bool type, got {ty:?}",
2288                    )))
2289                }
2290            },
2291            crate::Expression::Derivative { axis, expr, .. } => {
2292                use crate::DerivativeAxis as Axis;
2293                let op = match axis {
2294                    Axis::X => "dfdx",
2295                    Axis::Y => "dfdy",
2296                    Axis::Width => "fwidth",
2297                };
2298                write!(self.out, "{NAMESPACE}::{op}")?;
2299                self.put_call_parameters(iter::once(expr), context)?;
2300            }
2301            crate::Expression::Relational { fun, argument } => {
2302                let op = match fun {
2303                    crate::RelationalFunction::Any => "any",
2304                    crate::RelationalFunction::All => "all",
2305                    crate::RelationalFunction::IsNan => "isnan",
2306                    crate::RelationalFunction::IsInf => "isinf",
2307                };
2308                write!(self.out, "{NAMESPACE}::{op}")?;
2309                self.put_call_parameters(iter::once(argument), context)?;
2310            }
2311            crate::Expression::Math {
2312                fun,
2313                arg,
2314                arg1,
2315                arg2,
2316                arg3,
2317            } => {
2318                use crate::MathFunction as Mf;
2319
2320                let arg_type = context.resolve_type(arg);
2321                let scalar_argument = match arg_type {
2322                    &crate::TypeInner::Scalar(_) => true,
2323                    _ => false,
2324                };
2325
2326                let fun_name = match fun {
2327                    // comparison
2328                    Mf::Abs => "abs",
2329                    Mf::Min => "min",
2330                    Mf::Max => "max",
2331                    Mf::Clamp => "clamp",
2332                    Mf::Saturate => "saturate",
2333                    // trigonometry
2334                    Mf::Cos => "cos",
2335                    Mf::Cosh => "cosh",
2336                    Mf::Sin => "sin",
2337                    Mf::Sinh => "sinh",
2338                    Mf::Tan => "tan",
2339                    Mf::Tanh => "tanh",
2340                    Mf::Acos => "acos",
2341                    Mf::Asin => "asin",
2342                    Mf::Atan => "atan",
2343                    Mf::Atan2 => "atan2",
2344                    Mf::Asinh => "asinh",
2345                    Mf::Acosh => "acosh",
2346                    Mf::Atanh => "atanh",
2347                    Mf::Radians => "",
2348                    Mf::Degrees => "",
2349                    // decomposition
2350                    Mf::Ceil => "ceil",
2351                    Mf::Floor => "floor",
2352                    Mf::Round => "rint",
2353                    Mf::Fract => "fract",
2354                    Mf::Trunc => "trunc",
2355                    Mf::Modf => MODF_FUNCTION,
2356                    Mf::Frexp => FREXP_FUNCTION,
2357                    Mf::Ldexp => "ldexp",
2358                    // exponent
2359                    Mf::Exp => "exp",
2360                    Mf::Exp2 => "exp2",
2361                    Mf::Log => "log",
2362                    Mf::Log2 => "log2",
2363                    Mf::Pow => "pow",
2364                    // geometry
2365                    Mf::Dot => match *context.resolve_type(arg) {
2366                        crate::TypeInner::Vector {
2367                            scalar:
2368                                crate::Scalar {
2369                                    // Resolve float values to MSL's builtin dot function.
2370                                    kind: crate::ScalarKind::Float,
2371                                    ..
2372                                },
2373                            ..
2374                        } => "dot",
2375                        crate::TypeInner::Vector {
2376                            size,
2377                            scalar:
2378                                scalar @ crate::Scalar {
2379                                    kind: crate::ScalarKind::Sint | crate::ScalarKind::Uint,
2380                                    ..
2381                                },
2382                        } => {
2383                            // Integer vector dot: call our mangled helper `dot_{type}{N}(a, b)`.
2384                            let fun_name = self.get_dot_wrapper_function_helper_name(scalar, size);
2385                            write!(self.out, "{fun_name}(")?;
2386                            self.put_expression(arg, context, true)?;
2387                            write!(self.out, ", ")?;
2388                            self.put_expression(arg1.unwrap(), context, true)?;
2389                            write!(self.out, ")")?;
2390                            return Ok(());
2391                        }
2392                        _ => unreachable!(
2393                            "Correct TypeInner for dot product should be already validated"
2394                        ),
2395                    },
2396                    fun @ (Mf::Dot4I8Packed | Mf::Dot4U8Packed) => {
2397                        if context.lang_version >= (2, 1) {
2398                            // Write potentially optimizable code using `packed_(u?)char4`.
2399                            // The two function arguments were already reinterpreted as packed (signed
2400                            // or unsigned) chars in `Self::put_block`.
2401                            let packed_type = match fun {
2402                                Mf::Dot4I8Packed => "packed_char4",
2403                                Mf::Dot4U8Packed => "packed_uchar4",
2404                                _ => unreachable!(),
2405                            };
2406
2407                            return self.put_dot_product(
2408                                Reinterpreted::new(packed_type, arg),
2409                                Reinterpreted::new(packed_type, arg1.unwrap()),
2410                                4,
2411                                |writer, arg, index| {
2412                                    // MSL implicitly promotes these (signed or unsigned) chars to
2413                                    // `int` or `uint` in the multiplication, so no overflow can occur.
2414                                    write!(writer.out, "{arg}[{index}]")?;
2415                                    Ok(())
2416                                },
2417                            );
2418                        } else {
2419                            // Fall back to a polyfill since MSL < 2.1 doesn't seem to support
2420                            // bitcasting from uint to `packed_char4` or `packed_uchar4`.
2421                            // See <https://github.com/gfx-rs/wgpu/pull/7574#issuecomment-2835464472>.
2422                            let conversion = match fun {
2423                                Mf::Dot4I8Packed => "int",
2424                                Mf::Dot4U8Packed => "",
2425                                _ => unreachable!(),
2426                            };
2427
2428                            return self.put_dot_product(
2429                                arg,
2430                                arg1.unwrap(),
2431                                4,
2432                                |writer, arg, index| {
2433                                    write!(writer.out, "({conversion}(")?;
2434                                    writer.put_expression(arg, context, true)?;
2435                                    if index == 3 {
2436                                        write!(writer.out, ") >> 24)")?;
2437                                    } else {
2438                                        write!(writer.out, ") << {} >> 24)", (3 - index) * 8)?;
2439                                    }
2440                                    Ok(())
2441                                },
2442                            );
2443                        }
2444                    }
2445                    Mf::Outer => return Err(Error::UnsupportedCall(format!("{fun:?}"))),
2446                    Mf::Cross => "cross",
2447                    Mf::Distance => "distance",
2448                    Mf::Length if scalar_argument => "abs",
2449                    Mf::Length => "length",
2450                    Mf::Normalize => "normalize",
2451                    Mf::FaceForward => "faceforward",
2452                    Mf::Reflect => "reflect",
2453                    Mf::Refract => "refract",
2454                    // computational
2455                    Mf::Sign => match arg_type.scalar_kind() {
2456                        Some(crate::ScalarKind::Sint) => {
2457                            return self.put_isign(arg, context);
2458                        }
2459                        _ => "sign",
2460                    },
2461                    Mf::Fma => "fma",
2462                    Mf::Mix => "mix",
2463                    Mf::Step => "step",
2464                    Mf::SmoothStep => "smoothstep",
2465                    Mf::Sqrt => "sqrt",
2466                    Mf::InverseSqrt => "rsqrt",
2467                    Mf::Inverse => return Err(Error::UnsupportedCall(format!("{fun:?}"))),
2468                    Mf::Transpose => "transpose",
2469                    Mf::Determinant => "determinant",
2470                    Mf::QuantizeToF16 => "",
2471                    // bits
2472                    Mf::CountTrailingZeros => "ctz",
2473                    Mf::CountLeadingZeros => "clz",
2474                    Mf::CountOneBits => "popcount",
2475                    Mf::ReverseBits => "reverse_bits",
2476                    Mf::ExtractBits => "",
2477                    Mf::InsertBits => "",
2478                    Mf::FirstTrailingBit => "",
2479                    Mf::FirstLeadingBit => "",
2480                    // data packing
2481                    Mf::Pack4x8snorm => "pack_float_to_snorm4x8",
2482                    Mf::Pack4x8unorm => "pack_float_to_unorm4x8",
2483                    Mf::Pack2x16snorm => "pack_float_to_snorm2x16",
2484                    Mf::Pack2x16unorm => "pack_float_to_unorm2x16",
2485                    Mf::Pack2x16float => "",
2486                    Mf::Pack4xI8 => "",
2487                    Mf::Pack4xU8 => "",
2488                    Mf::Pack4xI8Clamp => "",
2489                    Mf::Pack4xU8Clamp => "",
2490                    // data unpacking
2491                    Mf::Unpack4x8snorm => "unpack_snorm4x8_to_float",
2492                    Mf::Unpack4x8unorm => "unpack_unorm4x8_to_float",
2493                    Mf::Unpack2x16snorm => "unpack_snorm2x16_to_float",
2494                    Mf::Unpack2x16unorm => "unpack_unorm2x16_to_float",
2495                    Mf::Unpack2x16float => "",
2496                    Mf::Unpack4xI8 => "",
2497                    Mf::Unpack4xU8 => "",
2498                };
2499
2500                match fun {
2501                    Mf::ReverseBits | Mf::ExtractBits | Mf::InsertBits => {
2502                        // reverse_bits is listed as requiring MSL 2.1 but that
2503                        // is a copy/paste error. Looking at previous snapshots
2504                        // on web.archive.org it's present in MSL 1.2.
2505                        //
2506                        // https://developer.apple.com/library/archive/documentation/Miscellaneous/Conceptual/MetalProgrammingGuide/WhatsNewiniOS10tvOS10andOSX1012/WhatsNewiniOS10tvOS10andOSX1012.html
2507                        // also talks about MSL 1.2 adding "New integer
2508                        // functions to extract, insert, and reverse bits, as
2509                        // described in Integer Functions."
2510                        if context.lang_version < (1, 2) {
2511                            return Err(Error::UnsupportedFunction(fun_name.to_string()));
2512                        }
2513                    }
2514                    _ => {}
2515                }
2516
2517                match fun {
2518                    Mf::Abs if arg_type.scalar_kind() == Some(crate::ScalarKind::Sint) => {
2519                        write!(self.out, "{ABS_FUNCTION}(")?;
2520                        self.put_expression(arg, context, true)?;
2521                        write!(self.out, ")")?;
2522                    }
2523                    Mf::Distance if scalar_argument => {
2524                        write!(self.out, "{NAMESPACE}::abs(")?;
2525                        self.put_expression(arg, context, false)?;
2526                        write!(self.out, " - ")?;
2527                        self.put_expression(arg1.unwrap(), context, false)?;
2528                        write!(self.out, ")")?;
2529                    }
2530                    Mf::FirstTrailingBit => {
2531                        let scalar = context.resolve_type(arg).scalar().unwrap();
2532                        let constant = scalar.width * 8 + 1;
2533
2534                        write!(self.out, "((({NAMESPACE}::ctz(")?;
2535                        self.put_expression(arg, context, true)?;
2536                        write!(self.out, ") + 1) % {constant}) - 1)")?;
2537                    }
2538                    Mf::FirstLeadingBit => {
2539                        let inner = context.resolve_type(arg);
2540                        let scalar = inner.scalar().unwrap();
2541                        let constant = scalar.width * 8 - 1;
2542
2543                        write!(
2544                            self.out,
2545                            "{NAMESPACE}::select({constant} - {NAMESPACE}::clz("
2546                        )?;
2547
2548                        if scalar.kind == crate::ScalarKind::Sint {
2549                            write!(self.out, "{NAMESPACE}::select(")?;
2550                            self.put_expression(arg, context, true)?;
2551                            write!(self.out, ", ~")?;
2552                            self.put_expression(arg, context, true)?;
2553                            write!(self.out, ", ")?;
2554                            self.put_expression(arg, context, true)?;
2555                            write!(self.out, " < 0)")?;
2556                        } else {
2557                            self.put_expression(arg, context, true)?;
2558                        }
2559
2560                        write!(self.out, "), ")?;
2561
2562                        // or metal will complain that select is ambiguous
2563                        match *inner {
2564                            crate::TypeInner::Vector { size, scalar } => {
2565                                let size = common::vector_size_str(size);
2566                                let name = scalar.to_msl_name();
2567                                write!(self.out, "{name}{size}")?;
2568                            }
2569                            crate::TypeInner::Scalar(scalar) => {
2570                                let name = scalar.to_msl_name();
2571                                write!(self.out, "{name}")?;
2572                            }
2573                            _ => (),
2574                        }
2575
2576                        write!(self.out, "(-1), ")?;
2577                        self.put_expression(arg, context, true)?;
2578                        write!(self.out, " == 0 || ")?;
2579                        self.put_expression(arg, context, true)?;
2580                        write!(self.out, " == -1)")?;
2581                    }
2582                    Mf::Unpack2x16float => {
2583                        write!(self.out, "float2(as_type<half2>(")?;
2584                        self.put_expression(arg, context, false)?;
2585                        write!(self.out, "))")?;
2586                    }
2587                    Mf::Pack2x16float => {
2588                        write!(self.out, "as_type<uint>(half2(")?;
2589                        self.put_expression(arg, context, false)?;
2590                        write!(self.out, "))")?;
2591                    }
2592                    Mf::ExtractBits => {
2593                        // The behavior of ExtractBits is undefined when offset + count > bit_width. We need
2594                        // to first sanitize the offset and count first. If we don't do this, Apple chips
2595                        // will return out-of-spec values if the extracted range is not within the bit width.
2596                        //
2597                        // This encodes the exact formula specified by the wgsl spec, without temporary values:
2598                        // https://gpuweb.github.io/gpuweb/wgsl/#extractBits-unsigned-builtin
2599                        //
2600                        // w = sizeof(x) * 8
2601                        // o = min(offset, w)
2602                        // tmp = w - o
2603                        // c = min(count, tmp)
2604                        //
2605                        // bitfieldExtract(x, o, c)
2606                        //
2607                        // extract_bits(e, min(offset, w), min(count, w - min(offset, w))))
2608
2609                        let scalar_bits = context.resolve_type(arg).scalar_width().unwrap() * 8;
2610
2611                        write!(self.out, "{NAMESPACE}::extract_bits(")?;
2612                        self.put_expression(arg, context, true)?;
2613                        write!(self.out, ", {NAMESPACE}::min(")?;
2614                        self.put_expression(arg1.unwrap(), context, true)?;
2615                        write!(self.out, ", {scalar_bits}u), {NAMESPACE}::min(")?;
2616                        self.put_expression(arg2.unwrap(), context, true)?;
2617                        write!(self.out, ", {scalar_bits}u - {NAMESPACE}::min(")?;
2618                        self.put_expression(arg1.unwrap(), context, true)?;
2619                        write!(self.out, ", {scalar_bits}u)))")?;
2620                    }
2621                    Mf::InsertBits => {
2622                        // The behavior of InsertBits has the same issue as ExtractBits.
2623                        //
2624                        // insertBits(e, newBits, min(offset, w), min(count, w - min(offset, w))))
2625
2626                        let scalar_bits = context.resolve_type(arg).scalar_width().unwrap() * 8;
2627
2628                        write!(self.out, "{NAMESPACE}::insert_bits(")?;
2629                        self.put_expression(arg, context, true)?;
2630                        write!(self.out, ", ")?;
2631                        self.put_expression(arg1.unwrap(), context, true)?;
2632                        write!(self.out, ", {NAMESPACE}::min(")?;
2633                        self.put_expression(arg2.unwrap(), context, true)?;
2634                        write!(self.out, ", {scalar_bits}u), {NAMESPACE}::min(")?;
2635                        self.put_expression(arg3.unwrap(), context, true)?;
2636                        write!(self.out, ", {scalar_bits}u - {NAMESPACE}::min(")?;
2637                        self.put_expression(arg2.unwrap(), context, true)?;
2638                        write!(self.out, ", {scalar_bits}u)))")?;
2639                    }
2640                    Mf::Radians => {
2641                        write!(self.out, "((")?;
2642                        self.put_expression(arg, context, false)?;
2643                        write!(self.out, ") * 0.017453292519943295474)")?;
2644                    }
2645                    Mf::Degrees => {
2646                        write!(self.out, "((")?;
2647                        self.put_expression(arg, context, false)?;
2648                        write!(self.out, ") * 57.295779513082322865)")?;
2649                    }
2650                    Mf::Modf | Mf::Frexp => {
2651                        write!(self.out, "{fun_name}")?;
2652                        self.put_call_parameters(iter::once(arg), context)?;
2653                    }
2654                    Mf::Pack4xI8 => self.put_pack4x8(arg, context, true, None)?,
2655                    Mf::Pack4xU8 => self.put_pack4x8(arg, context, false, None)?,
2656                    Mf::Pack4xI8Clamp => {
2657                        self.put_pack4x8(arg, context, true, Some(("-128", "127")))?
2658                    }
2659                    Mf::Pack4xU8Clamp => {
2660                        self.put_pack4x8(arg, context, false, Some(("0", "255")))?
2661                    }
2662                    fun @ (Mf::Unpack4xI8 | Mf::Unpack4xU8) => {
2663                        let sign_prefix = if matches!(fun, Mf::Unpack4xU8) {
2664                            "u"
2665                        } else {
2666                            ""
2667                        };
2668
2669                        if context.lang_version >= (2, 1) {
2670                            // Metal uses little endian byte order, which matches what WGSL expects here.
2671                            write!(
2672                                self.out,
2673                                "{sign_prefix}int4(as_type<packed_{sign_prefix}char4>("
2674                            )?;
2675                            self.put_expression(arg, context, true)?;
2676                            write!(self.out, "))")?;
2677                        } else {
2678                            // MSL < 2.1 doesn't support `as_type` casting between packed chars and scalars.
2679                            write!(self.out, "({sign_prefix}int4(")?;
2680                            self.put_expression(arg, context, true)?;
2681                            write!(self.out, ", ")?;
2682                            self.put_expression(arg, context, true)?;
2683                            write!(self.out, " >> 8, ")?;
2684                            self.put_expression(arg, context, true)?;
2685                            write!(self.out, " >> 16, ")?;
2686                            self.put_expression(arg, context, true)?;
2687                            write!(self.out, " >> 24) << 24 >> 24)")?;
2688                        }
2689                    }
2690                    Mf::QuantizeToF16 => {
2691                        match *context.resolve_type(arg) {
2692                            crate::TypeInner::Scalar { .. } => write!(self.out, "float(half(")?,
2693                            crate::TypeInner::Vector { size, .. } => write!(
2694                                self.out,
2695                                "{NAMESPACE}::float{size}({NAMESPACE}::half{size}(",
2696                                size = common::vector_size_str(size),
2697                            )?,
2698                            _ => unreachable!(
2699                                "Correct TypeInner for QuantizeToF16 should be already validated"
2700                            ),
2701                        };
2702
2703                        self.put_expression(arg, context, true)?;
2704                        write!(self.out, "))")?;
2705                    }
2706                    _ => {
2707                        write!(self.out, "{NAMESPACE}::{fun_name}")?;
2708                        self.put_call_parameters(
2709                            iter::once(arg).chain(arg1).chain(arg2).chain(arg3),
2710                            context,
2711                        )?;
2712                    }
2713                }
2714            }
2715            crate::Expression::As {
2716                expr,
2717                kind,
2718                convert,
2719            } => match *context.resolve_type(expr) {
2720                crate::TypeInner::Scalar(src) | crate::TypeInner::Vector { scalar: src, .. } => {
2721                    if src.kind == crate::ScalarKind::Float
2722                        && (kind == crate::ScalarKind::Sint || kind == crate::ScalarKind::Uint)
2723                        && convert.is_some()
2724                    {
2725                        // Use helper functions for float to int casts in order to avoid
2726                        // undefined behaviour when value is out of range for the target
2727                        // type.
2728                        let fun_name = match (kind, convert) {
2729                            (crate::ScalarKind::Sint, Some(4)) => F2I32_FUNCTION,
2730                            (crate::ScalarKind::Uint, Some(4)) => F2U32_FUNCTION,
2731                            (crate::ScalarKind::Sint, Some(8)) => F2I64_FUNCTION,
2732                            (crate::ScalarKind::Uint, Some(8)) => F2U64_FUNCTION,
2733                            _ => unreachable!(),
2734                        };
2735                        write!(self.out, "{fun_name}(")?;
2736                        self.put_expression(expr, context, true)?;
2737                        write!(self.out, ")")?;
2738                    } else {
2739                        let target_scalar = crate::Scalar {
2740                            kind,
2741                            width: convert.unwrap_or(src.width),
2742                        };
2743                        let op = match convert {
2744                            Some(_) => "static_cast",
2745                            None => "as_type",
2746                        };
2747                        write!(self.out, "{op}<")?;
2748                        match *context.resolve_type(expr) {
2749                            crate::TypeInner::Vector { size, .. } => {
2750                                put_numeric_type(&mut self.out, target_scalar, &[size])?
2751                            }
2752                            _ => put_numeric_type(&mut self.out, target_scalar, &[])?,
2753                        };
2754                        write!(self.out, ">(")?;
2755                        self.put_expression(expr, context, true)?;
2756                        write!(self.out, ")")?;
2757                    }
2758                }
2759                crate::TypeInner::Matrix {
2760                    columns,
2761                    rows,
2762                    scalar,
2763                } => {
2764                    let target_scalar = crate::Scalar {
2765                        kind,
2766                        width: convert.unwrap_or(scalar.width),
2767                    };
2768                    put_numeric_type(&mut self.out, target_scalar, &[rows, columns])?;
2769                    write!(self.out, "(")?;
2770                    self.put_expression(expr, context, true)?;
2771                    write!(self.out, ")")?;
2772                }
2773                ref ty => {
2774                    return Err(Error::GenericValidation(format!(
2775                        "Unsupported type for As: {ty:?}"
2776                    )))
2777                }
2778            },
2779            // has to be a named expression
2780            crate::Expression::CallResult(_)
2781            | crate::Expression::AtomicResult { .. }
2782            | crate::Expression::WorkGroupUniformLoadResult { .. }
2783            | crate::Expression::SubgroupBallotResult
2784            | crate::Expression::SubgroupOperationResult { .. }
2785            | crate::Expression::RayQueryProceedResult => {
2786                unreachable!()
2787            }
2788            crate::Expression::ArrayLength(expr) => {
2789                // Find the global to which the array belongs.
2790                let global = match context.function.expressions[expr] {
2791                    crate::Expression::AccessIndex { base, .. } => {
2792                        match context.function.expressions[base] {
2793                            crate::Expression::GlobalVariable(handle) => handle,
2794                            ref ex => {
2795                                return Err(Error::GenericValidation(format!(
2796                                    "Expected global variable in AccessIndex, got {ex:?}"
2797                                )))
2798                            }
2799                        }
2800                    }
2801                    crate::Expression::GlobalVariable(handle) => handle,
2802                    ref ex => {
2803                        return Err(Error::GenericValidation(format!(
2804                            "Unexpected expression in ArrayLength, got {ex:?}"
2805                        )))
2806                    }
2807                };
2808
2809                if !is_scoped {
2810                    write!(self.out, "(")?;
2811                }
2812                write!(self.out, "1 + ")?;
2813                self.put_dynamic_array_max_index(global, context)?;
2814                if !is_scoped {
2815                    write!(self.out, ")")?;
2816                }
2817            }
2818            crate::Expression::RayQueryVertexPositions { .. } => {
2819                unimplemented!()
2820            }
2821            crate::Expression::RayQueryGetIntersection {
2822                query,
2823                committed: _,
2824            } => {
2825                if context.lang_version < (2, 4) {
2826                    return Err(Error::UnsupportedRayTracing);
2827                }
2828
2829                let ty = context.module.special_types.ray_intersection.unwrap();
2830                let type_name = &self.names[&NameKey::Type(ty)];
2831                write!(self.out, "{type_name} {{{RAY_QUERY_FUN_MAP_INTERSECTION}(")?;
2832                self.put_expression(query, context, true)?;
2833                write!(self.out, ".{RAY_QUERY_FIELD_INTERSECTION}.type)")?;
2834                let fields = [
2835                    "distance",
2836                    "user_instance_id", // req Metal 2.4
2837                    "instance_id",
2838                    "", // SBT offset
2839                    "geometry_id",
2840                    "primitive_id",
2841                    "triangle_barycentric_coord",
2842                    "triangle_front_facing",
2843                    "",                          // padding
2844                    "object_to_world_transform", // req Metal 2.4
2845                    "world_to_object_transform", // req Metal 2.4
2846                ];
2847                for field in fields {
2848                    write!(self.out, ", ")?;
2849                    if field.is_empty() {
2850                        write!(self.out, "{{}}")?;
2851                    } else {
2852                        self.put_expression(query, context, true)?;
2853                        write!(self.out, ".{RAY_QUERY_FIELD_INTERSECTION}.{field}")?;
2854                    }
2855                }
2856                write!(self.out, "}}")?;
2857            }
2858            crate::Expression::CooperativeLoad { ref data, .. } => {
2859                if context.lang_version < (2, 3) {
2860                    return Err(Error::UnsupportedCooperativeMatrix);
2861                }
2862                write!(self.out, "{COOPERATIVE_LOAD_FUNCTION}(")?;
2863                write!(self.out, "&")?;
2864                self.put_access_chain(data.pointer, context.policies.index, context)?;
2865                write!(self.out, ", ")?;
2866                self.put_expression(data.stride, context, true)?;
2867                write!(self.out, ", {})", data.row_major)?;
2868            }
2869            crate::Expression::CooperativeMultiplyAdd { a, b, c } => {
2870                if context.lang_version < (2, 3) {
2871                    return Err(Error::UnsupportedCooperativeMatrix);
2872                }
2873                write!(self.out, "{COOPERATIVE_MULTIPLY_ADD_FUNCTION}(")?;
2874                self.put_expression(a, context, true)?;
2875                write!(self.out, ", ")?;
2876                self.put_expression(b, context, true)?;
2877                write!(self.out, ", ")?;
2878                self.put_expression(c, context, true)?;
2879                write!(self.out, ")")?;
2880            }
2881        }
2882        Ok(())
2883    }
2884
2885    /// Emits code for a binary operation, using the provided callback to emit
2886    /// the left and right operands.
2887    fn put_binop<F>(
2888        &mut self,
2889        op: crate::BinaryOperator,
2890        left: Handle<crate::Expression>,
2891        right: Handle<crate::Expression>,
2892        context: &ExpressionContext,
2893        is_scoped: bool,
2894        put_expression: &F,
2895    ) -> BackendResult
2896    where
2897        F: Fn(&mut Self, Handle<crate::Expression>, &ExpressionContext, bool) -> BackendResult,
2898    {
2899        let op_str = back::binary_operation_str(op);
2900
2901        if !is_scoped {
2902            write!(self.out, "(")?;
2903        }
2904
2905        // Cast packed vector if necessary
2906        // Packed vector - matrix multiplications are not supported in MSL
2907        if op == crate::BinaryOperator::Multiply
2908            && matches!(
2909                context.resolve_type(right),
2910                &crate::TypeInner::Matrix { .. }
2911            )
2912        {
2913            self.put_wrapped_expression_for_packed_vec3_access(
2914                left,
2915                context,
2916                false,
2917                put_expression,
2918            )?;
2919        } else {
2920            put_expression(self, left, context, false)?;
2921        }
2922
2923        write!(self.out, " {op_str} ")?;
2924
2925        // See comment above
2926        if op == crate::BinaryOperator::Multiply
2927            && matches!(context.resolve_type(left), &crate::TypeInner::Matrix { .. })
2928        {
2929            self.put_wrapped_expression_for_packed_vec3_access(
2930                right,
2931                context,
2932                false,
2933                put_expression,
2934            )?;
2935        } else {
2936            put_expression(self, right, context, false)?;
2937        }
2938
2939        if !is_scoped {
2940            write!(self.out, ")")?;
2941        }
2942
2943        Ok(())
2944    }
2945
2946    /// Used by expressions like Swizzle and Binary since they need packed_vec3's to be casted to a vec3
2947    fn put_wrapped_expression_for_packed_vec3_access<F>(
2948        &mut self,
2949        expr_handle: Handle<crate::Expression>,
2950        context: &ExpressionContext,
2951        is_scoped: bool,
2952        put_expression: &F,
2953    ) -> BackendResult
2954    where
2955        F: Fn(&mut Self, Handle<crate::Expression>, &ExpressionContext, bool) -> BackendResult,
2956    {
2957        if let Some(scalar) = context.get_packed_vec_kind(expr_handle) {
2958            write!(self.out, "{}::{}3(", NAMESPACE, scalar.to_msl_name())?;
2959            put_expression(self, expr_handle, context, is_scoped)?;
2960            write!(self.out, ")")?;
2961        } else {
2962            put_expression(self, expr_handle, context, is_scoped)?;
2963        }
2964        Ok(())
2965    }
2966
2967    /// Emits code for an expression using the provided callback, wrapping the
2968    /// result in a bitcast to the type `cast_to`.
2969    fn put_bitcasted_expression<F>(
2970        &mut self,
2971        cast_to: &crate::TypeInner,
2972        context: &ExpressionContext,
2973        put_expression: &F,
2974    ) -> BackendResult
2975    where
2976        F: Fn(&mut Self, &ExpressionContext, bool) -> BackendResult,
2977    {
2978        write!(self.out, "as_type<")?;
2979        match *cast_to {
2980            crate::TypeInner::Scalar(scalar) => put_numeric_type(&mut self.out, scalar, &[])?,
2981            crate::TypeInner::Vector { size, scalar } => {
2982                put_numeric_type(&mut self.out, scalar, &[size])?
2983            }
2984            _ => return Err(Error::UnsupportedBitCast(cast_to.clone())),
2985        };
2986        write!(self.out, ">(")?;
2987        put_expression(self, context, true)?;
2988        write!(self.out, ")")?;
2989
2990        Ok(())
2991    }
2992
2993    /// Write a `GuardedIndex` as a Metal expression.
2994    fn put_index(
2995        &mut self,
2996        index: index::GuardedIndex,
2997        context: &ExpressionContext,
2998        is_scoped: bool,
2999    ) -> BackendResult {
3000        match index {
3001            index::GuardedIndex::Expression(expr) => {
3002                self.put_expression(expr, context, is_scoped)?
3003            }
3004            index::GuardedIndex::Known(value) => write!(self.out, "{value}")?,
3005        }
3006        Ok(())
3007    }
3008
3009    /// Emit an index bounds check condition for `chain`, if required.
3010    ///
3011    /// `chain` is a subtree of `Access` and `AccessIndex` expressions,
3012    /// operating either on a pointer to a value, or on a value directly. If we cannot
3013    /// statically determine that all indexing operations in `chain` are within
3014    /// bounds, then write a conditional expression to check them dynamically,
3015    /// and return true. All accesses in the chain are checked by the generated
3016    /// expression.
3017    ///
3018    /// This assumes that the [`BoundsCheckPolicy`] for `chain` is [`ReadZeroSkipWrite`].
3019    ///
3020    /// The text written is of the form:
3021    ///
3022    /// ```ignore
3023    /// {level}{prefix}uint(i) < 4 && uint(j) < 10
3024    /// ```
3025    ///
3026    /// where `{level}` and `{prefix}` are the arguments to this function. For [`Store`]
3027    /// statements, presumably these arguments start an indented `if` statement; for
3028    /// [`Load`] expressions, the caller is probably building up a ternary `?:`
3029    /// expression. In either case, what is written is not a complete syntactic structure
3030    /// in its own right, and the caller will have to finish it off if we return `true`.
3031    ///
3032    /// If no expression is written, return false.
3033    ///
3034    /// [`BoundsCheckPolicy`]: index::BoundsCheckPolicy
3035    /// [`ReadZeroSkipWrite`]: index::BoundsCheckPolicy::ReadZeroSkipWrite
3036    /// [`Store`]: crate::Statement::Store
3037    /// [`Load`]: crate::Expression::Load
3038    #[allow(unused_variables)]
3039    fn put_bounds_checks(
3040        &mut self,
3041        chain: Handle<crate::Expression>,
3042        context: &ExpressionContext,
3043        level: back::Level,
3044        prefix: &'static str,
3045    ) -> Result<bool, Error> {
3046        let mut check_written = false;
3047
3048        // Iterate over the access chain, handling each required bounds check.
3049        for item in context.bounds_check_iter(chain) {
3050            let BoundsCheck {
3051                base,
3052                index,
3053                length,
3054            } = item;
3055
3056            if check_written {
3057                write!(self.out, " && ")?;
3058            } else {
3059                write!(self.out, "{level}{prefix}")?;
3060                check_written = true;
3061            }
3062
3063            // Check that the index falls within bounds. Do this with a single
3064            // comparison, by casting the index to `uint` first, so that negative
3065            // indices become large positive values.
3066            write!(self.out, "uint(")?;
3067            self.put_index(index, context, true)?;
3068            self.out.write_str(") < ")?;
3069            match length {
3070                index::IndexableLength::Known(value) => write!(self.out, "{value}")?,
3071                index::IndexableLength::Dynamic => {
3072                    let global = context.function.originating_global(base).ok_or_else(|| {
3073                        Error::GenericValidation("Could not find originating global".into())
3074                    })?;
3075                    write!(self.out, "1 + ")?;
3076                    self.put_dynamic_array_max_index(global, context)?
3077                }
3078            }
3079        }
3080
3081        Ok(check_written)
3082    }
3083
3084    /// Write the access chain `chain`.
3085    ///
3086    /// `chain` is a subtree of [`Access`] and [`AccessIndex`] expressions,
3087    /// operating either on a pointer to a value, or on a value directly.
3088    ///
3089    /// Generate bounds checks code only if `policy` is [`Restrict`]. The
3090    /// [`ReadZeroSkipWrite`] policy requires checks before any accesses take place, so
3091    /// that must be handled in the caller.
3092    ///
3093    /// Handle the entire chain, recursing back into `put_expression` only for index
3094    /// expressions and the base expression that originates the pointer or composite value
3095    /// being accessed. This allows `put_expression` to assume that any `Access` or
3096    /// `AccessIndex` expressions it sees are the top of a chain, so it can emit
3097    /// `ReadZeroSkipWrite` checks.
3098    ///
3099    /// [`Access`]: crate::Expression::Access
3100    /// [`AccessIndex`]: crate::Expression::AccessIndex
3101    /// [`Restrict`]: crate::proc::index::BoundsCheckPolicy::Restrict
3102    /// [`ReadZeroSkipWrite`]: crate::proc::index::BoundsCheckPolicy::ReadZeroSkipWrite
3103    fn put_access_chain(
3104        &mut self,
3105        chain: Handle<crate::Expression>,
3106        policy: index::BoundsCheckPolicy,
3107        context: &ExpressionContext,
3108    ) -> BackendResult {
3109        match context.function.expressions[chain] {
3110            crate::Expression::Access { base, index } => {
3111                let mut base_ty = context.resolve_type(base);
3112
3113                // Look through any pointers to see what we're really indexing.
3114                if let crate::TypeInner::Pointer { base, space: _ } = *base_ty {
3115                    base_ty = &context.module.types[base].inner;
3116                }
3117
3118                self.put_subscripted_access_chain(
3119                    base,
3120                    base_ty,
3121                    index::GuardedIndex::Expression(index),
3122                    policy,
3123                    context,
3124                )?;
3125            }
3126            crate::Expression::AccessIndex { base, index } => {
3127                let base_resolution = &context.info[base].ty;
3128                let mut base_ty = base_resolution.inner_with(&context.module.types);
3129                let mut base_ty_handle = base_resolution.handle();
3130
3131                // Look through any pointers to see what we're really indexing.
3132                if let crate::TypeInner::Pointer { base, space: _ } = *base_ty {
3133                    base_ty = &context.module.types[base].inner;
3134                    base_ty_handle = Some(base);
3135                }
3136
3137                // Handle structs and anything else that can use `.x` syntax here, so
3138                // `put_subscripted_access_chain` won't have to handle the absurd case of
3139                // indexing a struct with an expression.
3140                match *base_ty {
3141                    crate::TypeInner::Struct { .. } => {
3142                        let base_ty = base_ty_handle.unwrap();
3143                        self.put_access_chain(base, policy, context)?;
3144                        let name = &self.names[&NameKey::StructMember(base_ty, index)];
3145                        write!(self.out, ".{name}")?;
3146                    }
3147                    crate::TypeInner::ValuePointer { .. } | crate::TypeInner::Vector { .. } => {
3148                        self.put_access_chain(base, policy, context)?;
3149                        // Prior to Metal v2.1 component access for packed vectors wasn't available
3150                        // however array indexing is
3151                        if context.get_packed_vec_kind(base).is_some() {
3152                            write!(self.out, "[{index}]")?;
3153                        } else {
3154                            write!(self.out, ".{}", back::COMPONENTS[index as usize])?;
3155                        }
3156                    }
3157                    _ => {
3158                        self.put_subscripted_access_chain(
3159                            base,
3160                            base_ty,
3161                            index::GuardedIndex::Known(index),
3162                            policy,
3163                            context,
3164                        )?;
3165                    }
3166                }
3167            }
3168            _ => self.put_expression(chain, context, false)?,
3169        }
3170
3171        Ok(())
3172    }
3173
3174    /// Write a `[]`-style access of `base` by `index`.
3175    ///
3176    /// If `policy` is [`Restrict`], then generate code as needed to force all index
3177    /// values within bounds.
3178    ///
3179    /// The `base_ty` argument must be the type we are actually indexing, like [`Array`] or
3180    /// [`Vector`]. In other words, it's `base`'s type with any surrounding [`Pointer`]
3181    /// removed. Our callers often already have this handy.
3182    ///
3183    /// This only emits `[]` expressions; it doesn't handle struct member accesses or
3184    /// referencing vector components by name.
3185    ///
3186    /// [`Restrict`]: crate::proc::index::BoundsCheckPolicy::Restrict
3187    /// [`Array`]: crate::TypeInner::Array
3188    /// [`Vector`]: crate::TypeInner::Vector
3189    /// [`Pointer`]: crate::TypeInner::Pointer
3190    fn put_subscripted_access_chain(
3191        &mut self,
3192        base: Handle<crate::Expression>,
3193        base_ty: &crate::TypeInner,
3194        index: index::GuardedIndex,
3195        policy: index::BoundsCheckPolicy,
3196        context: &ExpressionContext,
3197    ) -> BackendResult {
3198        let accessing_wrapped_array = match *base_ty {
3199            crate::TypeInner::Array {
3200                size: crate::ArraySize::Constant(_) | crate::ArraySize::Pending(_),
3201                ..
3202            } => true,
3203            _ => false,
3204        };
3205        let accessing_wrapped_binding_array =
3206            matches!(*base_ty, crate::TypeInner::BindingArray { .. });
3207
3208        self.put_access_chain(base, policy, context)?;
3209        if accessing_wrapped_array {
3210            write!(self.out, ".{WRAPPED_ARRAY_FIELD}")?;
3211        }
3212        write!(self.out, "[")?;
3213
3214        // Decide whether this index needs to be clamped to fall within range.
3215        let restriction_needed = if policy == index::BoundsCheckPolicy::Restrict {
3216            context.access_needs_check(base, index)
3217        } else {
3218            None
3219        };
3220        if let Some(limit) = restriction_needed {
3221            write!(self.out, "{NAMESPACE}::min(unsigned(")?;
3222            self.put_index(index, context, true)?;
3223            write!(self.out, "), ")?;
3224            match limit {
3225                index::IndexableLength::Known(limit) => {
3226                    write!(self.out, "{}u", limit - 1)?;
3227                }
3228                index::IndexableLength::Dynamic => {
3229                    let global = context.function.originating_global(base).ok_or_else(|| {
3230                        Error::GenericValidation("Could not find originating global".into())
3231                    })?;
3232                    self.put_dynamic_array_max_index(global, context)?;
3233                }
3234            }
3235            write!(self.out, ")")?;
3236        } else {
3237            self.put_index(index, context, true)?;
3238        }
3239
3240        write!(self.out, "]")?;
3241
3242        if accessing_wrapped_binding_array {
3243            write!(self.out, ".{WRAPPED_ARRAY_FIELD}")?;
3244        }
3245
3246        Ok(())
3247    }
3248
3249    fn put_load(
3250        &mut self,
3251        pointer: Handle<crate::Expression>,
3252        context: &ExpressionContext,
3253        is_scoped: bool,
3254    ) -> BackendResult {
3255        // Since access chains never cross between address spaces, we can just
3256        // check the index bounds check policy once at the top.
3257        let policy = context.choose_bounds_check_policy(pointer);
3258        if policy == index::BoundsCheckPolicy::ReadZeroSkipWrite
3259            && self.put_bounds_checks(
3260                pointer,
3261                context,
3262                back::Level(0),
3263                if is_scoped { "" } else { "(" },
3264            )?
3265        {
3266            write!(self.out, " ? ")?;
3267            self.put_unchecked_load(pointer, policy, context)?;
3268            write!(self.out, " : DefaultConstructible()")?;
3269
3270            if !is_scoped {
3271                write!(self.out, ")")?;
3272            }
3273        } else {
3274            self.put_unchecked_load(pointer, policy, context)?;
3275        }
3276
3277        Ok(())
3278    }
3279
3280    fn put_unchecked_load(
3281        &mut self,
3282        pointer: Handle<crate::Expression>,
3283        policy: index::BoundsCheckPolicy,
3284        context: &ExpressionContext,
3285    ) -> BackendResult {
3286        let is_atomic_pointer = context
3287            .resolve_type(pointer)
3288            .is_atomic_pointer(&context.module.types);
3289
3290        if is_atomic_pointer {
3291            write!(
3292                self.out,
3293                "{NAMESPACE}::atomic_load_explicit({ATOMIC_REFERENCE}"
3294            )?;
3295            self.put_access_chain(pointer, policy, context)?;
3296            write!(self.out, ", {NAMESPACE}::memory_order_relaxed)")?;
3297        } else {
3298            // We don't do any dereferencing with `*` here as pointer arguments to functions
3299            // are done by `&` references and not `*` pointers. These do not need to be
3300            // dereferenced.
3301            self.put_access_chain(pointer, policy, context)?;
3302        }
3303
3304        Ok(())
3305    }
3306
3307    fn put_return_value(
3308        &mut self,
3309        level: back::Level,
3310        expr_handle: Handle<crate::Expression>,
3311        result_struct: Option<&str>,
3312        context: &ExpressionContext,
3313    ) -> BackendResult {
3314        match result_struct {
3315            Some(struct_name) => {
3316                let mut has_point_size = false;
3317                let result_ty = context.function.result.as_ref().unwrap().ty;
3318                match context.module.types[result_ty].inner {
3319                    crate::TypeInner::Struct { ref members, .. } => {
3320                        let tmp = "_tmp";
3321                        write!(self.out, "{level}const auto {tmp} = ")?;
3322                        self.put_expression(expr_handle, context, true)?;
3323                        writeln!(self.out, ";")?;
3324                        write!(self.out, "{level}return {struct_name} {{")?;
3325
3326                        let mut is_first = true;
3327
3328                        for (index, member) in members.iter().enumerate() {
3329                            if let Some(crate::Binding::BuiltIn(crate::BuiltIn::PointSize)) =
3330                                member.binding
3331                            {
3332                                has_point_size = true;
3333                                if !context.pipeline_options.allow_and_force_point_size {
3334                                    continue;
3335                                }
3336                            }
3337
3338                            let comma = if is_first { "" } else { "," };
3339                            is_first = false;
3340                            let name = &self.names[&NameKey::StructMember(result_ty, index as u32)];
3341                            // HACK: we are forcefully deduplicating the expression here
3342                            // to convert from a wrapped struct to a raw array, e.g.
3343                            // `float gl_ClipDistance1 [[clip_distance]] [1];`.
3344                            if let crate::TypeInner::Array {
3345                                size: crate::ArraySize::Constant(size),
3346                                ..
3347                            } = context.module.types[member.ty].inner
3348                            {
3349                                write!(self.out, "{comma} {{")?;
3350                                for j in 0..size.get() {
3351                                    if j != 0 {
3352                                        write!(self.out, ",")?;
3353                                    }
3354                                    write!(self.out, "{tmp}.{name}.{WRAPPED_ARRAY_FIELD}[{j}]")?;
3355                                }
3356                                write!(self.out, "}}")?;
3357                            } else {
3358                                write!(self.out, "{comma} {tmp}.{name}")?;
3359                            }
3360                        }
3361                    }
3362                    _ => {
3363                        write!(self.out, "{level}return {struct_name} {{ ")?;
3364                        self.put_expression(expr_handle, context, true)?;
3365                    }
3366                }
3367
3368                if let FunctionOrigin::EntryPoint(ep_index) = context.origin {
3369                    let stage = context.module.entry_points[ep_index as usize].stage;
3370                    if context.pipeline_options.allow_and_force_point_size
3371                        && stage == crate::ShaderStage::Vertex
3372                        && !has_point_size
3373                    {
3374                        // point size was injected and comes last
3375                        write!(self.out, ", 1.0")?;
3376                    }
3377                }
3378                write!(self.out, " }}")?;
3379            }
3380            None => {
3381                write!(self.out, "{level}return ")?;
3382                self.put_expression(expr_handle, context, true)?;
3383            }
3384        }
3385        writeln!(self.out, ";")?;
3386        Ok(())
3387    }
3388
3389    /// Helper method used to find which expressions of a given function require baking
3390    ///
3391    /// # Notes
3392    /// This function overwrites the contents of `self.need_bake_expressions`
3393    fn update_expressions_to_bake(
3394        &mut self,
3395        func: &crate::Function,
3396        info: &valid::FunctionInfo,
3397        context: &ExpressionContext,
3398    ) {
3399        use crate::Expression;
3400        self.need_bake_expressions.clear();
3401
3402        for (expr_handle, expr) in func.expressions.iter() {
3403            // Expressions whose reference count is above the
3404            // threshold should always be stored in temporaries.
3405            let expr_info = &info[expr_handle];
3406            let min_ref_count = func.expressions[expr_handle].bake_ref_count();
3407            if min_ref_count <= expr_info.ref_count {
3408                self.need_bake_expressions.insert(expr_handle);
3409            } else {
3410                match expr_info.ty {
3411                    // force ray desc to be baked: it's used multiple times internally
3412                    TypeResolution::Handle(h)
3413                        if Some(h) == context.module.special_types.ray_desc =>
3414                    {
3415                        self.need_bake_expressions.insert(expr_handle);
3416                    }
3417                    _ => {}
3418                }
3419            }
3420
3421            if let Expression::Math {
3422                fun,
3423                arg,
3424                arg1,
3425                arg2,
3426                ..
3427            } = *expr
3428            {
3429                match fun {
3430                    // WGSL's `dot` function works on any `vecN` type, but Metal's only
3431                    // works on floating-point vectors, so we emit inline code for
3432                    // integer vector `dot` calls. But that code uses each argument `N`
3433                    // times, once for each component (see `put_dot_product`), so to
3434                    // avoid duplicated evaluation, we must bake integer operands.
3435                    // This applies both when using the polyfill (because of the duplicate
3436                    // evaluation issue) and when we don't use the polyfill (because we
3437                    // need them to be emitted before casting to packed chars -- see the
3438                    // comment at the call to `put_casting_to_packed_chars`).
3439                    crate::MathFunction::Dot4U8Packed | crate::MathFunction::Dot4I8Packed => {
3440                        self.need_bake_expressions.insert(arg);
3441                        self.need_bake_expressions.insert(arg1.unwrap());
3442                    }
3443                    crate::MathFunction::FirstLeadingBit => {
3444                        self.need_bake_expressions.insert(arg);
3445                    }
3446                    crate::MathFunction::Pack4xI8
3447                    | crate::MathFunction::Pack4xU8
3448                    | crate::MathFunction::Pack4xI8Clamp
3449                    | crate::MathFunction::Pack4xU8Clamp
3450                    | crate::MathFunction::Unpack4xI8
3451                    | crate::MathFunction::Unpack4xU8 => {
3452                        // On MSL < 2.1, we emit a polyfill for these functions that uses the
3453                        // argument multiple times. This is no longer necessary on MSL >= 2.1.
3454                        if context.lang_version < (2, 1) {
3455                            self.need_bake_expressions.insert(arg);
3456                        }
3457                    }
3458                    crate::MathFunction::ExtractBits => {
3459                        // Only argument 1 is re-used.
3460                        self.need_bake_expressions.insert(arg1.unwrap());
3461                    }
3462                    crate::MathFunction::InsertBits => {
3463                        // Only argument 2 is re-used.
3464                        self.need_bake_expressions.insert(arg2.unwrap());
3465                    }
3466                    crate::MathFunction::Sign => {
3467                        // WGSL's `sign` function works also on signed ints, but Metal's only
3468                        // works on floating points, so we emit inline code for integer `sign`
3469                        // calls. But that code uses each argument 2 times (see `put_isign`),
3470                        // so to avoid duplicated evaluation, we must bake the argument.
3471                        let inner = context.resolve_type(expr_handle);
3472                        if inner.scalar_kind() == Some(crate::ScalarKind::Sint) {
3473                            self.need_bake_expressions.insert(arg);
3474                        }
3475                    }
3476                    _ => {}
3477                }
3478            }
3479        }
3480    }
3481
3482    fn start_baking_expression(
3483        &mut self,
3484        handle: Handle<crate::Expression>,
3485        context: &ExpressionContext,
3486        name: &str,
3487    ) -> BackendResult {
3488        match context.info[handle].ty {
3489            TypeResolution::Handle(ty_handle) => {
3490                let ty_name = TypeContext {
3491                    handle: ty_handle,
3492                    gctx: context.module.to_ctx(),
3493                    names: &self.names,
3494                    access: crate::StorageAccess::empty(),
3495                    first_time: false,
3496                };
3497                write!(self.out, "{ty_name}")?;
3498            }
3499            TypeResolution::Value(crate::TypeInner::Scalar(scalar)) => {
3500                put_numeric_type(&mut self.out, scalar, &[])?;
3501            }
3502            TypeResolution::Value(crate::TypeInner::Vector { size, scalar }) => {
3503                put_numeric_type(&mut self.out, scalar, &[size])?;
3504            }
3505            TypeResolution::Value(crate::TypeInner::Matrix {
3506                columns,
3507                rows,
3508                scalar,
3509            }) => {
3510                put_numeric_type(&mut self.out, scalar, &[rows, columns])?;
3511            }
3512            TypeResolution::Value(crate::TypeInner::CooperativeMatrix {
3513                columns,
3514                rows,
3515                scalar,
3516                role: _,
3517            }) => {
3518                write!(
3519                    self.out,
3520                    "{}::simdgroup_{}{}x{}",
3521                    NAMESPACE,
3522                    scalar.to_msl_name(),
3523                    columns as u32,
3524                    rows as u32,
3525                )?;
3526            }
3527            TypeResolution::Value(ref other) => {
3528                log::warn!("Type {other:?} isn't a known local");
3529                return Err(Error::FeatureNotImplemented("weird local type".to_string()));
3530            }
3531        }
3532
3533        //TODO: figure out the naming scheme that wouldn't collide with user names.
3534        write!(self.out, " {name} = ")?;
3535
3536        Ok(())
3537    }
3538
3539    /// Cache a clamped level of detail value, if necessary.
3540    ///
3541    /// [`ImageLoad`] accesses covered by [`BoundsCheckPolicy::Restrict`] use a
3542    /// properly clamped level of detail value both in the access itself, and
3543    /// for fetching the size of the requested MIP level, needed to clamp the
3544    /// coordinates. To avoid recomputing this clamped level of detail, we cache
3545    /// it in a temporary variable, as part of the [`Emit`] statement covering
3546    /// the [`ImageLoad`] expression.
3547    ///
3548    /// [`ImageLoad`]: crate::Expression::ImageLoad
3549    /// [`BoundsCheckPolicy::Restrict`]: index::BoundsCheckPolicy::Restrict
3550    /// [`Emit`]: crate::Statement::Emit
3551    fn put_cache_restricted_level(
3552        &mut self,
3553        load: Handle<crate::Expression>,
3554        image: Handle<crate::Expression>,
3555        mip_level: Option<Handle<crate::Expression>>,
3556        indent: back::Level,
3557        context: &StatementContext,
3558    ) -> BackendResult {
3559        // Does this image access actually require (or even permit) a
3560        // level-of-detail, and does the policy require us to restrict it?
3561        let level_of_detail = match mip_level {
3562            Some(level) => level,
3563            None => return Ok(()),
3564        };
3565
3566        if context.expression.policies.image_load != index::BoundsCheckPolicy::Restrict
3567            || !context.expression.image_needs_lod(image)
3568        {
3569            return Ok(());
3570        }
3571
3572        write!(self.out, "{}uint {} = ", indent, ClampedLod(load),)?;
3573        self.put_restricted_scalar_image_index(
3574            image,
3575            level_of_detail,
3576            "get_num_mip_levels",
3577            &context.expression,
3578        )?;
3579        writeln!(self.out, ";")?;
3580
3581        Ok(())
3582    }
3583
3584    /// Convert the arguments of `Dot4{I, U}Packed` to `packed_(u?)char4`.
3585    ///
3586    /// Caches the results in temporary variables (whose names are derived from
3587    /// the original variable names). This caching avoids the need to redo the
3588    /// casting for each vector component when emitting the dot product.
3589    fn put_casting_to_packed_chars(
3590        &mut self,
3591        fun: crate::MathFunction,
3592        arg0: Handle<crate::Expression>,
3593        arg1: Handle<crate::Expression>,
3594        indent: back::Level,
3595        context: &StatementContext<'_>,
3596    ) -> Result<(), Error> {
3597        let packed_type = match fun {
3598            crate::MathFunction::Dot4I8Packed => "packed_char4",
3599            crate::MathFunction::Dot4U8Packed => "packed_uchar4",
3600            _ => unreachable!(),
3601        };
3602
3603        for arg in [arg0, arg1] {
3604            write!(
3605                self.out,
3606                "{indent}{packed_type} {0} = as_type<{packed_type}>(",
3607                Reinterpreted::new(packed_type, arg)
3608            )?;
3609            self.put_expression(arg, &context.expression, true)?;
3610            writeln!(self.out, ");")?;
3611        }
3612
3613        Ok(())
3614    }
3615
3616    fn put_block(
3617        &mut self,
3618        level: back::Level,
3619        statements: &[crate::Statement],
3620        context: &StatementContext,
3621    ) -> BackendResult {
3622        // Add to the set in order to track the stack size.
3623        #[cfg(test)]
3624        self.put_block_stack_pointers
3625            .insert(ptr::from_ref(&level).cast());
3626
3627        for statement in statements {
3628            log::trace!("statement[{}] {:?}", level.0, statement);
3629            match *statement {
3630                crate::Statement::Emit(ref range) => {
3631                    for handle in range.clone() {
3632                        use crate::MathFunction as Mf;
3633
3634                        match context.expression.function.expressions[handle] {
3635                            // `ImageLoad` expressions covered by the `Restrict` bounds check policy
3636                            // may need to cache a clamped version of their level-of-detail argument.
3637                            crate::Expression::ImageLoad {
3638                                image,
3639                                level: mip_level,
3640                                ..
3641                            } => {
3642                                self.put_cache_restricted_level(
3643                                    handle, image, mip_level, level, context,
3644                                )?;
3645                            }
3646
3647                            // If we are going to write a `Dot4I8Packed` or `Dot4U8Packed` on Metal
3648                            // 2.1+ then we introduce two intermediate variables that recast the two
3649                            // arguments as packed (signed or unsigned) chars. The actual dot product
3650                            // is implemented in `Self::put_expression`, and it uses both of these
3651                            // intermediate variables multiple times. There's no danger that the
3652                            // original arguments get modified between the definition of these
3653                            // intermediate variables and the implementation of the actual dot
3654                            // product since we require the inputs of `Dot4{I, U}Packed` to be baked.
3655                            crate::Expression::Math {
3656                                fun: fun @ (Mf::Dot4I8Packed | Mf::Dot4U8Packed),
3657                                arg,
3658                                arg1,
3659                                ..
3660                            } if context.expression.lang_version >= (2, 1) => {
3661                                self.put_casting_to_packed_chars(
3662                                    fun,
3663                                    arg,
3664                                    arg1.unwrap(),
3665                                    level,
3666                                    context,
3667                                )?;
3668                            }
3669
3670                            _ => (),
3671                        }
3672
3673                        let ptr_class = context.expression.resolve_type(handle).pointer_space();
3674                        let expr_name = if ptr_class.is_some() {
3675                            None // don't bake pointer expressions (just yet)
3676                        } else if let Some(name) =
3677                            context.expression.function.named_expressions.get(&handle)
3678                        {
3679                            // The `crate::Function::named_expressions` table holds
3680                            // expressions that should be saved in temporaries once they
3681                            // are `Emit`ted. We only add them to `self.named_expressions`
3682                            // when we reach the `Emit` that covers them, so that we don't
3683                            // try to use their names before we've actually initialized
3684                            // the temporary that holds them.
3685                            //
3686                            // Don't assume the names in `named_expressions` are unique,
3687                            // or even valid. Use the `Namer`.
3688                            Some(self.namer.call(name))
3689                        } else {
3690                            // If this expression is an index that we're going to first compare
3691                            // against a limit, and then actually use as an index, then we may
3692                            // want to cache it in a temporary, to avoid evaluating it twice.
3693                            let bake = if context.expression.guarded_indices.contains(handle) {
3694                                true
3695                            } else {
3696                                self.need_bake_expressions.contains(&handle)
3697                            };
3698
3699                            if bake {
3700                                Some(Baked(handle).to_string())
3701                            } else {
3702                                None
3703                            }
3704                        };
3705
3706                        if let Some(name) = expr_name {
3707                            write!(self.out, "{level}")?;
3708                            self.start_baking_expression(handle, &context.expression, &name)?;
3709                            self.put_expression(handle, &context.expression, true)?;
3710                            self.named_expressions.insert(handle, name);
3711                            writeln!(self.out, ";")?;
3712                        }
3713                    }
3714                }
3715                crate::Statement::Block(ref block) => {
3716                    if !block.is_empty() {
3717                        writeln!(self.out, "{level}{{")?;
3718                        self.put_block(level.next(), block, context)?;
3719                        writeln!(self.out, "{level}}}")?;
3720                    }
3721                }
3722                crate::Statement::If {
3723                    condition,
3724                    ref accept,
3725                    ref reject,
3726                } => {
3727                    write!(self.out, "{level}if (")?;
3728                    self.put_expression(condition, &context.expression, true)?;
3729                    writeln!(self.out, ") {{")?;
3730                    self.put_block(level.next(), accept, context)?;
3731                    if !reject.is_empty() {
3732                        writeln!(self.out, "{level}}} else {{")?;
3733                        self.put_block(level.next(), reject, context)?;
3734                    }
3735                    writeln!(self.out, "{level}}}")?;
3736                }
3737                crate::Statement::Switch {
3738                    selector,
3739                    ref cases,
3740                } => {
3741                    write!(self.out, "{level}switch(")?;
3742                    self.put_expression(selector, &context.expression, true)?;
3743                    writeln!(self.out, ") {{")?;
3744                    let lcase = level.next();
3745                    for case in cases.iter() {
3746                        match case.value {
3747                            crate::SwitchValue::I32(value) => {
3748                                write!(self.out, "{lcase}case {value}:")?;
3749                            }
3750                            crate::SwitchValue::U32(value) => {
3751                                write!(self.out, "{lcase}case {value}u:")?;
3752                            }
3753                            crate::SwitchValue::Default => {
3754                                write!(self.out, "{lcase}default:")?;
3755                            }
3756                        }
3757
3758                        let write_block_braces = !(case.fall_through && case.body.is_empty());
3759                        if write_block_braces {
3760                            writeln!(self.out, " {{")?;
3761                        } else {
3762                            writeln!(self.out)?;
3763                        }
3764
3765                        self.put_block(lcase.next(), &case.body, context)?;
3766                        if !case.fall_through && case.body.last().is_none_or(|s| !s.is_terminator())
3767                        {
3768                            writeln!(self.out, "{}break;", lcase.next())?;
3769                        }
3770
3771                        if write_block_braces {
3772                            writeln!(self.out, "{lcase}}}")?;
3773                        }
3774                    }
3775                    writeln!(self.out, "{level}}}")?;
3776                }
3777                crate::Statement::Loop {
3778                    ref body,
3779                    ref continuing,
3780                    break_if,
3781                } => {
3782                    let force_loop_bound_statements =
3783                        self.gen_force_bounded_loop_statements(level, context);
3784                    let gate_name = (!continuing.is_empty() || break_if.is_some())
3785                        .then(|| self.namer.call("loop_init"));
3786
3787                    if let Some((ref decl, _)) = force_loop_bound_statements {
3788                        writeln!(self.out, "{decl}")?;
3789                    }
3790                    if let Some(ref gate_name) = gate_name {
3791                        writeln!(self.out, "{level}bool {gate_name} = true;")?;
3792                    }
3793
3794                    writeln!(self.out, "{level}while(true) {{",)?;
3795                    if let Some((_, ref break_and_inc)) = force_loop_bound_statements {
3796                        writeln!(self.out, "{break_and_inc}")?;
3797                    }
3798                    if let Some(ref gate_name) = gate_name {
3799                        let lif = level.next();
3800                        let lcontinuing = lif.next();
3801                        writeln!(self.out, "{lif}if (!{gate_name}) {{")?;
3802                        self.put_block(lcontinuing, continuing, context)?;
3803                        if let Some(condition) = break_if {
3804                            write!(self.out, "{lcontinuing}if (")?;
3805                            self.put_expression(condition, &context.expression, true)?;
3806                            writeln!(self.out, ") {{")?;
3807                            writeln!(self.out, "{}break;", lcontinuing.next())?;
3808                            writeln!(self.out, "{lcontinuing}}}")?;
3809                        }
3810                        writeln!(self.out, "{lif}}}")?;
3811                        writeln!(self.out, "{lif}{gate_name} = false;")?;
3812                    }
3813                    self.put_block(level.next(), body, context)?;
3814
3815                    writeln!(self.out, "{level}}}")?;
3816                }
3817                crate::Statement::Break => {
3818                    writeln!(self.out, "{level}break;")?;
3819                }
3820                crate::Statement::Continue => {
3821                    writeln!(self.out, "{level}continue;")?;
3822                }
3823                crate::Statement::Return {
3824                    value: Some(expr_handle),
3825                } => {
3826                    self.put_return_value(
3827                        level,
3828                        expr_handle,
3829                        context.result_struct,
3830                        &context.expression,
3831                    )?;
3832                }
3833                crate::Statement::Return { value: None } => {
3834                    writeln!(self.out, "{level}return;")?;
3835                }
3836                crate::Statement::Kill => {
3837                    writeln!(self.out, "{level}{NAMESPACE}::discard_fragment();")?;
3838                }
3839                crate::Statement::ControlBarrier(flags)
3840                | crate::Statement::MemoryBarrier(flags) => {
3841                    self.write_barrier(flags, level)?;
3842                }
3843                crate::Statement::Store { pointer, value } => {
3844                    self.put_store(pointer, value, level, context)?
3845                }
3846                crate::Statement::ImageStore {
3847                    image,
3848                    coordinate,
3849                    array_index,
3850                    value,
3851                } => {
3852                    let address = TexelAddress {
3853                        coordinate,
3854                        array_index,
3855                        sample: None,
3856                        level: None,
3857                    };
3858                    self.put_image_store(level, image, &address, value, context)?
3859                }
3860                crate::Statement::Call {
3861                    function,
3862                    ref arguments,
3863                    result,
3864                } => {
3865                    write!(self.out, "{level}")?;
3866                    if let Some(expr) = result {
3867                        let name = Baked(expr).to_string();
3868                        self.start_baking_expression(expr, &context.expression, &name)?;
3869                        self.named_expressions.insert(expr, name);
3870                    }
3871                    let fun_name = &self.names[&NameKey::Function(function)];
3872                    write!(self.out, "{fun_name}(")?;
3873                    // first, write down the actual arguments
3874                    for (i, &handle) in arguments.iter().enumerate() {
3875                        if i != 0 {
3876                            write!(self.out, ", ")?;
3877                        }
3878                        self.put_expression(handle, &context.expression, true)?;
3879                    }
3880                    // follow-up with any global resources used
3881                    let mut separate = !arguments.is_empty();
3882                    let fun_info = &context.expression.mod_info[function];
3883                    let mut needs_buffer_sizes = false;
3884                    for (handle, var) in context.expression.module.global_variables.iter() {
3885                        if fun_info[handle].is_empty() {
3886                            continue;
3887                        }
3888                        if var.space.needs_pass_through() {
3889                            let name = &self.names[&NameKey::GlobalVariable(handle)];
3890                            if separate {
3891                                write!(self.out, ", ")?;
3892                            } else {
3893                                separate = true;
3894                            }
3895                            write!(self.out, "{name}")?;
3896                        }
3897                        needs_buffer_sizes |=
3898                            needs_array_length(var.ty, &context.expression.module.types);
3899                    }
3900                    if needs_buffer_sizes {
3901                        if separate {
3902                            write!(self.out, ", ")?;
3903                        }
3904                        write!(self.out, "_buffer_sizes")?;
3905                    }
3906
3907                    // done
3908                    writeln!(self.out, ");")?;
3909                }
3910                crate::Statement::Atomic {
3911                    pointer,
3912                    ref fun,
3913                    value,
3914                    result,
3915                } => {
3916                    let context = &context.expression;
3917
3918                    // This backend supports `SHADER_INT64_ATOMIC_MIN_MAX` but not
3919                    // `SHADER_INT64_ATOMIC_ALL_OPS`, so we can assume that if `result` is
3920                    // `Some`, we are not operating on a 64-bit value, and that if we are
3921                    // operating on a 64-bit value, `result` is `None`.
3922                    write!(self.out, "{level}")?;
3923                    let fun_key = if let Some(result) = result {
3924                        let res_name = Baked(result).to_string();
3925                        self.start_baking_expression(result, context, &res_name)?;
3926                        self.named_expressions.insert(result, res_name);
3927                        fun.to_msl()
3928                    } else if context.resolve_type(value).scalar_width() == Some(8) {
3929                        fun.to_msl_64_bit()?
3930                    } else {
3931                        fun.to_msl()
3932                    };
3933
3934                    // If the pointer we're passing to the atomic operation needs to be conditional
3935                    // for `ReadZeroSkipWrite`, the condition needs to *surround* the atomic op, and
3936                    // the pointer operand should be unchecked.
3937                    let policy = context.choose_bounds_check_policy(pointer);
3938                    let checked = policy == index::BoundsCheckPolicy::ReadZeroSkipWrite
3939                        && self.put_bounds_checks(pointer, context, back::Level(0), "")?;
3940
3941                    // If requested and successfully put bounds checks, continue the ternary expression.
3942                    if checked {
3943                        write!(self.out, " ? ")?;
3944                    }
3945
3946                    // Put the atomic function invocation.
3947                    match *fun {
3948                        crate::AtomicFunction::Exchange { compare: Some(cmp) } => {
3949                            write!(self.out, "{ATOMIC_COMP_EXCH_FUNCTION}({ATOMIC_REFERENCE}")?;
3950                            self.put_access_chain(pointer, policy, context)?;
3951                            write!(self.out, ", ")?;
3952                            self.put_expression(cmp, context, true)?;
3953                            write!(self.out, ", ")?;
3954                            self.put_expression(value, context, true)?;
3955                            write!(self.out, ")")?;
3956                        }
3957                        _ => {
3958                            write!(
3959                                self.out,
3960                                "{NAMESPACE}::atomic_{fun_key}_explicit({ATOMIC_REFERENCE}"
3961                            )?;
3962                            self.put_access_chain(pointer, policy, context)?;
3963                            write!(self.out, ", ")?;
3964                            self.put_expression(value, context, true)?;
3965                            write!(self.out, ", {NAMESPACE}::memory_order_relaxed)")?;
3966                        }
3967                    }
3968
3969                    // Finish the ternary expression.
3970                    if checked {
3971                        write!(self.out, " : DefaultConstructible()")?;
3972                    }
3973
3974                    // Done
3975                    writeln!(self.out, ";")?;
3976                }
3977                crate::Statement::ImageAtomic {
3978                    image,
3979                    coordinate,
3980                    array_index,
3981                    fun,
3982                    value,
3983                } => {
3984                    let address = TexelAddress {
3985                        coordinate,
3986                        array_index,
3987                        sample: None,
3988                        level: None,
3989                    };
3990                    self.put_image_atomic(level, image, &address, fun, value, context)?
3991                }
3992                crate::Statement::WorkGroupUniformLoad { pointer, result } => {
3993                    self.write_barrier(crate::Barrier::WORK_GROUP, level)?;
3994
3995                    write!(self.out, "{level}")?;
3996                    let name = self.namer.call("");
3997                    self.start_baking_expression(result, &context.expression, &name)?;
3998                    self.put_load(pointer, &context.expression, true)?;
3999                    self.named_expressions.insert(result, name);
4000
4001                    writeln!(self.out, ";")?;
4002                    self.write_barrier(crate::Barrier::WORK_GROUP, level)?;
4003                }
4004                crate::Statement::RayQuery { query, ref fun } => {
4005                    if context.expression.lang_version < (2, 4) {
4006                        return Err(Error::UnsupportedRayTracing);
4007                    }
4008
4009                    match *fun {
4010                        crate::RayQueryFunction::Initialize {
4011                            acceleration_structure,
4012                            descriptor,
4013                        } => {
4014                            //TODO: how to deal with winding?
4015                            write!(self.out, "{level}")?;
4016                            self.put_expression(query, &context.expression, true)?;
4017                            writeln!(self.out, ".{RAY_QUERY_FIELD_INTERSECTOR}.assume_geometry_type({RT_NAMESPACE}::geometry_type::triangle);")?;
4018                            {
4019                                let f_opaque = back::RayFlag::CULL_OPAQUE.bits();
4020                                let f_no_opaque = back::RayFlag::CULL_NO_OPAQUE.bits();
4021                                write!(self.out, "{level}")?;
4022                                self.put_expression(query, &context.expression, true)?;
4023                                write!(
4024                                    self.out,
4025                                    ".{RAY_QUERY_FIELD_INTERSECTOR}.set_opacity_cull_mode(("
4026                                )?;
4027                                self.put_expression(descriptor, &context.expression, true)?;
4028                                write!(self.out, ".flags & {f_opaque}) != 0 ? {RT_NAMESPACE}::opacity_cull_mode::opaque : (")?;
4029                                self.put_expression(descriptor, &context.expression, true)?;
4030                                write!(self.out, ".flags & {f_no_opaque}) != 0 ? {RT_NAMESPACE}::opacity_cull_mode::non_opaque : ")?;
4031                                writeln!(self.out, "{RT_NAMESPACE}::opacity_cull_mode::none);")?;
4032                            }
4033                            {
4034                                let f_opaque = back::RayFlag::OPAQUE.bits();
4035                                let f_no_opaque = back::RayFlag::NO_OPAQUE.bits();
4036                                write!(self.out, "{level}")?;
4037                                self.put_expression(query, &context.expression, true)?;
4038                                write!(self.out, ".{RAY_QUERY_FIELD_INTERSECTOR}.force_opacity((")?;
4039                                self.put_expression(descriptor, &context.expression, true)?;
4040                                write!(self.out, ".flags & {f_opaque}) != 0 ? {RT_NAMESPACE}::forced_opacity::opaque : (")?;
4041                                self.put_expression(descriptor, &context.expression, true)?;
4042                                write!(self.out, ".flags & {f_no_opaque}) != 0 ? {RT_NAMESPACE}::forced_opacity::non_opaque : ")?;
4043                                writeln!(self.out, "{RT_NAMESPACE}::forced_opacity::none);")?;
4044                            }
4045                            {
4046                                let flag = back::RayFlag::TERMINATE_ON_FIRST_HIT.bits();
4047                                write!(self.out, "{level}")?;
4048                                self.put_expression(query, &context.expression, true)?;
4049                                write!(
4050                                    self.out,
4051                                    ".{RAY_QUERY_FIELD_INTERSECTOR}.accept_any_intersection(("
4052                                )?;
4053                                self.put_expression(descriptor, &context.expression, true)?;
4054                                writeln!(self.out, ".flags & {flag}) != 0);")?;
4055                            }
4056
4057                            write!(self.out, "{level}")?;
4058                            self.put_expression(query, &context.expression, true)?;
4059                            write!(self.out, ".{RAY_QUERY_FIELD_INTERSECTION} = ")?;
4060                            self.put_expression(query, &context.expression, true)?;
4061                            write!(
4062                                self.out,
4063                                ".{RAY_QUERY_FIELD_INTERSECTOR}.intersect({RT_NAMESPACE}::ray("
4064                            )?;
4065                            self.put_expression(descriptor, &context.expression, true)?;
4066                            write!(self.out, ".origin, ")?;
4067                            self.put_expression(descriptor, &context.expression, true)?;
4068                            write!(self.out, ".dir, ")?;
4069                            self.put_expression(descriptor, &context.expression, true)?;
4070                            write!(self.out, ".tmin, ")?;
4071                            self.put_expression(descriptor, &context.expression, true)?;
4072                            write!(self.out, ".tmax), ")?;
4073                            self.put_expression(acceleration_structure, &context.expression, true)?;
4074                            write!(self.out, ", ")?;
4075                            self.put_expression(descriptor, &context.expression, true)?;
4076                            write!(self.out, ".cull_mask);")?;
4077
4078                            write!(self.out, "{level}")?;
4079                            self.put_expression(query, &context.expression, true)?;
4080                            writeln!(self.out, ".{RAY_QUERY_FIELD_READY} = true;")?;
4081                        }
4082                        crate::RayQueryFunction::Proceed { result } => {
4083                            write!(self.out, "{level}")?;
4084                            let name = Baked(result).to_string();
4085                            self.start_baking_expression(result, &context.expression, &name)?;
4086                            self.named_expressions.insert(result, name);
4087                            self.put_expression(query, &context.expression, true)?;
4088                            writeln!(self.out, ".{RAY_QUERY_FIELD_READY};")?;
4089                            if RAY_QUERY_MODERN_SUPPORT {
4090                                write!(self.out, "{level}")?;
4091                                self.put_expression(query, &context.expression, true)?;
4092                                writeln!(self.out, ".?.next();")?;
4093                            }
4094                        }
4095                        crate::RayQueryFunction::GenerateIntersection { hit_t } => {
4096                            if RAY_QUERY_MODERN_SUPPORT {
4097                                write!(self.out, "{level}")?;
4098                                self.put_expression(query, &context.expression, true)?;
4099                                write!(self.out, ".?.commit_bounding_box_intersection(")?;
4100                                self.put_expression(hit_t, &context.expression, true)?;
4101                                writeln!(self.out, ");")?;
4102                            } else {
4103                                log::warn!("Ray Query GenerateIntersection is not yet supported");
4104                            }
4105                        }
4106                        crate::RayQueryFunction::ConfirmIntersection => {
4107                            if RAY_QUERY_MODERN_SUPPORT {
4108                                write!(self.out, "{level}")?;
4109                                self.put_expression(query, &context.expression, true)?;
4110                                writeln!(self.out, ".?.commit_triangle_intersection();")?;
4111                            } else {
4112                                log::warn!("Ray Query ConfirmIntersection is not yet supported");
4113                            }
4114                        }
4115                        crate::RayQueryFunction::Terminate => {
4116                            if RAY_QUERY_MODERN_SUPPORT {
4117                                write!(self.out, "{level}")?;
4118                                self.put_expression(query, &context.expression, true)?;
4119                                writeln!(self.out, ".?.abort();")?;
4120                            }
4121                            write!(self.out, "{level}")?;
4122                            self.put_expression(query, &context.expression, true)?;
4123                            writeln!(self.out, ".{RAY_QUERY_FIELD_READY} = false;")?;
4124                        }
4125                    }
4126                }
4127                crate::Statement::SubgroupBallot { result, predicate } => {
4128                    write!(self.out, "{level}")?;
4129                    let name = self.namer.call("");
4130                    self.start_baking_expression(result, &context.expression, &name)?;
4131                    self.named_expressions.insert(result, name);
4132                    write!(
4133                        self.out,
4134                        "{NAMESPACE}::uint4((uint64_t){NAMESPACE}::simd_ballot("
4135                    )?;
4136                    if let Some(predicate) = predicate {
4137                        self.put_expression(predicate, &context.expression, true)?;
4138                    } else {
4139                        write!(self.out, "true")?;
4140                    }
4141                    writeln!(self.out, "), 0, 0, 0);")?;
4142                }
4143                crate::Statement::SubgroupCollectiveOperation {
4144                    op,
4145                    collective_op,
4146                    argument,
4147                    result,
4148                } => {
4149                    write!(self.out, "{level}")?;
4150                    let name = self.namer.call("");
4151                    self.start_baking_expression(result, &context.expression, &name)?;
4152                    self.named_expressions.insert(result, name);
4153                    match (collective_op, op) {
4154                        (crate::CollectiveOperation::Reduce, crate::SubgroupOperation::All) => {
4155                            write!(self.out, "{NAMESPACE}::simd_all(")?
4156                        }
4157                        (crate::CollectiveOperation::Reduce, crate::SubgroupOperation::Any) => {
4158                            write!(self.out, "{NAMESPACE}::simd_any(")?
4159                        }
4160                        (crate::CollectiveOperation::Reduce, crate::SubgroupOperation::Add) => {
4161                            write!(self.out, "{NAMESPACE}::simd_sum(")?
4162                        }
4163                        (crate::CollectiveOperation::Reduce, crate::SubgroupOperation::Mul) => {
4164                            write!(self.out, "{NAMESPACE}::simd_product(")?
4165                        }
4166                        (crate::CollectiveOperation::Reduce, crate::SubgroupOperation::Max) => {
4167                            write!(self.out, "{NAMESPACE}::simd_max(")?
4168                        }
4169                        (crate::CollectiveOperation::Reduce, crate::SubgroupOperation::Min) => {
4170                            write!(self.out, "{NAMESPACE}::simd_min(")?
4171                        }
4172                        (crate::CollectiveOperation::Reduce, crate::SubgroupOperation::And) => {
4173                            write!(self.out, "{NAMESPACE}::simd_and(")?
4174                        }
4175                        (crate::CollectiveOperation::Reduce, crate::SubgroupOperation::Or) => {
4176                            write!(self.out, "{NAMESPACE}::simd_or(")?
4177                        }
4178                        (crate::CollectiveOperation::Reduce, crate::SubgroupOperation::Xor) => {
4179                            write!(self.out, "{NAMESPACE}::simd_xor(")?
4180                        }
4181                        (
4182                            crate::CollectiveOperation::ExclusiveScan,
4183                            crate::SubgroupOperation::Add,
4184                        ) => write!(self.out, "{NAMESPACE}::simd_prefix_exclusive_sum(")?,
4185                        (
4186                            crate::CollectiveOperation::ExclusiveScan,
4187                            crate::SubgroupOperation::Mul,
4188                        ) => write!(self.out, "{NAMESPACE}::simd_prefix_exclusive_product(")?,
4189                        (
4190                            crate::CollectiveOperation::InclusiveScan,
4191                            crate::SubgroupOperation::Add,
4192                        ) => write!(self.out, "{NAMESPACE}::simd_prefix_inclusive_sum(")?,
4193                        (
4194                            crate::CollectiveOperation::InclusiveScan,
4195                            crate::SubgroupOperation::Mul,
4196                        ) => write!(self.out, "{NAMESPACE}::simd_prefix_inclusive_product(")?,
4197                        _ => unimplemented!(),
4198                    }
4199                    self.put_expression(argument, &context.expression, true)?;
4200                    writeln!(self.out, ");")?;
4201                }
4202                crate::Statement::SubgroupGather {
4203                    mode,
4204                    argument,
4205                    result,
4206                } => {
4207                    write!(self.out, "{level}")?;
4208                    let name = self.namer.call("");
4209                    self.start_baking_expression(result, &context.expression, &name)?;
4210                    self.named_expressions.insert(result, name);
4211                    match mode {
4212                        crate::GatherMode::BroadcastFirst => {
4213                            write!(self.out, "{NAMESPACE}::simd_broadcast_first(")?;
4214                        }
4215                        crate::GatherMode::Broadcast(_) => {
4216                            write!(self.out, "{NAMESPACE}::simd_broadcast(")?;
4217                        }
4218                        crate::GatherMode::Shuffle(_) => {
4219                            write!(self.out, "{NAMESPACE}::simd_shuffle(")?;
4220                        }
4221                        crate::GatherMode::ShuffleDown(_) => {
4222                            write!(self.out, "{NAMESPACE}::simd_shuffle_down(")?;
4223                        }
4224                        crate::GatherMode::ShuffleUp(_) => {
4225                            write!(self.out, "{NAMESPACE}::simd_shuffle_up(")?;
4226                        }
4227                        crate::GatherMode::ShuffleXor(_) => {
4228                            write!(self.out, "{NAMESPACE}::simd_shuffle_xor(")?;
4229                        }
4230                        crate::GatherMode::QuadBroadcast(_) => {
4231                            write!(self.out, "{NAMESPACE}::quad_broadcast(")?;
4232                        }
4233                        crate::GatherMode::QuadSwap(_) => {
4234                            write!(self.out, "{NAMESPACE}::quad_shuffle_xor(")?;
4235                        }
4236                    }
4237                    self.put_expression(argument, &context.expression, true)?;
4238                    match mode {
4239                        crate::GatherMode::BroadcastFirst => {}
4240                        crate::GatherMode::Broadcast(index)
4241                        | crate::GatherMode::Shuffle(index)
4242                        | crate::GatherMode::ShuffleDown(index)
4243                        | crate::GatherMode::ShuffleUp(index)
4244                        | crate::GatherMode::ShuffleXor(index)
4245                        | crate::GatherMode::QuadBroadcast(index) => {
4246                            write!(self.out, ", ")?;
4247                            self.put_expression(index, &context.expression, true)?;
4248                        }
4249                        crate::GatherMode::QuadSwap(direction) => {
4250                            write!(self.out, ", ")?;
4251                            match direction {
4252                                crate::Direction::X => {
4253                                    write!(self.out, "1u")?;
4254                                }
4255                                crate::Direction::Y => {
4256                                    write!(self.out, "2u")?;
4257                                }
4258                                crate::Direction::Diagonal => {
4259                                    write!(self.out, "3u")?;
4260                                }
4261                            }
4262                        }
4263                    }
4264                    writeln!(self.out, ");")?;
4265                }
4266                crate::Statement::CooperativeStore { target, ref data } => {
4267                    write!(self.out, "{level}simdgroup_store(")?;
4268                    self.put_expression(target, &context.expression, true)?;
4269                    write!(self.out, ", &")?;
4270                    self.put_access_chain(
4271                        data.pointer,
4272                        context.expression.policies.index,
4273                        &context.expression,
4274                    )?;
4275                    write!(self.out, ", ")?;
4276                    self.put_expression(data.stride, &context.expression, true)?;
4277                    if data.row_major {
4278                        let matrix_origin = "0";
4279                        let transpose = true;
4280                        write!(self.out, ", {matrix_origin}, {transpose}")?;
4281                    }
4282                    writeln!(self.out, ");")?;
4283                }
4284            }
4285        }
4286
4287        // un-emit expressions
4288        //TODO: take care of loop/continuing?
4289        for statement in statements {
4290            if let crate::Statement::Emit(ref range) = *statement {
4291                for handle in range.clone() {
4292                    self.named_expressions.shift_remove(&handle);
4293                }
4294            }
4295        }
4296        Ok(())
4297    }
4298
4299    fn put_store(
4300        &mut self,
4301        pointer: Handle<crate::Expression>,
4302        value: Handle<crate::Expression>,
4303        level: back::Level,
4304        context: &StatementContext,
4305    ) -> BackendResult {
4306        let policy = context.expression.choose_bounds_check_policy(pointer);
4307        if policy == index::BoundsCheckPolicy::ReadZeroSkipWrite
4308            && self.put_bounds_checks(pointer, &context.expression, level, "if (")?
4309        {
4310            writeln!(self.out, ") {{")?;
4311            self.put_unchecked_store(pointer, value, policy, level.next(), context)?;
4312            writeln!(self.out, "{level}}}")?;
4313        } else {
4314            self.put_unchecked_store(pointer, value, policy, level, context)?;
4315        }
4316
4317        Ok(())
4318    }
4319
4320    fn put_unchecked_store(
4321        &mut self,
4322        pointer: Handle<crate::Expression>,
4323        value: Handle<crate::Expression>,
4324        policy: index::BoundsCheckPolicy,
4325        level: back::Level,
4326        context: &StatementContext,
4327    ) -> BackendResult {
4328        let is_atomic_pointer = context
4329            .expression
4330            .resolve_type(pointer)
4331            .is_atomic_pointer(&context.expression.module.types);
4332
4333        if is_atomic_pointer {
4334            write!(
4335                self.out,
4336                "{level}{NAMESPACE}::atomic_store_explicit({ATOMIC_REFERENCE}"
4337            )?;
4338            self.put_access_chain(pointer, policy, &context.expression)?;
4339            write!(self.out, ", ")?;
4340            self.put_expression(value, &context.expression, true)?;
4341            writeln!(self.out, ", {NAMESPACE}::memory_order_relaxed);")?;
4342        } else {
4343            write!(self.out, "{level}")?;
4344            self.put_access_chain(pointer, policy, &context.expression)?;
4345            write!(self.out, " = ")?;
4346            self.put_expression(value, &context.expression, true)?;
4347            writeln!(self.out, ";")?;
4348        }
4349
4350        Ok(())
4351    }
4352
4353    pub fn write(
4354        &mut self,
4355        module: &crate::Module,
4356        info: &valid::ModuleInfo,
4357        options: &Options,
4358        pipeline_options: &PipelineOptions,
4359    ) -> Result<TranslationInfo, Error> {
4360        self.names.clear();
4361        self.namer.reset(
4362            module,
4363            &super::keywords::RESERVED_SET,
4364            proc::CaseInsensitiveKeywordSet::empty(),
4365            &[CLAMPED_LOD_LOAD_PREFIX],
4366            &mut self.names,
4367        );
4368        self.wrapped_functions.clear();
4369        self.struct_member_pads.clear();
4370
4371        writeln!(
4372            self.out,
4373            "// language: metal{}.{}",
4374            options.lang_version.0, options.lang_version.1
4375        )?;
4376        writeln!(self.out, "#include <metal_stdlib>")?;
4377        writeln!(self.out, "#include <simd/simd.h>")?;
4378        writeln!(self.out)?;
4379        // Work around Metal bug where `uint` is not available by default
4380        writeln!(self.out, "using {NAMESPACE}::uint;")?;
4381
4382        let mut uses_ray_query = false;
4383        for (_, ty) in module.types.iter() {
4384            match ty.inner {
4385                crate::TypeInner::AccelerationStructure { .. } => {
4386                    if options.lang_version < (2, 4) {
4387                        return Err(Error::UnsupportedRayTracing);
4388                    }
4389                }
4390                crate::TypeInner::RayQuery { .. } => {
4391                    if options.lang_version < (2, 4) {
4392                        return Err(Error::UnsupportedRayTracing);
4393                    }
4394                    uses_ray_query = true;
4395                }
4396                _ => (),
4397            }
4398        }
4399
4400        if module.special_types.ray_desc.is_some()
4401            || module.special_types.ray_intersection.is_some()
4402        {
4403            if options.lang_version < (2, 4) {
4404                return Err(Error::UnsupportedRayTracing);
4405            }
4406        }
4407
4408        if uses_ray_query {
4409            self.put_ray_query_type()?;
4410        }
4411
4412        if options
4413            .bounds_check_policies
4414            .contains(index::BoundsCheckPolicy::ReadZeroSkipWrite)
4415        {
4416            self.put_default_constructible()?;
4417        }
4418        writeln!(self.out)?;
4419
4420        {
4421            // Make a `Vec` of all the `GlobalVariable`s that contain
4422            // runtime-sized arrays.
4423            let globals: Vec<Handle<crate::GlobalVariable>> = module
4424                .global_variables
4425                .iter()
4426                .filter(|&(_, var)| needs_array_length(var.ty, &module.types))
4427                .map(|(handle, _)| handle)
4428                .collect();
4429
4430            let mut buffer_indices = vec![];
4431            for vbm in &pipeline_options.vertex_buffer_mappings {
4432                buffer_indices.push(vbm.id);
4433            }
4434
4435            if !globals.is_empty() || !buffer_indices.is_empty() {
4436                writeln!(self.out, "struct _mslBufferSizes {{")?;
4437
4438                for global in globals {
4439                    writeln!(
4440                        self.out,
4441                        "{}uint {};",
4442                        back::INDENT,
4443                        ArraySizeMember(global)
4444                    )?;
4445                }
4446
4447                for idx in buffer_indices {
4448                    writeln!(self.out, "{}uint buffer_size{};", back::INDENT, idx)?;
4449                }
4450
4451                writeln!(self.out, "}};")?;
4452                writeln!(self.out)?;
4453            }
4454        };
4455
4456        self.write_type_defs(module)?;
4457        self.write_global_constants(module, info)?;
4458        self.write_functions(module, info, options, pipeline_options)
4459    }
4460
4461    /// Write the definition for the `DefaultConstructible` class.
4462    ///
4463    /// The [`ReadZeroSkipWrite`] bounds check policy requires us to be able to
4464    /// produce 'zero' values for any type, including structs, arrays, and so
4465    /// on. We could do this by emitting default constructor applications, but
4466    /// that would entail printing the name of the type, which is more trouble
4467    /// than you'd think. Instead, we just construct this magic C++14 class that
4468    /// can be converted to any type that can be default constructed, using
4469    /// template parameter inference to detect which type is needed, so we don't
4470    /// have to figure out the name.
4471    ///
4472    /// [`ReadZeroSkipWrite`]: index::BoundsCheckPolicy::ReadZeroSkipWrite
4473    fn put_default_constructible(&mut self) -> BackendResult {
4474        let tab = back::INDENT;
4475        writeln!(self.out, "struct DefaultConstructible {{")?;
4476        writeln!(self.out, "{tab}template<typename T>")?;
4477        writeln!(self.out, "{tab}operator T() && {{")?;
4478        writeln!(self.out, "{tab}{tab}return T {{}};")?;
4479        writeln!(self.out, "{tab}}}")?;
4480        writeln!(self.out, "}};")?;
4481        Ok(())
4482    }
4483
4484    fn put_ray_query_type(&mut self) -> BackendResult {
4485        let tab = back::INDENT;
4486        writeln!(self.out, "struct {RAY_QUERY_TYPE} {{")?;
4487        let full_type = format!("{RT_NAMESPACE}::intersector<{RT_NAMESPACE}::instancing, {RT_NAMESPACE}::triangle_data, {RT_NAMESPACE}::world_space_data>");
4488        writeln!(self.out, "{tab}{full_type} {RAY_QUERY_FIELD_INTERSECTOR};")?;
4489        writeln!(
4490            self.out,
4491            "{tab}{full_type}::result_type {RAY_QUERY_FIELD_INTERSECTION};"
4492        )?;
4493        writeln!(self.out, "{tab}bool {RAY_QUERY_FIELD_READY} = false;")?;
4494        writeln!(self.out, "}};")?;
4495        writeln!(self.out, "constexpr {NAMESPACE}::uint {RAY_QUERY_FUN_MAP_INTERSECTION}(const {RT_NAMESPACE}::intersection_type ty) {{")?;
4496        let v_triangle = back::RayIntersectionType::Triangle as u32;
4497        let v_bbox = back::RayIntersectionType::BoundingBox as u32;
4498        writeln!(
4499            self.out,
4500            "{tab}return ty=={RT_NAMESPACE}::intersection_type::triangle ? {v_triangle} : "
4501        )?;
4502        writeln!(
4503            self.out,
4504            "{tab}{tab}ty=={RT_NAMESPACE}::intersection_type::bounding_box ? {v_bbox} : 0;"
4505        )?;
4506        writeln!(self.out, "}}")?;
4507        Ok(())
4508    }
4509
4510    fn write_type_defs(&mut self, module: &crate::Module) -> BackendResult {
4511        let mut generated_argument_buffer_wrapper = false;
4512        let mut generated_external_texture_wrapper = false;
4513        for (handle, ty) in module.types.iter() {
4514            match ty.inner {
4515                crate::TypeInner::BindingArray { .. } if !generated_argument_buffer_wrapper => {
4516                    writeln!(self.out, "template <typename T>")?;
4517                    writeln!(self.out, "struct {ARGUMENT_BUFFER_WRAPPER_STRUCT} {{")?;
4518                    writeln!(self.out, "{}T {WRAPPED_ARRAY_FIELD};", back::INDENT)?;
4519                    writeln!(self.out, "}};")?;
4520                    generated_argument_buffer_wrapper = true;
4521                }
4522                crate::TypeInner::Image {
4523                    class: crate::ImageClass::External,
4524                    ..
4525                } if !generated_external_texture_wrapper => {
4526                    let params_ty_name = &self.names
4527                        [&NameKey::Type(module.special_types.external_texture_params.unwrap())];
4528                    writeln!(self.out, "struct {EXTERNAL_TEXTURE_WRAPPER_STRUCT} {{")?;
4529                    writeln!(
4530                        self.out,
4531                        "{}{NAMESPACE}::texture2d<float, {NAMESPACE}::access::sample> plane0;",
4532                        back::INDENT
4533                    )?;
4534                    writeln!(
4535                        self.out,
4536                        "{}{NAMESPACE}::texture2d<float, {NAMESPACE}::access::sample> plane1;",
4537                        back::INDENT
4538                    )?;
4539                    writeln!(
4540                        self.out,
4541                        "{}{NAMESPACE}::texture2d<float, {NAMESPACE}::access::sample> plane2;",
4542                        back::INDENT
4543                    )?;
4544                    writeln!(self.out, "{}{params_ty_name} params;", back::INDENT)?;
4545                    writeln!(self.out, "}};")?;
4546                    generated_external_texture_wrapper = true;
4547                }
4548                _ => {}
4549            }
4550
4551            if !ty.needs_alias() {
4552                continue;
4553            }
4554            let name = &self.names[&NameKey::Type(handle)];
4555            match ty.inner {
4556                // Naga IR can pass around arrays by value, but Metal, following
4557                // C++, performs an array-to-pointer conversion (C++ [conv.array])
4558                // on expressions of array type, so assigning the array by value
4559                // isn't possible. However, Metal *does* assign structs by
4560                // value. So in our Metal output, we wrap all array types in
4561                // synthetic struct types:
4562                //
4563                //     struct type1 {
4564                //         float inner[10]
4565                //     };
4566                //
4567                // Then we carefully include `.inner` (`WRAPPED_ARRAY_FIELD`) in
4568                // any expression that actually wants access to the array.
4569                crate::TypeInner::Array {
4570                    base,
4571                    size,
4572                    stride: _,
4573                } => {
4574                    let base_name = TypeContext {
4575                        handle: base,
4576                        gctx: module.to_ctx(),
4577                        names: &self.names,
4578                        access: crate::StorageAccess::empty(),
4579                        first_time: false,
4580                    };
4581
4582                    match size.resolve(module.to_ctx())? {
4583                        proc::IndexableLength::Known(size) => {
4584                            writeln!(self.out, "struct {name} {{")?;
4585                            writeln!(
4586                                self.out,
4587                                "{}{} {}[{}];",
4588                                back::INDENT,
4589                                base_name,
4590                                WRAPPED_ARRAY_FIELD,
4591                                size
4592                            )?;
4593                            writeln!(self.out, "}};")?;
4594                        }
4595                        proc::IndexableLength::Dynamic => {
4596                            writeln!(self.out, "typedef {base_name} {name}[1];")?;
4597                        }
4598                    }
4599                }
4600                crate::TypeInner::Struct {
4601                    ref members, span, ..
4602                } => {
4603                    writeln!(self.out, "struct {name} {{")?;
4604                    let mut last_offset = 0;
4605                    for (index, member) in members.iter().enumerate() {
4606                        if member.offset > last_offset {
4607                            self.struct_member_pads.insert((handle, index as u32));
4608                            let pad = member.offset - last_offset;
4609                            writeln!(self.out, "{}char _pad{}[{}];", back::INDENT, index, pad)?;
4610                        }
4611                        let ty_inner = &module.types[member.ty].inner;
4612                        last_offset = member.offset + ty_inner.size(module.to_ctx());
4613
4614                        let member_name = &self.names[&NameKey::StructMember(handle, index as u32)];
4615
4616                        // If the member should be packed (as is the case for a misaligned vec3) issue a packed vector
4617                        match should_pack_struct_member(members, span, index, module) {
4618                            Some(scalar) => {
4619                                writeln!(
4620                                    self.out,
4621                                    "{}{}::packed_{}3 {};",
4622                                    back::INDENT,
4623                                    NAMESPACE,
4624                                    scalar.to_msl_name(),
4625                                    member_name
4626                                )?;
4627                            }
4628                            None => {
4629                                let base_name = TypeContext {
4630                                    handle: member.ty,
4631                                    gctx: module.to_ctx(),
4632                                    names: &self.names,
4633                                    access: crate::StorageAccess::empty(),
4634                                    first_time: false,
4635                                };
4636                                writeln!(
4637                                    self.out,
4638                                    "{}{} {};",
4639                                    back::INDENT,
4640                                    base_name,
4641                                    member_name
4642                                )?;
4643
4644                                // for 3-component vectors, add one component
4645                                if let crate::TypeInner::Vector {
4646                                    size: crate::VectorSize::Tri,
4647                                    scalar,
4648                                } = *ty_inner
4649                                {
4650                                    last_offset += scalar.width as u32;
4651                                }
4652                            }
4653                        }
4654                    }
4655                    if last_offset < span {
4656                        let pad = span - last_offset;
4657                        writeln!(
4658                            self.out,
4659                            "{}char _pad{}[{}];",
4660                            back::INDENT,
4661                            members.len(),
4662                            pad
4663                        )?;
4664                    }
4665                    writeln!(self.out, "}};")?;
4666                }
4667                _ => {
4668                    let ty_name = TypeContext {
4669                        handle,
4670                        gctx: module.to_ctx(),
4671                        names: &self.names,
4672                        access: crate::StorageAccess::empty(),
4673                        first_time: true,
4674                    };
4675                    writeln!(self.out, "typedef {ty_name} {name};")?;
4676                }
4677            }
4678        }
4679
4680        // Write functions to create special types.
4681        for (type_key, struct_ty) in module.special_types.predeclared_types.iter() {
4682            match type_key {
4683                &crate::PredeclaredType::ModfResult { size, scalar }
4684                | &crate::PredeclaredType::FrexpResult { size, scalar } => {
4685                    let arg_type_name_owner;
4686                    let arg_type_name = if let Some(size) = size {
4687                        arg_type_name_owner = format!(
4688                            "{NAMESPACE}::{}{}",
4689                            if scalar.width == 8 { "double" } else { "float" },
4690                            size as u8
4691                        );
4692                        &arg_type_name_owner
4693                    } else if scalar.width == 8 {
4694                        "double"
4695                    } else {
4696                        "float"
4697                    };
4698
4699                    let other_type_name_owner;
4700                    let (defined_func_name, called_func_name, other_type_name) =
4701                        if matches!(type_key, &crate::PredeclaredType::ModfResult { .. }) {
4702                            (MODF_FUNCTION, "modf", arg_type_name)
4703                        } else {
4704                            let other_type_name = if let Some(size) = size {
4705                                other_type_name_owner = format!("int{}", size as u8);
4706                                &other_type_name_owner
4707                            } else {
4708                                "int"
4709                            };
4710                            (FREXP_FUNCTION, "frexp", other_type_name)
4711                        };
4712
4713                    let struct_name = &self.names[&NameKey::Type(*struct_ty)];
4714
4715                    writeln!(self.out)?;
4716                    writeln!(
4717                        self.out,
4718                        "{struct_name} {defined_func_name}({arg_type_name} arg) {{
4719    {other_type_name} other;
4720    {arg_type_name} fract = {NAMESPACE}::{called_func_name}(arg, other);
4721    return {struct_name}{{ fract, other }};
4722}}"
4723                    )?;
4724                }
4725                &crate::PredeclaredType::AtomicCompareExchangeWeakResult(scalar) => {
4726                    let arg_type_name = scalar.to_msl_name();
4727                    let called_func_name = "atomic_compare_exchange_weak_explicit";
4728                    let defined_func_name = ATOMIC_COMP_EXCH_FUNCTION;
4729                    let struct_name = &self.names[&NameKey::Type(*struct_ty)];
4730
4731                    writeln!(self.out)?;
4732
4733                    for address_space_name in ["device", "threadgroup"] {
4734                        writeln!(
4735                            self.out,
4736                            "\
4737template <typename A>
4738{struct_name} {defined_func_name}(
4739    {address_space_name} A *atomic_ptr,
4740    {arg_type_name} cmp,
4741    {arg_type_name} v
4742) {{
4743    bool swapped = {NAMESPACE}::{called_func_name}(
4744        atomic_ptr, &cmp, v,
4745        metal::memory_order_relaxed, metal::memory_order_relaxed
4746    );
4747    return {struct_name}{{cmp, swapped}};
4748}}"
4749                        )?;
4750                    }
4751                }
4752            }
4753        }
4754
4755        Ok(())
4756    }
4757
4758    /// Writes all named constants
4759    fn write_global_constants(
4760        &mut self,
4761        module: &crate::Module,
4762        mod_info: &valid::ModuleInfo,
4763    ) -> BackendResult {
4764        let constants = module.constants.iter().filter(|&(_, c)| c.name.is_some());
4765
4766        for (handle, constant) in constants {
4767            let ty_name = TypeContext {
4768                handle: constant.ty,
4769                gctx: module.to_ctx(),
4770                names: &self.names,
4771                access: crate::StorageAccess::empty(),
4772                first_time: false,
4773            };
4774            let name = &self.names[&NameKey::Constant(handle)];
4775            write!(self.out, "constant {ty_name} {name} = ")?;
4776            self.put_const_expression(constant.init, module, mod_info, &module.global_expressions)?;
4777            writeln!(self.out, ";")?;
4778        }
4779
4780        Ok(())
4781    }
4782
4783    fn put_inline_sampler_properties(
4784        &mut self,
4785        level: back::Level,
4786        sampler: &sm::InlineSampler,
4787    ) -> BackendResult {
4788        for (&letter, address) in ['s', 't', 'r'].iter().zip(sampler.address.iter()) {
4789            writeln!(
4790                self.out,
4791                "{}{}::{}_address::{},",
4792                level,
4793                NAMESPACE,
4794                letter,
4795                address.as_str(),
4796            )?;
4797        }
4798        writeln!(
4799            self.out,
4800            "{}{}::mag_filter::{},",
4801            level,
4802            NAMESPACE,
4803            sampler.mag_filter.as_str(),
4804        )?;
4805        writeln!(
4806            self.out,
4807            "{}{}::min_filter::{},",
4808            level,
4809            NAMESPACE,
4810            sampler.min_filter.as_str(),
4811        )?;
4812        if let Some(filter) = sampler.mip_filter {
4813            writeln!(
4814                self.out,
4815                "{}{}::mip_filter::{},",
4816                level,
4817                NAMESPACE,
4818                filter.as_str(),
4819            )?;
4820        }
4821        // avoid setting it on platforms that don't support it
4822        if sampler.border_color != sm::BorderColor::TransparentBlack {
4823            writeln!(
4824                self.out,
4825                "{}{}::border_color::{},",
4826                level,
4827                NAMESPACE,
4828                sampler.border_color.as_str(),
4829            )?;
4830        }
4831        //TODO: I'm not able to feed this in a way that MSL likes:
4832        //>error: use of undeclared identifier 'lod_clamp'
4833        //>error: no member named 'max_anisotropy' in namespace 'metal'
4834        if false {
4835            if let Some(ref lod) = sampler.lod_clamp {
4836                writeln!(self.out, "{}lod_clamp({},{}),", level, lod.start, lod.end,)?;
4837            }
4838            if let Some(aniso) = sampler.max_anisotropy {
4839                writeln!(self.out, "{}max_anisotropy({}),", level, aniso.get(),)?;
4840            }
4841        }
4842        if sampler.compare_func != sm::CompareFunc::Never {
4843            writeln!(
4844                self.out,
4845                "{}{}::compare_func::{},",
4846                level,
4847                NAMESPACE,
4848                sampler.compare_func.as_str(),
4849            )?;
4850        }
4851        writeln!(
4852            self.out,
4853            "{}{}::coord::{}",
4854            level,
4855            NAMESPACE,
4856            sampler.coord.as_str()
4857        )?;
4858        Ok(())
4859    }
4860
4861    fn write_unpacking_function(
4862        &mut self,
4863        format: back::msl::VertexFormat,
4864    ) -> Result<(String, u32, u32), Error> {
4865        use back::msl::VertexFormat::*;
4866        match format {
4867            Uint8 => {
4868                let name = self.namer.call("unpackUint8");
4869                writeln!(self.out, "uint {name}(metal::uchar b0) {{")?;
4870                writeln!(self.out, "{}return uint(b0);", back::INDENT)?;
4871                writeln!(self.out, "}}")?;
4872                Ok((name, 1, 1))
4873            }
4874            Uint8x2 => {
4875                let name = self.namer.call("unpackUint8x2");
4876                writeln!(
4877                    self.out,
4878                    "metal::uint2 {name}(metal::uchar b0, \
4879                                         metal::uchar b1) {{"
4880                )?;
4881                writeln!(self.out, "{}return metal::uint2(b0, b1);", back::INDENT)?;
4882                writeln!(self.out, "}}")?;
4883                Ok((name, 2, 2))
4884            }
4885            Uint8x4 => {
4886                let name = self.namer.call("unpackUint8x4");
4887                writeln!(
4888                    self.out,
4889                    "metal::uint4 {name}(metal::uchar b0, \
4890                                         metal::uchar b1, \
4891                                         metal::uchar b2, \
4892                                         metal::uchar b3) {{"
4893                )?;
4894                writeln!(
4895                    self.out,
4896                    "{}return metal::uint4(b0, b1, b2, b3);",
4897                    back::INDENT
4898                )?;
4899                writeln!(self.out, "}}")?;
4900                Ok((name, 4, 4))
4901            }
4902            Sint8 => {
4903                let name = self.namer.call("unpackSint8");
4904                writeln!(self.out, "int {name}(metal::uchar b0) {{")?;
4905                writeln!(self.out, "{}return int(as_type<char>(b0));", back::INDENT)?;
4906                writeln!(self.out, "}}")?;
4907                Ok((name, 1, 1))
4908            }
4909            Sint8x2 => {
4910                let name = self.namer.call("unpackSint8x2");
4911                writeln!(
4912                    self.out,
4913                    "metal::int2 {name}(metal::uchar b0, \
4914                                        metal::uchar b1) {{"
4915                )?;
4916                writeln!(
4917                    self.out,
4918                    "{}return metal::int2(as_type<char>(b0), \
4919                                          as_type<char>(b1));",
4920                    back::INDENT
4921                )?;
4922                writeln!(self.out, "}}")?;
4923                Ok((name, 2, 2))
4924            }
4925            Sint8x4 => {
4926                let name = self.namer.call("unpackSint8x4");
4927                writeln!(
4928                    self.out,
4929                    "metal::int4 {name}(metal::uchar b0, \
4930                                        metal::uchar b1, \
4931                                        metal::uchar b2, \
4932                                        metal::uchar b3) {{"
4933                )?;
4934                writeln!(
4935                    self.out,
4936                    "{}return metal::int4(as_type<char>(b0), \
4937                                          as_type<char>(b1), \
4938                                          as_type<char>(b2), \
4939                                          as_type<char>(b3));",
4940                    back::INDENT
4941                )?;
4942                writeln!(self.out, "}}")?;
4943                Ok((name, 4, 4))
4944            }
4945            Unorm8 => {
4946                let name = self.namer.call("unpackUnorm8");
4947                writeln!(self.out, "float {name}(metal::uchar b0) {{")?;
4948                writeln!(
4949                    self.out,
4950                    "{}return float(float(b0) / 255.0f);",
4951                    back::INDENT
4952                )?;
4953                writeln!(self.out, "}}")?;
4954                Ok((name, 1, 1))
4955            }
4956            Unorm8x2 => {
4957                let name = self.namer.call("unpackUnorm8x2");
4958                writeln!(
4959                    self.out,
4960                    "metal::float2 {name}(metal::uchar b0, \
4961                                          metal::uchar b1) {{"
4962                )?;
4963                writeln!(
4964                    self.out,
4965                    "{}return metal::float2(float(b0) / 255.0f, \
4966                                            float(b1) / 255.0f);",
4967                    back::INDENT
4968                )?;
4969                writeln!(self.out, "}}")?;
4970                Ok((name, 2, 2))
4971            }
4972            Unorm8x4 => {
4973                let name = self.namer.call("unpackUnorm8x4");
4974                writeln!(
4975                    self.out,
4976                    "metal::float4 {name}(metal::uchar b0, \
4977                                          metal::uchar b1, \
4978                                          metal::uchar b2, \
4979                                          metal::uchar b3) {{"
4980                )?;
4981                writeln!(
4982                    self.out,
4983                    "{}return metal::float4(float(b0) / 255.0f, \
4984                                            float(b1) / 255.0f, \
4985                                            float(b2) / 255.0f, \
4986                                            float(b3) / 255.0f);",
4987                    back::INDENT
4988                )?;
4989                writeln!(self.out, "}}")?;
4990                Ok((name, 4, 4))
4991            }
4992            Snorm8 => {
4993                let name = self.namer.call("unpackSnorm8");
4994                writeln!(self.out, "float {name}(metal::uchar b0) {{")?;
4995                writeln!(
4996                    self.out,
4997                    "{}return float(metal::max(-1.0f, as_type<char>(b0) / 127.0f));",
4998                    back::INDENT
4999                )?;
5000                writeln!(self.out, "}}")?;
5001                Ok((name, 1, 1))
5002            }
5003            Snorm8x2 => {
5004                let name = self.namer.call("unpackSnorm8x2");
5005                writeln!(
5006                    self.out,
5007                    "metal::float2 {name}(metal::uchar b0, \
5008                                          metal::uchar b1) {{"
5009                )?;
5010                writeln!(
5011                    self.out,
5012                    "{}return metal::float2(metal::max(-1.0f, as_type<char>(b0) / 127.0f), \
5013                                            metal::max(-1.0f, as_type<char>(b1) / 127.0f));",
5014                    back::INDENT
5015                )?;
5016                writeln!(self.out, "}}")?;
5017                Ok((name, 2, 2))
5018            }
5019            Snorm8x4 => {
5020                let name = self.namer.call("unpackSnorm8x4");
5021                writeln!(
5022                    self.out,
5023                    "metal::float4 {name}(metal::uchar b0, \
5024                                          metal::uchar b1, \
5025                                          metal::uchar b2, \
5026                                          metal::uchar b3) {{"
5027                )?;
5028                writeln!(
5029                    self.out,
5030                    "{}return metal::float4(metal::max(-1.0f, as_type<char>(b0) / 127.0f), \
5031                                            metal::max(-1.0f, as_type<char>(b1) / 127.0f), \
5032                                            metal::max(-1.0f, as_type<char>(b2) / 127.0f), \
5033                                            metal::max(-1.0f, as_type<char>(b3) / 127.0f));",
5034                    back::INDENT
5035                )?;
5036                writeln!(self.out, "}}")?;
5037                Ok((name, 4, 4))
5038            }
5039            Uint16 => {
5040                let name = self.namer.call("unpackUint16");
5041                writeln!(
5042                    self.out,
5043                    "metal::uint {name}(metal::uint b0, \
5044                                        metal::uint b1) {{"
5045                )?;
5046                writeln!(
5047                    self.out,
5048                    "{}return metal::uint(b1 << 8 | b0);",
5049                    back::INDENT
5050                )?;
5051                writeln!(self.out, "}}")?;
5052                Ok((name, 2, 1))
5053            }
5054            Uint16x2 => {
5055                let name = self.namer.call("unpackUint16x2");
5056                writeln!(
5057                    self.out,
5058                    "metal::uint2 {name}(metal::uint b0, \
5059                                         metal::uint b1, \
5060                                         metal::uint b2, \
5061                                         metal::uint b3) {{"
5062                )?;
5063                writeln!(
5064                    self.out,
5065                    "{}return metal::uint2(b1 << 8 | b0, \
5066                                           b3 << 8 | b2);",
5067                    back::INDENT
5068                )?;
5069                writeln!(self.out, "}}")?;
5070                Ok((name, 4, 2))
5071            }
5072            Uint16x4 => {
5073                let name = self.namer.call("unpackUint16x4");
5074                writeln!(
5075                    self.out,
5076                    "metal::uint4 {name}(metal::uint b0, \
5077                                         metal::uint b1, \
5078                                         metal::uint b2, \
5079                                         metal::uint b3, \
5080                                         metal::uint b4, \
5081                                         metal::uint b5, \
5082                                         metal::uint b6, \
5083                                         metal::uint b7) {{"
5084                )?;
5085                writeln!(
5086                    self.out,
5087                    "{}return metal::uint4(b1 << 8 | b0, \
5088                                           b3 << 8 | b2, \
5089                                           b5 << 8 | b4, \
5090                                           b7 << 8 | b6);",
5091                    back::INDENT
5092                )?;
5093                writeln!(self.out, "}}")?;
5094                Ok((name, 8, 4))
5095            }
5096            Sint16 => {
5097                let name = self.namer.call("unpackSint16");
5098                writeln!(
5099                    self.out,
5100                    "int {name}(metal::ushort b0, \
5101                                metal::ushort b1) {{"
5102                )?;
5103                writeln!(
5104                    self.out,
5105                    "{}return int(as_type<short>(metal::ushort(b1 << 8 | b0)));",
5106                    back::INDENT
5107                )?;
5108                writeln!(self.out, "}}")?;
5109                Ok((name, 2, 1))
5110            }
5111            Sint16x2 => {
5112                let name = self.namer.call("unpackSint16x2");
5113                writeln!(
5114                    self.out,
5115                    "metal::int2 {name}(metal::ushort b0, \
5116                                        metal::ushort b1, \
5117                                        metal::ushort b2, \
5118                                        metal::ushort b3) {{"
5119                )?;
5120                writeln!(
5121                    self.out,
5122                    "{}return metal::int2(as_type<short>(metal::ushort(b1 << 8 | b0)), \
5123                                          as_type<short>(metal::ushort(b3 << 8 | b2)));",
5124                    back::INDENT
5125                )?;
5126                writeln!(self.out, "}}")?;
5127                Ok((name, 4, 2))
5128            }
5129            Sint16x4 => {
5130                let name = self.namer.call("unpackSint16x4");
5131                writeln!(
5132                    self.out,
5133                    "metal::int4 {name}(metal::ushort b0, \
5134                                        metal::ushort b1, \
5135                                        metal::ushort b2, \
5136                                        metal::ushort b3, \
5137                                        metal::ushort b4, \
5138                                        metal::ushort b5, \
5139                                        metal::ushort b6, \
5140                                        metal::ushort b7) {{"
5141                )?;
5142                writeln!(
5143                    self.out,
5144                    "{}return metal::int4(as_type<short>(metal::ushort(b1 << 8 | b0)), \
5145                                          as_type<short>(metal::ushort(b3 << 8 | b2)), \
5146                                          as_type<short>(metal::ushort(b5 << 8 | b4)), \
5147                                          as_type<short>(metal::ushort(b7 << 8 | b6)));",
5148                    back::INDENT
5149                )?;
5150                writeln!(self.out, "}}")?;
5151                Ok((name, 8, 4))
5152            }
5153            Unorm16 => {
5154                let name = self.namer.call("unpackUnorm16");
5155                writeln!(
5156                    self.out,
5157                    "float {name}(metal::ushort b0, \
5158                                  metal::ushort b1) {{"
5159                )?;
5160                writeln!(
5161                    self.out,
5162                    "{}return float(float(b1 << 8 | b0) / 65535.0f);",
5163                    back::INDENT
5164                )?;
5165                writeln!(self.out, "}}")?;
5166                Ok((name, 2, 1))
5167            }
5168            Unorm16x2 => {
5169                let name = self.namer.call("unpackUnorm16x2");
5170                writeln!(
5171                    self.out,
5172                    "metal::float2 {name}(metal::ushort b0, \
5173                                          metal::ushort b1, \
5174                                          metal::ushort b2, \
5175                                          metal::ushort b3) {{"
5176                )?;
5177                writeln!(
5178                    self.out,
5179                    "{}return metal::float2(float(b1 << 8 | b0) / 65535.0f, \
5180                                            float(b3 << 8 | b2) / 65535.0f);",
5181                    back::INDENT
5182                )?;
5183                writeln!(self.out, "}}")?;
5184                Ok((name, 4, 2))
5185            }
5186            Unorm16x4 => {
5187                let name = self.namer.call("unpackUnorm16x4");
5188                writeln!(
5189                    self.out,
5190                    "metal::float4 {name}(metal::ushort b0, \
5191                                          metal::ushort b1, \
5192                                          metal::ushort b2, \
5193                                          metal::ushort b3, \
5194                                          metal::ushort b4, \
5195                                          metal::ushort b5, \
5196                                          metal::ushort b6, \
5197                                          metal::ushort b7) {{"
5198                )?;
5199                writeln!(
5200                    self.out,
5201                    "{}return metal::float4(float(b1 << 8 | b0) / 65535.0f, \
5202                                            float(b3 << 8 | b2) / 65535.0f, \
5203                                            float(b5 << 8 | b4) / 65535.0f, \
5204                                            float(b7 << 8 | b6) / 65535.0f);",
5205                    back::INDENT
5206                )?;
5207                writeln!(self.out, "}}")?;
5208                Ok((name, 8, 4))
5209            }
5210            Snorm16 => {
5211                let name = self.namer.call("unpackSnorm16");
5212                writeln!(
5213                    self.out,
5214                    "float {name}(metal::ushort b0, \
5215                                  metal::ushort b1) {{"
5216                )?;
5217                writeln!(
5218                    self.out,
5219                    "{}return metal::unpack_snorm2x16_to_float(b1 << 8 | b0).x;",
5220                    back::INDENT
5221                )?;
5222                writeln!(self.out, "}}")?;
5223                Ok((name, 2, 1))
5224            }
5225            Snorm16x2 => {
5226                let name = self.namer.call("unpackSnorm16x2");
5227                writeln!(
5228                    self.out,
5229                    "metal::float2 {name}(uint b0, \
5230                                          uint b1, \
5231                                          uint b2, \
5232                                          uint b3) {{"
5233                )?;
5234                writeln!(
5235                    self.out,
5236                    "{}return metal::unpack_snorm2x16_to_float(b3 << 24 | b2 << 16 | b1 << 8 | b0);",
5237                    back::INDENT
5238                )?;
5239                writeln!(self.out, "}}")?;
5240                Ok((name, 4, 2))
5241            }
5242            Snorm16x4 => {
5243                let name = self.namer.call("unpackSnorm16x4");
5244                writeln!(
5245                    self.out,
5246                    "metal::float4 {name}(uint b0, \
5247                                          uint b1, \
5248                                          uint b2, \
5249                                          uint b3, \
5250                                          uint b4, \
5251                                          uint b5, \
5252                                          uint b6, \
5253                                          uint b7) {{"
5254                )?;
5255                writeln!(
5256                    self.out,
5257                    "{}return metal::float4(metal::unpack_snorm2x16_to_float(b3 << 24 | b2 << 16 | b1 << 8 | b0), \
5258                                            metal::unpack_snorm2x16_to_float(b7 << 24 | b6 << 16 | b5 << 8 | b4));",
5259                    back::INDENT
5260                )?;
5261                writeln!(self.out, "}}")?;
5262                Ok((name, 8, 4))
5263            }
5264            Float16 => {
5265                let name = self.namer.call("unpackFloat16");
5266                writeln!(
5267                    self.out,
5268                    "float {name}(metal::ushort b0, \
5269                                  metal::ushort b1) {{"
5270                )?;
5271                writeln!(
5272                    self.out,
5273                    "{}return float(as_type<half>(metal::ushort(b1 << 8 | b0)));",
5274                    back::INDENT
5275                )?;
5276                writeln!(self.out, "}}")?;
5277                Ok((name, 2, 1))
5278            }
5279            Float16x2 => {
5280                let name = self.namer.call("unpackFloat16x2");
5281                writeln!(
5282                    self.out,
5283                    "metal::float2 {name}(metal::ushort b0, \
5284                                          metal::ushort b1, \
5285                                          metal::ushort b2, \
5286                                          metal::ushort b3) {{"
5287                )?;
5288                writeln!(
5289                    self.out,
5290                    "{}return metal::float2(as_type<half>(metal::ushort(b1 << 8 | b0)), \
5291                                            as_type<half>(metal::ushort(b3 << 8 | b2)));",
5292                    back::INDENT
5293                )?;
5294                writeln!(self.out, "}}")?;
5295                Ok((name, 4, 2))
5296            }
5297            Float16x4 => {
5298                let name = self.namer.call("unpackFloat16x4");
5299                writeln!(
5300                    self.out,
5301                    "metal::float4 {name}(metal::ushort b0, \
5302                                        metal::ushort b1, \
5303                                        metal::ushort b2, \
5304                                        metal::ushort b3, \
5305                                        metal::ushort b4, \
5306                                        metal::ushort b5, \
5307                                        metal::ushort b6, \
5308                                        metal::ushort b7) {{"
5309                )?;
5310                writeln!(
5311                    self.out,
5312                    "{}return metal::float4(as_type<half>(metal::ushort(b1 << 8 | b0)), \
5313                                          as_type<half>(metal::ushort(b3 << 8 | b2)), \
5314                                          as_type<half>(metal::ushort(b5 << 8 | b4)), \
5315                                          as_type<half>(metal::ushort(b7 << 8 | b6)));",
5316                    back::INDENT
5317                )?;
5318                writeln!(self.out, "}}")?;
5319                Ok((name, 8, 4))
5320            }
5321            Float32 => {
5322                let name = self.namer.call("unpackFloat32");
5323                writeln!(
5324                    self.out,
5325                    "float {name}(uint b0, \
5326                                  uint b1, \
5327                                  uint b2, \
5328                                  uint b3) {{"
5329                )?;
5330                writeln!(
5331                    self.out,
5332                    "{}return as_type<float>(b3 << 24 | b2 << 16 | b1 << 8 | b0);",
5333                    back::INDENT
5334                )?;
5335                writeln!(self.out, "}}")?;
5336                Ok((name, 4, 1))
5337            }
5338            Float32x2 => {
5339                let name = self.namer.call("unpackFloat32x2");
5340                writeln!(
5341                    self.out,
5342                    "metal::float2 {name}(uint b0, \
5343                                          uint b1, \
5344                                          uint b2, \
5345                                          uint b3, \
5346                                          uint b4, \
5347                                          uint b5, \
5348                                          uint b6, \
5349                                          uint b7) {{"
5350                )?;
5351                writeln!(
5352                    self.out,
5353                    "{}return metal::float2(as_type<float>(b3 << 24 | b2 << 16 | b1 << 8 | b0), \
5354                                            as_type<float>(b7 << 24 | b6 << 16 | b5 << 8 | b4));",
5355                    back::INDENT
5356                )?;
5357                writeln!(self.out, "}}")?;
5358                Ok((name, 8, 2))
5359            }
5360            Float32x3 => {
5361                let name = self.namer.call("unpackFloat32x3");
5362                writeln!(
5363                    self.out,
5364                    "metal::float3 {name}(uint b0, \
5365                                          uint b1, \
5366                                          uint b2, \
5367                                          uint b3, \
5368                                          uint b4, \
5369                                          uint b5, \
5370                                          uint b6, \
5371                                          uint b7, \
5372                                          uint b8, \
5373                                          uint b9, \
5374                                          uint b10, \
5375                                          uint b11) {{"
5376                )?;
5377                writeln!(
5378                    self.out,
5379                    "{}return metal::float3(as_type<float>(b3 << 24 | b2 << 16 | b1 << 8 | b0), \
5380                                            as_type<float>(b7 << 24 | b6 << 16 | b5 << 8 | b4), \
5381                                            as_type<float>(b11 << 24 | b10 << 16 | b9 << 8 | b8));",
5382                    back::INDENT
5383                )?;
5384                writeln!(self.out, "}}")?;
5385                Ok((name, 12, 3))
5386            }
5387            Float32x4 => {
5388                let name = self.namer.call("unpackFloat32x4");
5389                writeln!(
5390                    self.out,
5391                    "metal::float4 {name}(uint b0, \
5392                                          uint b1, \
5393                                          uint b2, \
5394                                          uint b3, \
5395                                          uint b4, \
5396                                          uint b5, \
5397                                          uint b6, \
5398                                          uint b7, \
5399                                          uint b8, \
5400                                          uint b9, \
5401                                          uint b10, \
5402                                          uint b11, \
5403                                          uint b12, \
5404                                          uint b13, \
5405                                          uint b14, \
5406                                          uint b15) {{"
5407                )?;
5408                writeln!(
5409                    self.out,
5410                    "{}return metal::float4(as_type<float>(b3 << 24 | b2 << 16 | b1 << 8 | b0), \
5411                                            as_type<float>(b7 << 24 | b6 << 16 | b5 << 8 | b4), \
5412                                            as_type<float>(b11 << 24 | b10 << 16 | b9 << 8 | b8), \
5413                                            as_type<float>(b15 << 24 | b14 << 16 | b13 << 8 | b12));",
5414                    back::INDENT
5415                )?;
5416                writeln!(self.out, "}}")?;
5417                Ok((name, 16, 4))
5418            }
5419            Uint32 => {
5420                let name = self.namer.call("unpackUint32");
5421                writeln!(
5422                    self.out,
5423                    "uint {name}(uint b0, \
5424                                 uint b1, \
5425                                 uint b2, \
5426                                 uint b3) {{"
5427                )?;
5428                writeln!(
5429                    self.out,
5430                    "{}return (b3 << 24 | b2 << 16 | b1 << 8 | b0);",
5431                    back::INDENT
5432                )?;
5433                writeln!(self.out, "}}")?;
5434                Ok((name, 4, 1))
5435            }
5436            Uint32x2 => {
5437                let name = self.namer.call("unpackUint32x2");
5438                writeln!(
5439                    self.out,
5440                    "uint2 {name}(uint b0, \
5441                                  uint b1, \
5442                                  uint b2, \
5443                                  uint b3, \
5444                                  uint b4, \
5445                                  uint b5, \
5446                                  uint b6, \
5447                                  uint b7) {{"
5448                )?;
5449                writeln!(
5450                    self.out,
5451                    "{}return uint2((b3 << 24 | b2 << 16 | b1 << 8 | b0), \
5452                                    (b7 << 24 | b6 << 16 | b5 << 8 | b4));",
5453                    back::INDENT
5454                )?;
5455                writeln!(self.out, "}}")?;
5456                Ok((name, 8, 2))
5457            }
5458            Uint32x3 => {
5459                let name = self.namer.call("unpackUint32x3");
5460                writeln!(
5461                    self.out,
5462                    "uint3 {name}(uint b0, \
5463                                  uint b1, \
5464                                  uint b2, \
5465                                  uint b3, \
5466                                  uint b4, \
5467                                  uint b5, \
5468                                  uint b6, \
5469                                  uint b7, \
5470                                  uint b8, \
5471                                  uint b9, \
5472                                  uint b10, \
5473                                  uint b11) {{"
5474                )?;
5475                writeln!(
5476                    self.out,
5477                    "{}return uint3((b3 << 24 | b2 << 16 | b1 << 8 | b0), \
5478                                    (b7 << 24 | b6 << 16 | b5 << 8 | b4), \
5479                                    (b11 << 24 | b10 << 16 | b9 << 8 | b8));",
5480                    back::INDENT
5481                )?;
5482                writeln!(self.out, "}}")?;
5483                Ok((name, 12, 3))
5484            }
5485            Uint32x4 => {
5486                let name = self.namer.call("unpackUint32x4");
5487                writeln!(
5488                    self.out,
5489                    "{NAMESPACE}::uint4 {name}(uint b0, \
5490                                  uint b1, \
5491                                  uint b2, \
5492                                  uint b3, \
5493                                  uint b4, \
5494                                  uint b5, \
5495                                  uint b6, \
5496                                  uint b7, \
5497                                  uint b8, \
5498                                  uint b9, \
5499                                  uint b10, \
5500                                  uint b11, \
5501                                  uint b12, \
5502                                  uint b13, \
5503                                  uint b14, \
5504                                  uint b15) {{"
5505                )?;
5506                writeln!(
5507                    self.out,
5508                    "{}return {NAMESPACE}::uint4((b3 << 24 | b2 << 16 | b1 << 8 | b0), \
5509                                    (b7 << 24 | b6 << 16 | b5 << 8 | b4), \
5510                                    (b11 << 24 | b10 << 16 | b9 << 8 | b8), \
5511                                    (b15 << 24 | b14 << 16 | b13 << 8 | b12));",
5512                    back::INDENT
5513                )?;
5514                writeln!(self.out, "}}")?;
5515                Ok((name, 16, 4))
5516            }
5517            Sint32 => {
5518                let name = self.namer.call("unpackSint32");
5519                writeln!(
5520                    self.out,
5521                    "int {name}(uint b0, \
5522                                uint b1, \
5523                                uint b2, \
5524                                uint b3) {{"
5525                )?;
5526                writeln!(
5527                    self.out,
5528                    "{}return as_type<int>(b3 << 24 | b2 << 16 | b1 << 8 | b0);",
5529                    back::INDENT
5530                )?;
5531                writeln!(self.out, "}}")?;
5532                Ok((name, 4, 1))
5533            }
5534            Sint32x2 => {
5535                let name = self.namer.call("unpackSint32x2");
5536                writeln!(
5537                    self.out,
5538                    "metal::int2 {name}(uint b0, \
5539                                        uint b1, \
5540                                        uint b2, \
5541                                        uint b3, \
5542                                        uint b4, \
5543                                        uint b5, \
5544                                        uint b6, \
5545                                        uint b7) {{"
5546                )?;
5547                writeln!(
5548                    self.out,
5549                    "{}return metal::int2(as_type<int>(b3 << 24 | b2 << 16 | b1 << 8 | b0), \
5550                                          as_type<int>(b7 << 24 | b6 << 16 | b5 << 8 | b4));",
5551                    back::INDENT
5552                )?;
5553                writeln!(self.out, "}}")?;
5554                Ok((name, 8, 2))
5555            }
5556            Sint32x3 => {
5557                let name = self.namer.call("unpackSint32x3");
5558                writeln!(
5559                    self.out,
5560                    "metal::int3 {name}(uint b0, \
5561                                        uint b1, \
5562                                        uint b2, \
5563                                        uint b3, \
5564                                        uint b4, \
5565                                        uint b5, \
5566                                        uint b6, \
5567                                        uint b7, \
5568                                        uint b8, \
5569                                        uint b9, \
5570                                        uint b10, \
5571                                        uint b11) {{"
5572                )?;
5573                writeln!(
5574                    self.out,
5575                    "{}return metal::int3(as_type<int>(b3 << 24 | b2 << 16 | b1 << 8 | b0), \
5576                                          as_type<int>(b7 << 24 | b6 << 16 | b5 << 8 | b4), \
5577                                          as_type<int>(b11 << 24 | b10 << 16 | b9 << 8 | b8));",
5578                    back::INDENT
5579                )?;
5580                writeln!(self.out, "}}")?;
5581                Ok((name, 12, 3))
5582            }
5583            Sint32x4 => {
5584                let name = self.namer.call("unpackSint32x4");
5585                writeln!(
5586                    self.out,
5587                    "metal::int4 {name}(uint b0, \
5588                                        uint b1, \
5589                                        uint b2, \
5590                                        uint b3, \
5591                                        uint b4, \
5592                                        uint b5, \
5593                                        uint b6, \
5594                                        uint b7, \
5595                                        uint b8, \
5596                                        uint b9, \
5597                                        uint b10, \
5598                                        uint b11, \
5599                                        uint b12, \
5600                                        uint b13, \
5601                                        uint b14, \
5602                                        uint b15) {{"
5603                )?;
5604                writeln!(
5605                    self.out,
5606                    "{}return metal::int4(as_type<int>(b3 << 24 | b2 << 16 | b1 << 8 | b0), \
5607                                          as_type<int>(b7 << 24 | b6 << 16 | b5 << 8 | b4), \
5608                                          as_type<int>(b11 << 24 | b10 << 16 | b9 << 8 | b8), \
5609                                          as_type<int>(b15 << 24 | b14 << 16 | b13 << 8 | b12));",
5610                    back::INDENT
5611                )?;
5612                writeln!(self.out, "}}")?;
5613                Ok((name, 16, 4))
5614            }
5615            Unorm10_10_10_2 => {
5616                let name = self.namer.call("unpackUnorm10_10_10_2");
5617                writeln!(
5618                    self.out,
5619                    "metal::float4 {name}(uint b0, \
5620                                          uint b1, \
5621                                          uint b2, \
5622                                          uint b3) {{"
5623                )?;
5624                writeln!(
5625                    self.out,
5626                    // The following is correct for RGBA packing, but our format seems to
5627                    // match ABGR, which can be fed into the Metal builtin function
5628                    // unpack_unorm10a2_to_float.
5629                    /*
5630                    "{}uint v = (b3 << 24 | b2 << 16 | b1 << 8 | b0); \
5631                       uint r = (v & 0xFFC00000) >> 22; \
5632                       uint g = (v & 0x003FF000) >> 12; \
5633                       uint b = (v & 0x00000FFC) >> 2; \
5634                       uint a = (v & 0x00000003); \
5635                       return metal::float4(float(r) / 1023.0f, float(g) / 1023.0f, float(b) / 1023.0f, float(a) / 3.0f);",
5636                    */
5637                    "{}return metal::unpack_unorm10a2_to_float(b3 << 24 | b2 << 16 | b1 << 8 | b0);",
5638                    back::INDENT
5639                )?;
5640                writeln!(self.out, "}}")?;
5641                Ok((name, 4, 4))
5642            }
5643            Unorm8x4Bgra => {
5644                let name = self.namer.call("unpackUnorm8x4Bgra");
5645                writeln!(
5646                    self.out,
5647                    "metal::float4 {name}(metal::uchar b0, \
5648                                          metal::uchar b1, \
5649                                          metal::uchar b2, \
5650                                          metal::uchar b3) {{"
5651                )?;
5652                writeln!(
5653                    self.out,
5654                    "{}return metal::float4(float(b2) / 255.0f, \
5655                                            float(b1) / 255.0f, \
5656                                            float(b0) / 255.0f, \
5657                                            float(b3) / 255.0f);",
5658                    back::INDENT
5659                )?;
5660                writeln!(self.out, "}}")?;
5661                Ok((name, 4, 4))
5662            }
5663        }
5664    }
5665
5666    fn write_wrapped_unary_op(
5667        &mut self,
5668        module: &crate::Module,
5669        func_ctx: &back::FunctionCtx,
5670        op: crate::UnaryOperator,
5671        operand: Handle<crate::Expression>,
5672    ) -> BackendResult {
5673        let operand_ty = func_ctx.resolve_type(operand, &module.types);
5674        match op {
5675            // Negating the TYPE_MIN of a two's complement signed integer
5676            // type causes overflow, which is undefined behaviour in MSL. To
5677            // avoid this we bitcast the value to unsigned and negate it,
5678            // then bitcast back to signed.
5679            // This adheres to the WGSL spec in that the negative of the
5680            // type's minimum value should equal to the minimum value.
5681            crate::UnaryOperator::Negate
5682                if operand_ty.scalar_kind() == Some(crate::ScalarKind::Sint) =>
5683            {
5684                let Some((vector_size, scalar)) = operand_ty.vector_size_and_scalar() else {
5685                    return Ok(());
5686                };
5687                let wrapped = WrappedFunction::UnaryOp {
5688                    op,
5689                    ty: (vector_size, scalar),
5690                };
5691                if !self.wrapped_functions.insert(wrapped) {
5692                    return Ok(());
5693                }
5694
5695                let unsigned_scalar = crate::Scalar {
5696                    kind: crate::ScalarKind::Uint,
5697                    ..scalar
5698                };
5699                let mut type_name = String::new();
5700                let mut unsigned_type_name = String::new();
5701                match vector_size {
5702                    None => {
5703                        put_numeric_type(&mut type_name, scalar, &[])?;
5704                        put_numeric_type(&mut unsigned_type_name, unsigned_scalar, &[])?
5705                    }
5706                    Some(size) => {
5707                        put_numeric_type(&mut type_name, scalar, &[size])?;
5708                        put_numeric_type(&mut unsigned_type_name, unsigned_scalar, &[size])?;
5709                    }
5710                };
5711
5712                writeln!(self.out, "{type_name} {NEG_FUNCTION}({type_name} val) {{")?;
5713                let level = back::Level(1);
5714                writeln!(
5715                    self.out,
5716                    "{level}return as_type<{type_name}>(-as_type<{unsigned_type_name}>(val));"
5717                )?;
5718                writeln!(self.out, "}}")?;
5719                writeln!(self.out)?;
5720            }
5721            _ => {}
5722        }
5723        Ok(())
5724    }
5725
5726    fn write_wrapped_binary_op(
5727        &mut self,
5728        module: &crate::Module,
5729        func_ctx: &back::FunctionCtx,
5730        expr: Handle<crate::Expression>,
5731        op: crate::BinaryOperator,
5732        left: Handle<crate::Expression>,
5733        right: Handle<crate::Expression>,
5734    ) -> BackendResult {
5735        let expr_ty = func_ctx.resolve_type(expr, &module.types);
5736        let left_ty = func_ctx.resolve_type(left, &module.types);
5737        let right_ty = func_ctx.resolve_type(right, &module.types);
5738        match (op, expr_ty.scalar_kind()) {
5739            // Signed integer division of TYPE_MIN / -1, or signed or
5740            // unsigned division by zero, gives an unspecified value in MSL.
5741            // We override the divisor to 1 in these cases.
5742            // This adheres to the WGSL spec in that:
5743            // * TYPE_MIN / -1 == TYPE_MIN
5744            // * x / 0 == x
5745            (
5746                crate::BinaryOperator::Divide,
5747                Some(crate::ScalarKind::Sint | crate::ScalarKind::Uint),
5748            ) => {
5749                let Some(left_wrapped_ty) = left_ty.vector_size_and_scalar() else {
5750                    return Ok(());
5751                };
5752                let Some(right_wrapped_ty) = right_ty.vector_size_and_scalar() else {
5753                    return Ok(());
5754                };
5755                let wrapped = WrappedFunction::BinaryOp {
5756                    op,
5757                    left_ty: left_wrapped_ty,
5758                    right_ty: right_wrapped_ty,
5759                };
5760                if !self.wrapped_functions.insert(wrapped) {
5761                    return Ok(());
5762                }
5763
5764                let Some((vector_size, scalar)) = expr_ty.vector_size_and_scalar() else {
5765                    return Ok(());
5766                };
5767                let mut type_name = String::new();
5768                match vector_size {
5769                    None => put_numeric_type(&mut type_name, scalar, &[])?,
5770                    Some(size) => put_numeric_type(&mut type_name, scalar, &[size])?,
5771                };
5772                writeln!(
5773                    self.out,
5774                    "{type_name} {DIV_FUNCTION}({type_name} lhs, {type_name} rhs) {{"
5775                )?;
5776                let level = back::Level(1);
5777                match scalar.kind {
5778                    crate::ScalarKind::Sint => {
5779                        let min_val = match scalar.width {
5780                            4 => crate::Literal::I32(i32::MIN),
5781                            8 => crate::Literal::I64(i64::MIN),
5782                            _ => {
5783                                return Err(Error::GenericValidation(format!(
5784                                    "Unexpected width for scalar {scalar:?}"
5785                                )));
5786                            }
5787                        };
5788                        write!(
5789                            self.out,
5790                            "{level}return lhs / metal::select(rhs, 1, (lhs == "
5791                        )?;
5792                        self.put_literal(min_val)?;
5793                        writeln!(self.out, " & rhs == -1) | (rhs == 0));")?
5794                    }
5795                    crate::ScalarKind::Uint => writeln!(
5796                        self.out,
5797                        "{level}return lhs / metal::select(rhs, 1u, rhs == 0u);"
5798                    )?,
5799                    _ => unreachable!(),
5800                }
5801                writeln!(self.out, "}}")?;
5802                writeln!(self.out)?;
5803            }
5804            // Integer modulo where one or both operands are negative, or the
5805            // divisor is zero, is undefined behaviour in MSL. To avoid this
5806            // we use the following equation:
5807            //
5808            // dividend - (dividend / divisor) * divisor
5809            //
5810            // overriding the divisor to 1 if either it is 0, or it is -1
5811            // and the dividend is TYPE_MIN.
5812            //
5813            // This adheres to the WGSL spec in that:
5814            // * TYPE_MIN % -1 == 0
5815            // * x % 0 == 0
5816            (
5817                crate::BinaryOperator::Modulo,
5818                Some(crate::ScalarKind::Sint | crate::ScalarKind::Uint),
5819            ) => {
5820                let Some(left_wrapped_ty) = left_ty.vector_size_and_scalar() else {
5821                    return Ok(());
5822                };
5823                let Some((right_vector_size, right_scalar)) = right_ty.vector_size_and_scalar()
5824                else {
5825                    return Ok(());
5826                };
5827                let wrapped = WrappedFunction::BinaryOp {
5828                    op,
5829                    left_ty: left_wrapped_ty,
5830                    right_ty: (right_vector_size, right_scalar),
5831                };
5832                if !self.wrapped_functions.insert(wrapped) {
5833                    return Ok(());
5834                }
5835
5836                let Some((vector_size, scalar)) = expr_ty.vector_size_and_scalar() else {
5837                    return Ok(());
5838                };
5839                let mut type_name = String::new();
5840                match vector_size {
5841                    None => put_numeric_type(&mut type_name, scalar, &[])?,
5842                    Some(size) => put_numeric_type(&mut type_name, scalar, &[size])?,
5843                };
5844                let mut rhs_type_name = String::new();
5845                match right_vector_size {
5846                    None => put_numeric_type(&mut rhs_type_name, right_scalar, &[])?,
5847                    Some(size) => put_numeric_type(&mut rhs_type_name, right_scalar, &[size])?,
5848                };
5849
5850                writeln!(
5851                    self.out,
5852                    "{type_name} {MOD_FUNCTION}({type_name} lhs, {type_name} rhs) {{"
5853                )?;
5854                let level = back::Level(1);
5855                match scalar.kind {
5856                    crate::ScalarKind::Sint => {
5857                        let min_val = match scalar.width {
5858                            4 => crate::Literal::I32(i32::MIN),
5859                            8 => crate::Literal::I64(i64::MIN),
5860                            _ => {
5861                                return Err(Error::GenericValidation(format!(
5862                                    "Unexpected width for scalar {scalar:?}"
5863                                )));
5864                            }
5865                        };
5866                        write!(
5867                            self.out,
5868                            "{level}{rhs_type_name} divisor = metal::select(rhs, 1, (lhs == "
5869                        )?;
5870                        self.put_literal(min_val)?;
5871                        writeln!(self.out, " & rhs == -1) | (rhs == 0));")?;
5872                        writeln!(self.out, "{level}return lhs - (lhs / divisor) * divisor;")?
5873                    }
5874                    crate::ScalarKind::Uint => writeln!(
5875                        self.out,
5876                        "{level}return lhs % metal::select(rhs, 1u, rhs == 0u);"
5877                    )?,
5878                    _ => unreachable!(),
5879                }
5880                writeln!(self.out, "}}")?;
5881                writeln!(self.out)?;
5882            }
5883            _ => {}
5884        }
5885        Ok(())
5886    }
5887
5888    /// Build the mangled helper name for integer vector dot products.
5889    ///
5890    /// `scalar` must be a concrete integer scalar type.
5891    ///
5892    /// Result format: `{DOT_FUNCTION_PREFIX}_{type}{N}` (e.g., `naga_dot_int3`).
5893    fn get_dot_wrapper_function_helper_name(
5894        &self,
5895        scalar: crate::Scalar,
5896        size: crate::VectorSize,
5897    ) -> String {
5898        // Check for consistency with [`super::keywords::RESERVED_SET`]
5899        debug_assert!(concrete_int_scalars().any(|s| s == scalar));
5900
5901        let type_name = scalar.to_msl_name();
5902        let size_suffix = common::vector_size_str(size);
5903        format!("{DOT_FUNCTION_PREFIX}_{type_name}{size_suffix}")
5904    }
5905
5906    #[allow(clippy::too_many_arguments)]
5907    fn write_wrapped_math_function(
5908        &mut self,
5909        module: &crate::Module,
5910        func_ctx: &back::FunctionCtx,
5911        fun: crate::MathFunction,
5912        arg: Handle<crate::Expression>,
5913        _arg1: Option<Handle<crate::Expression>>,
5914        _arg2: Option<Handle<crate::Expression>>,
5915        _arg3: Option<Handle<crate::Expression>>,
5916    ) -> BackendResult {
5917        let arg_ty = func_ctx.resolve_type(arg, &module.types);
5918        match fun {
5919            // Taking the absolute value of the TYPE_MIN of a two's
5920            // complement signed integer type causes overflow, which is
5921            // undefined behaviour in MSL. To avoid this, when the value is
5922            // negative we bitcast the value to unsigned and negate it, then
5923            // bitcast back to signed.
5924            // This adheres to the WGSL spec in that the absolute of the
5925            // type's minimum value should equal to the minimum value.
5926            crate::MathFunction::Abs if arg_ty.scalar_kind() == Some(crate::ScalarKind::Sint) => {
5927                let Some((vector_size, scalar)) = arg_ty.vector_size_and_scalar() else {
5928                    return Ok(());
5929                };
5930                let wrapped = WrappedFunction::Math {
5931                    fun,
5932                    arg_ty: (vector_size, scalar),
5933                };
5934                if !self.wrapped_functions.insert(wrapped) {
5935                    return Ok(());
5936                }
5937
5938                let unsigned_scalar = crate::Scalar {
5939                    kind: crate::ScalarKind::Uint,
5940                    ..scalar
5941                };
5942                let mut type_name = String::new();
5943                let mut unsigned_type_name = String::new();
5944                match vector_size {
5945                    None => {
5946                        put_numeric_type(&mut type_name, scalar, &[])?;
5947                        put_numeric_type(&mut unsigned_type_name, unsigned_scalar, &[])?
5948                    }
5949                    Some(size) => {
5950                        put_numeric_type(&mut type_name, scalar, &[size])?;
5951                        put_numeric_type(&mut unsigned_type_name, unsigned_scalar, &[size])?;
5952                    }
5953                };
5954
5955                writeln!(self.out, "{type_name} {ABS_FUNCTION}({type_name} val) {{")?;
5956                let level = back::Level(1);
5957                writeln!(self.out, "{level}return metal::select(as_type<{type_name}>(-as_type<{unsigned_type_name}>(val)), val, val >= 0);")?;
5958                writeln!(self.out, "}}")?;
5959                writeln!(self.out)?;
5960            }
5961
5962            crate::MathFunction::Dot => match *arg_ty {
5963                crate::TypeInner::Vector { size, scalar }
5964                    if matches!(
5965                        scalar.kind,
5966                        crate::ScalarKind::Sint | crate::ScalarKind::Uint
5967                    ) =>
5968                {
5969                    // De-duplicate per (fun, arg type) like other wrapped math functions
5970                    let wrapped = WrappedFunction::Math {
5971                        fun,
5972                        arg_ty: (Some(size), scalar),
5973                    };
5974                    if !self.wrapped_functions.insert(wrapped) {
5975                        return Ok(());
5976                    }
5977
5978                    let mut vec_ty = String::new();
5979                    put_numeric_type(&mut vec_ty, scalar, &[size])?;
5980                    let mut ret_ty = String::new();
5981                    put_numeric_type(&mut ret_ty, scalar, &[])?;
5982
5983                    let fun_name = self.get_dot_wrapper_function_helper_name(scalar, size);
5984
5985                    // Emit function signature and body using put_dot_product for the expression
5986                    writeln!(self.out, "{ret_ty} {fun_name}({vec_ty} a, {vec_ty} b) {{")?;
5987                    let level = back::Level(1);
5988                    write!(self.out, "{level}return ")?;
5989                    self.put_dot_product("a", "b", size as usize, |writer, name, index| {
5990                        write!(writer.out, "{name}.{}", back::COMPONENTS[index])?;
5991                        Ok(())
5992                    })?;
5993                    writeln!(self.out, ";")?;
5994                    writeln!(self.out, "}}")?;
5995                    writeln!(self.out)?;
5996                }
5997                _ => {}
5998            },
5999
6000            _ => {}
6001        }
6002        Ok(())
6003    }
6004
6005    fn write_wrapped_cast(
6006        &mut self,
6007        module: &crate::Module,
6008        func_ctx: &back::FunctionCtx,
6009        expr: Handle<crate::Expression>,
6010        kind: crate::ScalarKind,
6011        convert: Option<crate::Bytes>,
6012    ) -> BackendResult {
6013        // Avoid undefined behaviour when casting from a float to integer
6014        // when the value is out of range for the target type. Additionally
6015        // ensure we clamp to the correct value as per the WGSL spec.
6016        //
6017        // https://www.w3.org/TR/WGSL/#floating-point-conversion:
6018        // * If X is exactly representable in the target type T, then the
6019        //   result is that value.
6020        // * Otherwise, the result is the value in T closest to
6021        //   truncate(X) and also exactly representable in the original
6022        //   floating point type.
6023        let src_ty = func_ctx.resolve_type(expr, &module.types);
6024        let Some(width) = convert else {
6025            return Ok(());
6026        };
6027        let Some((vector_size, src_scalar)) = src_ty.vector_size_and_scalar() else {
6028            return Ok(());
6029        };
6030        let dst_scalar = crate::Scalar { kind, width };
6031        if src_scalar.kind != crate::ScalarKind::Float
6032            || (dst_scalar.kind != crate::ScalarKind::Sint
6033                && dst_scalar.kind != crate::ScalarKind::Uint)
6034        {
6035            return Ok(());
6036        }
6037        let wrapped = WrappedFunction::Cast {
6038            src_scalar,
6039            vector_size,
6040            dst_scalar,
6041        };
6042        if !self.wrapped_functions.insert(wrapped) {
6043            return Ok(());
6044        }
6045        let (min, max) = proc::min_max_float_representable_by(src_scalar, dst_scalar);
6046
6047        let mut src_type_name = String::new();
6048        match vector_size {
6049            None => put_numeric_type(&mut src_type_name, src_scalar, &[])?,
6050            Some(size) => put_numeric_type(&mut src_type_name, src_scalar, &[size])?,
6051        };
6052        let mut dst_type_name = String::new();
6053        match vector_size {
6054            None => put_numeric_type(&mut dst_type_name, dst_scalar, &[])?,
6055            Some(size) => put_numeric_type(&mut dst_type_name, dst_scalar, &[size])?,
6056        };
6057        let fun_name = match dst_scalar {
6058            crate::Scalar::I32 => F2I32_FUNCTION,
6059            crate::Scalar::U32 => F2U32_FUNCTION,
6060            crate::Scalar::I64 => F2I64_FUNCTION,
6061            crate::Scalar::U64 => F2U64_FUNCTION,
6062            _ => unreachable!(),
6063        };
6064
6065        writeln!(
6066            self.out,
6067            "{dst_type_name} {fun_name}({src_type_name} value) {{"
6068        )?;
6069        let level = back::Level(1);
6070        write!(
6071            self.out,
6072            "{level}return static_cast<{dst_type_name}>({NAMESPACE}::clamp(value, "
6073        )?;
6074        self.put_literal(min)?;
6075        write!(self.out, ", ")?;
6076        self.put_literal(max)?;
6077        writeln!(self.out, "));")?;
6078        writeln!(self.out, "}}")?;
6079        writeln!(self.out)?;
6080        Ok(())
6081    }
6082
6083    /// Helper function used by [`Self::write_wrapped_image_load`] and
6084    /// [`Self::write_wrapped_image_sample`] to write the shared YUV to RGB
6085    /// conversion code for external textures. Expects the preceding code to
6086    /// declare the Y component as a `float` variable of name `y`, the UV
6087    /// components as a `float2` variable of name `uv`, and the external
6088    /// texture params as a variable of name `params`. The emitted code will
6089    /// return the result.
6090    fn write_convert_yuv_to_rgb_and_return(
6091        &mut self,
6092        level: back::Level,
6093        y: &str,
6094        uv: &str,
6095        params: &str,
6096    ) -> BackendResult {
6097        let l1 = level;
6098        let l2 = l1.next();
6099
6100        // Convert from YUV to non-linear RGB in the source color space.
6101        writeln!(
6102            self.out,
6103            "{l1}float3 srcGammaRgb = ({params}.yuv_conversion_matrix * float4({y}, {uv}, 1.0)).rgb;"
6104        )?;
6105
6106        // Apply the inverse of the source transfer function to convert to
6107        // linear RGB in the source color space.
6108        writeln!(self.out, "{l1}float3 srcLinearRgb = {NAMESPACE}::select(")?;
6109        writeln!(self.out, "{l2}{NAMESPACE}::pow((srcGammaRgb + {params}.src_tf.a - 1.0) / {params}.src_tf.a, {params}.src_tf.g),")?;
6110        writeln!(self.out, "{l2}srcGammaRgb / {params}.src_tf.k,")?;
6111        writeln!(
6112            self.out,
6113            "{l2}srcGammaRgb < {params}.src_tf.k * {params}.src_tf.b);"
6114        )?;
6115
6116        // Multiply by the gamut conversion matrix to convert to linear RGB in
6117        // the destination color space.
6118        writeln!(
6119            self.out,
6120            "{l1}float3 dstLinearRgb = {params}.gamut_conversion_matrix * srcLinearRgb;"
6121        )?;
6122
6123        // Finally, apply the dest transfer function to convert to non-linear
6124        // RGB in the destination color space, and return the result.
6125        writeln!(self.out, "{l1}float3 dstGammaRgb = {NAMESPACE}::select(")?;
6126        writeln!(self.out, "{l2}{params}.dst_tf.a * {NAMESPACE}::pow(dstLinearRgb, 1.0 / {params}.dst_tf.g) - ({params}.dst_tf.a - 1),")?;
6127        writeln!(self.out, "{l2}{params}.dst_tf.k * dstLinearRgb,")?;
6128        writeln!(self.out, "{l2}dstLinearRgb < {params}.dst_tf.b);")?;
6129
6130        writeln!(self.out, "{l1}return float4(dstGammaRgb, 1.0);")?;
6131        Ok(())
6132    }
6133
6134    #[allow(clippy::too_many_arguments)]
6135    fn write_wrapped_image_load(
6136        &mut self,
6137        module: &crate::Module,
6138        func_ctx: &back::FunctionCtx,
6139        image: Handle<crate::Expression>,
6140        _coordinate: Handle<crate::Expression>,
6141        _array_index: Option<Handle<crate::Expression>>,
6142        _sample: Option<Handle<crate::Expression>>,
6143        _level: Option<Handle<crate::Expression>>,
6144    ) -> BackendResult {
6145        // We currently only need to wrap image loads for external textures
6146        let class = match *func_ctx.resolve_type(image, &module.types) {
6147            crate::TypeInner::Image { class, .. } => class,
6148            _ => unreachable!(),
6149        };
6150        if class != crate::ImageClass::External {
6151            return Ok(());
6152        }
6153        let wrapped = WrappedFunction::ImageLoad { class };
6154        if !self.wrapped_functions.insert(wrapped) {
6155            return Ok(());
6156        }
6157
6158        writeln!(self.out, "float4 {IMAGE_LOAD_EXTERNAL_FUNCTION}({EXTERNAL_TEXTURE_WRAPPER_STRUCT} tex, uint2 coords) {{")?;
6159        let l1 = back::Level(1);
6160        let l2 = l1.next();
6161        let l3 = l2.next();
6162        writeln!(
6163            self.out,
6164            "{l1}uint2 plane0_size = uint2(tex.plane0.get_width(), tex.plane0.get_height());"
6165        )?;
6166        // Clamp coords to provided size of external texture to prevent OOB
6167        // read. If params.size is zero then clamp to the actual size of the
6168        // texture.
6169        writeln!(
6170            self.out,
6171            "{l1}uint2 cropped_size = {NAMESPACE}::any(tex.params.size != 0) ? tex.params.size : plane0_size;"
6172        )?;
6173        writeln!(
6174            self.out,
6175            "{l1}coords = {NAMESPACE}::min(coords, cropped_size - 1);"
6176        )?;
6177
6178        // Apply load transformation
6179        writeln!(self.out, "{l1}uint2 plane0_coords = uint2({NAMESPACE}::round(tex.params.load_transform * float3(float2(coords), 1.0)));")?;
6180        writeln!(self.out, "{l1}if (tex.params.num_planes == 1u) {{")?;
6181        // For single plane, simply read from plane0
6182        writeln!(self.out, "{l2}return tex.plane0.read(plane0_coords);")?;
6183        writeln!(self.out, "{l1}}} else {{")?;
6184
6185        // Chroma planes may be subsampled so we must scale the coords accordingly.
6186        writeln!(
6187            self.out,
6188            "{l2}uint2 plane1_size = uint2(tex.plane1.get_width(), tex.plane1.get_height());"
6189        )?;
6190        writeln!(self.out, "{l2}uint2 plane1_coords = uint2({NAMESPACE}::floor(float2(plane0_coords) * float2(plane1_size) / float2(plane0_size)));")?;
6191
6192        // For multi-plane, read the Y value from plane 0
6193        writeln!(self.out, "{l2}float y = tex.plane0.read(plane0_coords).x;")?;
6194
6195        writeln!(self.out, "{l2}float2 uv;")?;
6196        writeln!(self.out, "{l2}if (tex.params.num_planes == 2u) {{")?;
6197        // For 2 planes, read UV from interleaved plane 1
6198        writeln!(self.out, "{l3}uv = tex.plane1.read(plane1_coords).xy;")?;
6199        writeln!(self.out, "{l2}}} else {{")?;
6200        // For 3 planes, read U and V from planes 1 and 2 respectively
6201        writeln!(
6202            self.out,
6203            "{l2}uint2 plane2_size = uint2(tex.plane2.get_width(), tex.plane2.get_height());"
6204        )?;
6205        writeln!(self.out, "{l2}uint2 plane2_coords = uint2({NAMESPACE}::floor(float2(plane0_coords) * float2(plane2_size) / float2(plane0_size)));")?;
6206        writeln!(
6207            self.out,
6208            "{l3}uv = float2(tex.plane1.read(plane1_coords).x, tex.plane2.read(plane2_coords).x);"
6209        )?;
6210        writeln!(self.out, "{l2}}}")?;
6211
6212        self.write_convert_yuv_to_rgb_and_return(l2, "y", "uv", "tex.params")?;
6213
6214        writeln!(self.out, "{l1}}}")?;
6215        writeln!(self.out, "}}")?;
6216        writeln!(self.out)?;
6217        Ok(())
6218    }
6219
6220    #[allow(clippy::too_many_arguments)]
6221    fn write_wrapped_image_sample(
6222        &mut self,
6223        module: &crate::Module,
6224        func_ctx: &back::FunctionCtx,
6225        image: Handle<crate::Expression>,
6226        _sampler: Handle<crate::Expression>,
6227        _gather: Option<crate::SwizzleComponent>,
6228        _coordinate: Handle<crate::Expression>,
6229        _array_index: Option<Handle<crate::Expression>>,
6230        _offset: Option<Handle<crate::Expression>>,
6231        _level: crate::SampleLevel,
6232        _depth_ref: Option<Handle<crate::Expression>>,
6233        clamp_to_edge: bool,
6234    ) -> BackendResult {
6235        // We currently only need to wrap textureSampleBaseClampToEdge, for
6236        // both sampled and external textures.
6237        if !clamp_to_edge {
6238            return Ok(());
6239        }
6240        let class = match *func_ctx.resolve_type(image, &module.types) {
6241            crate::TypeInner::Image { class, .. } => class,
6242            _ => unreachable!(),
6243        };
6244        let wrapped = WrappedFunction::ImageSample {
6245            class,
6246            clamp_to_edge: true,
6247        };
6248        if !self.wrapped_functions.insert(wrapped) {
6249            return Ok(());
6250        }
6251        match class {
6252            crate::ImageClass::External => {
6253                writeln!(self.out, "float4 {IMAGE_SAMPLE_BASE_CLAMP_TO_EDGE_FUNCTION}({EXTERNAL_TEXTURE_WRAPPER_STRUCT} tex, {NAMESPACE}::sampler samp, float2 coords) {{")?;
6254                let l1 = back::Level(1);
6255                let l2 = l1.next();
6256                let l3 = l2.next();
6257                writeln!(self.out, "{l1}uint2 plane0_size = uint2(tex.plane0.get_width(), tex.plane0.get_height());")?;
6258                writeln!(
6259                    self.out,
6260                    "{l1}coords = tex.params.sample_transform * float3(coords, 1.0);"
6261                )?;
6262
6263                // Calculate the sample bounds. The purported size of the texture
6264                // (params.size) is irrelevant here as we are dealing with normalized
6265                // coordinates. Usually we would clamp to (0,0)..(1,1). However, we must
6266                // apply the sample transformation to that, also bearing in mind that it
6267                // may contain a flip on either axis. We calculate and adjust for the
6268                // half-texel separately for each plane as it depends on the actual
6269                // texture size which may vary between planes.
6270                writeln!(
6271                    self.out,
6272                    "{l1}float2 bounds_min = tex.params.sample_transform * float3(0.0, 0.0, 1.0);"
6273                )?;
6274                writeln!(
6275                    self.out,
6276                    "{l1}float2 bounds_max = tex.params.sample_transform * float3(1.0, 1.0, 1.0);"
6277                )?;
6278                writeln!(self.out, "{l1}float4 bounds = float4({NAMESPACE}::min(bounds_min, bounds_max), {NAMESPACE}::max(bounds_min, bounds_max));")?;
6279                writeln!(
6280                    self.out,
6281                    "{l1}float2 plane0_half_texel = float2(0.5, 0.5) / float2(plane0_size);"
6282                )?;
6283                writeln!(
6284                    self.out,
6285                    "{l1}float2 plane0_coords = {NAMESPACE}::clamp(coords, bounds.xy + plane0_half_texel, bounds.zw - plane0_half_texel);"
6286                )?;
6287                writeln!(self.out, "{l1}if (tex.params.num_planes == 1u) {{")?;
6288                // For single plane, simply sample from plane0
6289                writeln!(
6290                    self.out,
6291                    "{l2}return tex.plane0.sample(samp, plane0_coords, {NAMESPACE}::level(0.0f));"
6292                )?;
6293                writeln!(self.out, "{l1}}} else {{")?;
6294                writeln!(self.out, "{l2}uint2 plane1_size = uint2(tex.plane1.get_width(), tex.plane1.get_height());")?;
6295                writeln!(
6296                    self.out,
6297                    "{l2}float2 plane1_half_texel = float2(0.5, 0.5) / float2(plane1_size);"
6298                )?;
6299                writeln!(
6300                    self.out,
6301                    "{l2}float2 plane1_coords = {NAMESPACE}::clamp(coords, bounds.xy + plane1_half_texel, bounds.zw - plane1_half_texel);"
6302                )?;
6303
6304                // For multi-plane, sample the Y value from plane 0
6305                writeln!(
6306                    self.out,
6307                    "{l2}float y = tex.plane0.sample(samp, plane0_coords, {NAMESPACE}::level(0.0f)).r;"
6308                )?;
6309                writeln!(self.out, "{l2}float2 uv = float2(0.0, 0.0);")?;
6310                writeln!(self.out, "{l2}if (tex.params.num_planes == 2u) {{")?;
6311                // For 2 planes, sample UV from interleaved plane 1
6312                writeln!(
6313                    self.out,
6314                    "{l3}uv = tex.plane1.sample(samp, plane1_coords, {NAMESPACE}::level(0.0f)).xy;"
6315                )?;
6316                writeln!(self.out, "{l2}}} else {{")?;
6317                // For 3 planes, sample U and V from planes 1 and 2 respectively
6318                writeln!(self.out, "{l3}uint2 plane2_size = uint2(tex.plane2.get_width(), tex.plane2.get_height());")?;
6319                writeln!(
6320                    self.out,
6321                    "{l3}float2 plane2_half_texel = float2(0.5, 0.5) / float2(plane2_size);"
6322                )?;
6323                writeln!(
6324                    self.out,
6325                    "{l3}float2 plane2_coords = {NAMESPACE}::clamp(coords, bounds.xy + plane2_half_texel, bounds.zw - plane1_half_texel);"
6326                )?;
6327                writeln!(self.out, "{l3}uv.x = tex.plane1.sample(samp, plane1_coords, {NAMESPACE}::level(0.0f)).x;")?;
6328                writeln!(self.out, "{l3}uv.y = tex.plane2.sample(samp, plane2_coords, {NAMESPACE}::level(0.0f)).x;")?;
6329                writeln!(self.out, "{l2}}}")?;
6330
6331                self.write_convert_yuv_to_rgb_and_return(l2, "y", "uv", "tex.params")?;
6332
6333                writeln!(self.out, "{l1}}}")?;
6334                writeln!(self.out, "}}")?;
6335                writeln!(self.out)?;
6336            }
6337            _ => {
6338                writeln!(self.out, "{NAMESPACE}::float4 {IMAGE_SAMPLE_BASE_CLAMP_TO_EDGE_FUNCTION}({NAMESPACE}::texture2d<float, {NAMESPACE}::access::sample> tex, {NAMESPACE}::sampler samp, {NAMESPACE}::float2 coords) {{")?;
6339                let l1 = back::Level(1);
6340                writeln!(self.out, "{l1}{NAMESPACE}::float2 half_texel = 0.5 / {NAMESPACE}::float2(tex.get_width(0u), tex.get_height(0u));")?;
6341                writeln!(
6342                    self.out,
6343                    "{l1}return tex.sample(samp, {NAMESPACE}::clamp(coords, half_texel, 1.0 - half_texel), {NAMESPACE}::level(0.0));"
6344                )?;
6345                writeln!(self.out, "}}")?;
6346                writeln!(self.out)?;
6347            }
6348        }
6349        Ok(())
6350    }
6351
6352    fn write_wrapped_image_query(
6353        &mut self,
6354        module: &crate::Module,
6355        func_ctx: &back::FunctionCtx,
6356        image: Handle<crate::Expression>,
6357        query: crate::ImageQuery,
6358    ) -> BackendResult {
6359        // We currently only need to wrap size image queries for external textures
6360        if !matches!(query, crate::ImageQuery::Size { .. }) {
6361            return Ok(());
6362        }
6363        let class = match *func_ctx.resolve_type(image, &module.types) {
6364            crate::TypeInner::Image { class, .. } => class,
6365            _ => unreachable!(),
6366        };
6367        if class != crate::ImageClass::External {
6368            return Ok(());
6369        }
6370        let wrapped = WrappedFunction::ImageQuerySize { class };
6371        if !self.wrapped_functions.insert(wrapped) {
6372            return Ok(());
6373        }
6374        writeln!(
6375            self.out,
6376            "uint2 {IMAGE_SIZE_EXTERNAL_FUNCTION}({EXTERNAL_TEXTURE_WRAPPER_STRUCT} tex) {{"
6377        )?;
6378        let l1 = back::Level(1);
6379        let l2 = l1.next();
6380        writeln!(
6381            self.out,
6382            "{l1}if ({NAMESPACE}::any(tex.params.size != uint2(0u))) {{"
6383        )?;
6384        writeln!(self.out, "{l2}return tex.params.size;")?;
6385        writeln!(self.out, "{l1}}} else {{")?;
6386        // params.size == (0, 0) indicates to query and return plane 0's actual size
6387        writeln!(
6388            self.out,
6389            "{l2}return uint2(tex.plane0.get_width(), tex.plane0.get_height());"
6390        )?;
6391        writeln!(self.out, "{l1}}}")?;
6392        writeln!(self.out, "}}")?;
6393        writeln!(self.out)?;
6394        Ok(())
6395    }
6396
6397    fn write_wrapped_cooperative_load(
6398        &mut self,
6399        module: &crate::Module,
6400        func_ctx: &back::FunctionCtx,
6401        columns: crate::CooperativeSize,
6402        rows: crate::CooperativeSize,
6403        pointer: Handle<crate::Expression>,
6404    ) -> BackendResult {
6405        let ptr_ty = func_ctx.resolve_type(pointer, &module.types);
6406        let space = ptr_ty.pointer_space().unwrap();
6407        let space_name = space.to_msl_name().unwrap_or_default();
6408        let scalar = ptr_ty
6409            .pointer_base_type()
6410            .unwrap()
6411            .inner_with(&module.types)
6412            .scalar()
6413            .unwrap();
6414        let wrapped = WrappedFunction::CooperativeLoad {
6415            space_name,
6416            columns,
6417            rows,
6418            scalar,
6419        };
6420        if !self.wrapped_functions.insert(wrapped) {
6421            return Ok(());
6422        }
6423        let scalar_name = scalar.to_msl_name();
6424        writeln!(
6425            self.out,
6426            "{NAMESPACE}::simdgroup_{scalar_name}{}x{} {COOPERATIVE_LOAD_FUNCTION}(const {space_name} {scalar_name}* ptr, int stride, bool is_row_major) {{",
6427            columns as u32, rows as u32,
6428        )?;
6429        let l1 = back::Level(1);
6430        writeln!(
6431            self.out,
6432            "{l1}{NAMESPACE}::simdgroup_{scalar_name}{}x{} m;",
6433            columns as u32, rows as u32
6434        )?;
6435        let matrix_origin = "0";
6436        writeln!(
6437            self.out,
6438            "{l1}simdgroup_load(m, ptr, stride, {matrix_origin}, is_row_major);"
6439        )?;
6440        writeln!(self.out, "{l1}return m;")?;
6441        writeln!(self.out, "}}")?;
6442        writeln!(self.out)?;
6443        Ok(())
6444    }
6445
6446    fn write_wrapped_cooperative_multiply_add(
6447        &mut self,
6448        module: &crate::Module,
6449        func_ctx: &back::FunctionCtx,
6450        space: crate::AddressSpace,
6451        a: Handle<crate::Expression>,
6452        b: Handle<crate::Expression>,
6453    ) -> BackendResult {
6454        let space_name = space.to_msl_name().unwrap_or_default();
6455        let (a_c, a_r, scalar) = match *func_ctx.resolve_type(a, &module.types) {
6456            crate::TypeInner::CooperativeMatrix {
6457                columns,
6458                rows,
6459                scalar,
6460                ..
6461            } => (columns, rows, scalar),
6462            _ => unreachable!(),
6463        };
6464        let (b_c, b_r) = match *func_ctx.resolve_type(b, &module.types) {
6465            crate::TypeInner::CooperativeMatrix { columns, rows, .. } => (columns, rows),
6466            _ => unreachable!(),
6467        };
6468        let wrapped = WrappedFunction::CooperativeMultiplyAdd {
6469            space_name,
6470            columns: b_c,
6471            rows: a_r,
6472            intermediate: a_c,
6473            scalar,
6474        };
6475        if !self.wrapped_functions.insert(wrapped) {
6476            return Ok(());
6477        }
6478        let scalar_name = scalar.to_msl_name();
6479        writeln!(
6480            self.out,
6481            "{NAMESPACE}::simdgroup_{scalar_name}{}x{} {COOPERATIVE_MULTIPLY_ADD_FUNCTION}(const {space_name} {NAMESPACE}::simdgroup_{scalar_name}{}x{}& a, const {space_name} {NAMESPACE}::simdgroup_{scalar_name}{}x{}& b, const {space_name} {NAMESPACE}::simdgroup_{scalar_name}{}x{}& c) {{",
6482            b_c as u32, a_r as u32, a_c as u32, a_r as u32, b_c as u32, b_r as u32, b_c as u32, a_r as u32,
6483        )?;
6484        let l1 = back::Level(1);
6485        writeln!(
6486            self.out,
6487            "{l1}{NAMESPACE}::simdgroup_{scalar_name}{}x{} d;",
6488            b_c as u32, a_r as u32
6489        )?;
6490        writeln!(self.out, "{l1}simdgroup_multiply_accumulate(d,a,b,c);")?;
6491        writeln!(self.out, "{l1}return d;")?;
6492        writeln!(self.out, "}}")?;
6493        writeln!(self.out)?;
6494        Ok(())
6495    }
6496
6497    pub(super) fn write_wrapped_functions(
6498        &mut self,
6499        module: &crate::Module,
6500        func_ctx: &back::FunctionCtx,
6501    ) -> BackendResult {
6502        for (expr_handle, expr) in func_ctx.expressions.iter() {
6503            match *expr {
6504                crate::Expression::Unary { op, expr: operand } => {
6505                    self.write_wrapped_unary_op(module, func_ctx, op, operand)?;
6506                }
6507                crate::Expression::Binary { op, left, right } => {
6508                    self.write_wrapped_binary_op(module, func_ctx, expr_handle, op, left, right)?;
6509                }
6510                crate::Expression::Math {
6511                    fun,
6512                    arg,
6513                    arg1,
6514                    arg2,
6515                    arg3,
6516                } => {
6517                    self.write_wrapped_math_function(module, func_ctx, fun, arg, arg1, arg2, arg3)?;
6518                }
6519                crate::Expression::As {
6520                    expr,
6521                    kind,
6522                    convert,
6523                } => {
6524                    self.write_wrapped_cast(module, func_ctx, expr, kind, convert)?;
6525                }
6526                crate::Expression::ImageLoad {
6527                    image,
6528                    coordinate,
6529                    array_index,
6530                    sample,
6531                    level,
6532                } => {
6533                    self.write_wrapped_image_load(
6534                        module,
6535                        func_ctx,
6536                        image,
6537                        coordinate,
6538                        array_index,
6539                        sample,
6540                        level,
6541                    )?;
6542                }
6543                crate::Expression::ImageSample {
6544                    image,
6545                    sampler,
6546                    gather,
6547                    coordinate,
6548                    array_index,
6549                    offset,
6550                    level,
6551                    depth_ref,
6552                    clamp_to_edge,
6553                } => {
6554                    self.write_wrapped_image_sample(
6555                        module,
6556                        func_ctx,
6557                        image,
6558                        sampler,
6559                        gather,
6560                        coordinate,
6561                        array_index,
6562                        offset,
6563                        level,
6564                        depth_ref,
6565                        clamp_to_edge,
6566                    )?;
6567                }
6568                crate::Expression::ImageQuery { image, query } => {
6569                    self.write_wrapped_image_query(module, func_ctx, image, query)?;
6570                }
6571                crate::Expression::CooperativeLoad {
6572                    columns,
6573                    rows,
6574                    role: _,
6575                    ref data,
6576                } => {
6577                    self.write_wrapped_cooperative_load(
6578                        module,
6579                        func_ctx,
6580                        columns,
6581                        rows,
6582                        data.pointer,
6583                    )?;
6584                }
6585                crate::Expression::CooperativeMultiplyAdd { a, b, c: _ } => {
6586                    let space = crate::AddressSpace::Private;
6587                    self.write_wrapped_cooperative_multiply_add(module, func_ctx, space, a, b)?;
6588                }
6589                _ => {}
6590            }
6591        }
6592
6593        Ok(())
6594    }
6595
6596    // Returns the array of mapped entry point names.
6597    fn write_functions(
6598        &mut self,
6599        module: &crate::Module,
6600        mod_info: &valid::ModuleInfo,
6601        options: &Options,
6602        pipeline_options: &PipelineOptions,
6603    ) -> Result<TranslationInfo, Error> {
6604        use back::msl::VertexFormat;
6605
6606        // Define structs to hold resolved/generated data for vertex buffers and
6607        // their attributes.
6608        struct AttributeMappingResolved {
6609            ty_name: String,
6610            dimension: u32,
6611            ty_is_int: bool,
6612            name: String,
6613        }
6614        let mut am_resolved = FastHashMap::<u32, AttributeMappingResolved>::default();
6615
6616        struct VertexBufferMappingResolved<'a> {
6617            id: u32,
6618            stride: u32,
6619            step_mode: back::msl::VertexBufferStepMode,
6620            ty_name: String,
6621            param_name: String,
6622            elem_name: String,
6623            attributes: &'a Vec<back::msl::AttributeMapping>,
6624        }
6625        let mut vbm_resolved = Vec::<VertexBufferMappingResolved>::new();
6626
6627        // Define a struct to hold a named reference to a byte-unpacking function.
6628        struct UnpackingFunction {
6629            name: String,
6630            byte_count: u32,
6631            dimension: u32,
6632        }
6633        let mut unpacking_functions = FastHashMap::<VertexFormat, UnpackingFunction>::default();
6634
6635        // Check if we are attempting vertex pulling. If we are, generate some
6636        // names we'll need, and iterate the vertex buffer mappings to output
6637        // all the conversion functions we'll need to unpack the attribute data.
6638        // We can re-use these names for all entry points that need them, since
6639        // those entry points also use self.namer.
6640        let mut needs_vertex_id = false;
6641        let v_id = self.namer.call("v_id");
6642
6643        let mut needs_instance_id = false;
6644        let i_id = self.namer.call("i_id");
6645        if pipeline_options.vertex_pulling_transform {
6646            for vbm in &pipeline_options.vertex_buffer_mappings {
6647                let buffer_id = vbm.id;
6648                let buffer_stride = vbm.stride;
6649
6650                assert!(
6651                    buffer_stride > 0,
6652                    "Vertex pulling requires a non-zero buffer stride."
6653                );
6654
6655                match vbm.step_mode {
6656                    back::msl::VertexBufferStepMode::Constant => {}
6657                    back::msl::VertexBufferStepMode::ByVertex => {
6658                        needs_vertex_id = true;
6659                    }
6660                    back::msl::VertexBufferStepMode::ByInstance => {
6661                        needs_instance_id = true;
6662                    }
6663                }
6664
6665                let buffer_ty = self.namer.call(format!("vb_{buffer_id}_type").as_str());
6666                let buffer_param = self.namer.call(format!("vb_{buffer_id}_in").as_str());
6667                let buffer_elem = self.namer.call(format!("vb_{buffer_id}_elem").as_str());
6668
6669                vbm_resolved.push(VertexBufferMappingResolved {
6670                    id: buffer_id,
6671                    stride: buffer_stride,
6672                    step_mode: vbm.step_mode,
6673                    ty_name: buffer_ty,
6674                    param_name: buffer_param,
6675                    elem_name: buffer_elem,
6676                    attributes: &vbm.attributes,
6677                });
6678
6679                // Iterate the attributes and generate needed unpacking functions.
6680                for attribute in &vbm.attributes {
6681                    if unpacking_functions.contains_key(&attribute.format) {
6682                        continue;
6683                    }
6684                    let (name, byte_count, dimension) =
6685                        match self.write_unpacking_function(attribute.format) {
6686                            Ok((name, byte_count, dimension)) => (name, byte_count, dimension),
6687                            _ => {
6688                                continue;
6689                            }
6690                        };
6691                    unpacking_functions.insert(
6692                        attribute.format,
6693                        UnpackingFunction {
6694                            name,
6695                            byte_count,
6696                            dimension,
6697                        },
6698                    );
6699                }
6700            }
6701        }
6702
6703        let mut pass_through_globals = Vec::new();
6704        for (fun_handle, fun) in module.functions.iter() {
6705            log::trace!(
6706                "function {:?}, handle {:?}",
6707                fun.name.as_deref().unwrap_or("(anonymous)"),
6708                fun_handle
6709            );
6710
6711            let ctx = back::FunctionCtx {
6712                ty: back::FunctionType::Function(fun_handle),
6713                info: &mod_info[fun_handle],
6714                expressions: &fun.expressions,
6715                named_expressions: &fun.named_expressions,
6716            };
6717
6718            writeln!(self.out)?;
6719            self.write_wrapped_functions(module, &ctx)?;
6720
6721            let fun_info = &mod_info[fun_handle];
6722            pass_through_globals.clear();
6723            let mut needs_buffer_sizes = false;
6724            for (handle, var) in module.global_variables.iter() {
6725                if !fun_info[handle].is_empty() {
6726                    if var.space.needs_pass_through() {
6727                        pass_through_globals.push(handle);
6728                    }
6729                    needs_buffer_sizes |= needs_array_length(var.ty, &module.types);
6730                }
6731            }
6732
6733            let fun_name = &self.names[&NameKey::Function(fun_handle)];
6734            match fun.result {
6735                Some(ref result) => {
6736                    let ty_name = TypeContext {
6737                        handle: result.ty,
6738                        gctx: module.to_ctx(),
6739                        names: &self.names,
6740                        access: crate::StorageAccess::empty(),
6741                        first_time: false,
6742                    };
6743                    write!(self.out, "{ty_name}")?;
6744                }
6745                None => {
6746                    write!(self.out, "void")?;
6747                }
6748            }
6749            writeln!(self.out, " {fun_name}(")?;
6750
6751            for (index, arg) in fun.arguments.iter().enumerate() {
6752                let name = &self.names[&NameKey::FunctionArgument(fun_handle, index as u32)];
6753                let param_type_name = TypeContext {
6754                    handle: arg.ty,
6755                    gctx: module.to_ctx(),
6756                    names: &self.names,
6757                    access: crate::StorageAccess::empty(),
6758                    first_time: false,
6759                };
6760                let separator = separate(
6761                    !pass_through_globals.is_empty()
6762                        || index + 1 != fun.arguments.len()
6763                        || needs_buffer_sizes,
6764                );
6765                writeln!(
6766                    self.out,
6767                    "{}{} {}{}",
6768                    back::INDENT,
6769                    param_type_name,
6770                    name,
6771                    separator
6772                )?;
6773            }
6774            for (index, &handle) in pass_through_globals.iter().enumerate() {
6775                let tyvar = TypedGlobalVariable {
6776                    module,
6777                    names: &self.names,
6778                    handle,
6779                    usage: fun_info[handle],
6780                    reference: true,
6781                };
6782                let separator =
6783                    separate(index + 1 != pass_through_globals.len() || needs_buffer_sizes);
6784                write!(self.out, "{}", back::INDENT)?;
6785                tyvar.try_fmt(&mut self.out)?;
6786                writeln!(self.out, "{separator}")?;
6787            }
6788
6789            if needs_buffer_sizes {
6790                writeln!(
6791                    self.out,
6792                    "{}constant _mslBufferSizes& _buffer_sizes",
6793                    back::INDENT
6794                )?;
6795            }
6796
6797            writeln!(self.out, ") {{")?;
6798
6799            let guarded_indices =
6800                index::find_checked_indexes(module, fun, fun_info, options.bounds_check_policies);
6801
6802            let context = StatementContext {
6803                expression: ExpressionContext {
6804                    function: fun,
6805                    origin: FunctionOrigin::Handle(fun_handle),
6806                    info: fun_info,
6807                    lang_version: options.lang_version,
6808                    policies: options.bounds_check_policies,
6809                    guarded_indices,
6810                    module,
6811                    mod_info,
6812                    pipeline_options,
6813                    force_loop_bounding: options.force_loop_bounding,
6814                },
6815                result_struct: None,
6816            };
6817
6818            self.put_locals(&context.expression)?;
6819            self.update_expressions_to_bake(fun, fun_info, &context.expression);
6820            self.put_block(back::Level(1), &fun.body, &context)?;
6821            writeln!(self.out, "}}")?;
6822            self.named_expressions.clear();
6823        }
6824
6825        let ep_range = get_entry_points(module, pipeline_options.entry_point.as_ref())
6826            .map_err(|(stage, name)| Error::EntryPointNotFound(stage, name))?;
6827
6828        let mut info = TranslationInfo {
6829            entry_point_names: Vec::with_capacity(ep_range.len()),
6830        };
6831
6832        for ep_index in ep_range {
6833            let ep = &module.entry_points[ep_index];
6834            let fun = &ep.function;
6835            let fun_info = mod_info.get_entry_point(ep_index);
6836            let mut ep_error = None;
6837
6838            // For vertex_id and instance_id arguments, presume that we'll
6839            // use our generated names, but switch to the name of an
6840            // existing @builtin param, if we find one.
6841            let mut v_existing_id = None;
6842            let mut i_existing_id = None;
6843
6844            log::trace!(
6845                "entry point {:?}, index {:?}",
6846                fun.name.as_deref().unwrap_or("(anonymous)"),
6847                ep_index
6848            );
6849
6850            let ctx = back::FunctionCtx {
6851                ty: back::FunctionType::EntryPoint(ep_index as u16),
6852                info: fun_info,
6853                expressions: &fun.expressions,
6854                named_expressions: &fun.named_expressions,
6855            };
6856
6857            self.write_wrapped_functions(module, &ctx)?;
6858
6859            let (em_str, in_mode, out_mode, can_vertex_pull) = match ep.stage {
6860                crate::ShaderStage::Vertex => (
6861                    "vertex",
6862                    LocationMode::VertexInput,
6863                    LocationMode::VertexOutput,
6864                    true,
6865                ),
6866                crate::ShaderStage::Fragment => (
6867                    "fragment",
6868                    LocationMode::FragmentInput,
6869                    LocationMode::FragmentOutput,
6870                    false,
6871                ),
6872                crate::ShaderStage::Compute => (
6873                    "kernel",
6874                    LocationMode::Uniform,
6875                    LocationMode::Uniform,
6876                    false,
6877                ),
6878                crate::ShaderStage::Task | crate::ShaderStage::Mesh => unimplemented!(),
6879            };
6880
6881            // Should this entry point be modified to do vertex pulling?
6882            let do_vertex_pulling = can_vertex_pull
6883                && pipeline_options.vertex_pulling_transform
6884                && !pipeline_options.vertex_buffer_mappings.is_empty();
6885
6886            // Is any global variable used by this entry point dynamically sized?
6887            let needs_buffer_sizes = do_vertex_pulling
6888                || module
6889                    .global_variables
6890                    .iter()
6891                    .filter(|&(handle, _)| !fun_info[handle].is_empty())
6892                    .any(|(_, var)| needs_array_length(var.ty, &module.types));
6893
6894            // skip this entry point if any global bindings are missing,
6895            // or their types are incompatible.
6896            if !options.fake_missing_bindings {
6897                for (var_handle, var) in module.global_variables.iter() {
6898                    if fun_info[var_handle].is_empty() {
6899                        continue;
6900                    }
6901                    match var.space {
6902                        crate::AddressSpace::Uniform
6903                        | crate::AddressSpace::Storage { .. }
6904                        | crate::AddressSpace::Handle => {
6905                            let br = match var.binding {
6906                                Some(ref br) => br,
6907                                None => {
6908                                    let var_name = var.name.clone().unwrap_or_default();
6909                                    ep_error =
6910                                        Some(super::EntryPointError::MissingBinding(var_name));
6911                                    break;
6912                                }
6913                            };
6914                            let target = options.get_resource_binding_target(ep, br);
6915                            let good = match target {
6916                                Some(target) => {
6917                                    // We intentionally don't dereference binding_arrays here,
6918                                    // so that binding arrays fall to the buffer location.
6919
6920                                    match module.types[var.ty].inner {
6921                                        crate::TypeInner::Image {
6922                                            class: crate::ImageClass::External,
6923                                            ..
6924                                        } => target.external_texture.is_some(),
6925                                        crate::TypeInner::Image { .. } => target.texture.is_some(),
6926                                        crate::TypeInner::Sampler { .. } => {
6927                                            target.sampler.is_some()
6928                                        }
6929                                        _ => target.buffer.is_some(),
6930                                    }
6931                                }
6932                                None => false,
6933                            };
6934                            if !good {
6935                                ep_error = Some(super::EntryPointError::MissingBindTarget(*br));
6936                                break;
6937                            }
6938                        }
6939                        crate::AddressSpace::Immediate => {
6940                            if let Err(e) = options.resolve_immediates(ep) {
6941                                ep_error = Some(e);
6942                                break;
6943                            }
6944                        }
6945                        crate::AddressSpace::TaskPayload => {
6946                            unimplemented!()
6947                        }
6948                        crate::AddressSpace::Function
6949                        | crate::AddressSpace::Private
6950                        | crate::AddressSpace::WorkGroup => {}
6951                    }
6952                }
6953                if needs_buffer_sizes {
6954                    if let Err(err) = options.resolve_sizes_buffer(ep) {
6955                        ep_error = Some(err);
6956                    }
6957                }
6958            }
6959
6960            if let Some(err) = ep_error {
6961                info.entry_point_names.push(Err(err));
6962                continue;
6963            }
6964            let fun_name = &self.names[&NameKey::EntryPoint(ep_index as _)];
6965            info.entry_point_names.push(Ok(fun_name.clone()));
6966
6967            writeln!(self.out)?;
6968
6969            // Since `Namer.reset` wasn't expecting struct members to be
6970            // suddenly injected into another namespace like this,
6971            // `self.names` doesn't keep them distinct from other variables.
6972            // Generate fresh names for these arguments, and remember the
6973            // mapping.
6974            let mut flattened_member_names = FastHashMap::default();
6975            // Varyings' members get their own namespace
6976            let mut varyings_namer = proc::Namer::default();
6977
6978            // List all the Naga `EntryPoint`'s `Function`'s arguments,
6979            // flattening structs into their members. In Metal, we will pass
6980            // each of these values to the entry point as a separate argument—
6981            // except for the varyings, handled next.
6982            let mut flattened_arguments = Vec::new();
6983            for (arg_index, arg) in fun.arguments.iter().enumerate() {
6984                match module.types[arg.ty].inner {
6985                    crate::TypeInner::Struct { ref members, .. } => {
6986                        for (member_index, member) in members.iter().enumerate() {
6987                            let member_index = member_index as u32;
6988                            flattened_arguments.push((
6989                                NameKey::StructMember(arg.ty, member_index),
6990                                member.ty,
6991                                member.binding.as_ref(),
6992                            ));
6993                            let name_key = NameKey::StructMember(arg.ty, member_index);
6994                            let name = match member.binding {
6995                                Some(crate::Binding::Location { .. }) => {
6996                                    if do_vertex_pulling {
6997                                        self.namer.call(&self.names[&name_key])
6998                                    } else {
6999                                        varyings_namer.call(&self.names[&name_key])
7000                                    }
7001                                }
7002                                _ => self.namer.call(&self.names[&name_key]),
7003                            };
7004                            flattened_member_names.insert(name_key, name);
7005                        }
7006                    }
7007                    _ => flattened_arguments.push((
7008                        NameKey::EntryPointArgument(ep_index as _, arg_index as u32),
7009                        arg.ty,
7010                        arg.binding.as_ref(),
7011                    )),
7012                }
7013            }
7014
7015            // Identify the varyings among the argument values, and maybe emit
7016            // a struct type named `<fun>Input` to hold them. If we are doing
7017            // vertex pulling, we instead update our attribute mapping to
7018            // note the types, names, and zero values of the attributes.
7019            let stage_in_name = self.namer.call(&format!("{fun_name}Input"));
7020            let varyings_member_name = self.namer.call("varyings");
7021            let mut has_varyings = false;
7022            if !flattened_arguments.is_empty() {
7023                if !do_vertex_pulling {
7024                    writeln!(self.out, "struct {stage_in_name} {{")?;
7025                }
7026                for &(ref name_key, ty, binding) in flattened_arguments.iter() {
7027                    let Some(binding) = binding else {
7028                        continue;
7029                    };
7030                    let name = match *name_key {
7031                        NameKey::StructMember(..) => &flattened_member_names[name_key],
7032                        _ => &self.names[name_key],
7033                    };
7034                    let ty_name = TypeContext {
7035                        handle: ty,
7036                        gctx: module.to_ctx(),
7037                        names: &self.names,
7038                        access: crate::StorageAccess::empty(),
7039                        first_time: false,
7040                    };
7041                    let resolved = options.resolve_local_binding(binding, in_mode)?;
7042                    let location = match *binding {
7043                        crate::Binding::Location { location, .. } => Some(location),
7044                        crate::Binding::BuiltIn(crate::BuiltIn::Barycentric { .. }) => None,
7045                        crate::Binding::BuiltIn(_) => continue,
7046                    };
7047                    if do_vertex_pulling {
7048                        let Some(location) = location else {
7049                            continue;
7050                        };
7051                        // Update our attribute mapping.
7052                        am_resolved.insert(
7053                            location,
7054                            AttributeMappingResolved {
7055                                ty_name: ty_name.to_string(),
7056                                dimension: ty_name.vertex_input_dimension(),
7057                                ty_is_int: ty_name.scalar().is_some_and(scalar_is_int),
7058                                name: name.to_string(),
7059                            },
7060                        );
7061                    } else {
7062                        has_varyings = true;
7063                        write!(self.out, "{}{} {}", back::INDENT, ty_name, name)?;
7064                        resolved.try_fmt(&mut self.out)?;
7065                        writeln!(self.out, ";")?;
7066                    }
7067                }
7068                if !do_vertex_pulling {
7069                    writeln!(self.out, "}};")?;
7070                }
7071            }
7072
7073            // Define a struct type named for the return value, if any, named
7074            // `<fun>Output`.
7075            let stage_out_name = self.namer.call(&format!("{fun_name}Output"));
7076            let result_member_name = self.namer.call("member");
7077            let result_type_name = match fun.result {
7078                Some(ref result) => {
7079                    let mut result_members = Vec::new();
7080                    if let crate::TypeInner::Struct { ref members, .. } =
7081                        module.types[result.ty].inner
7082                    {
7083                        for (member_index, member) in members.iter().enumerate() {
7084                            result_members.push((
7085                                &self.names[&NameKey::StructMember(result.ty, member_index as u32)],
7086                                member.ty,
7087                                member.binding.as_ref(),
7088                            ));
7089                        }
7090                    } else {
7091                        result_members.push((
7092                            &result_member_name,
7093                            result.ty,
7094                            result.binding.as_ref(),
7095                        ));
7096                    }
7097
7098                    writeln!(self.out, "struct {stage_out_name} {{")?;
7099                    let mut has_point_size = false;
7100                    for (name, ty, binding) in result_members {
7101                        let ty_name = TypeContext {
7102                            handle: ty,
7103                            gctx: module.to_ctx(),
7104                            names: &self.names,
7105                            access: crate::StorageAccess::empty(),
7106                            first_time: true,
7107                        };
7108                        let binding = binding.ok_or_else(|| {
7109                            Error::GenericValidation("Expected binding, got None".into())
7110                        })?;
7111
7112                        if let crate::Binding::BuiltIn(crate::BuiltIn::PointSize) = *binding {
7113                            has_point_size = true;
7114                            if !pipeline_options.allow_and_force_point_size {
7115                                continue;
7116                            }
7117                        }
7118
7119                        let array_len = match module.types[ty].inner {
7120                            crate::TypeInner::Array {
7121                                size: crate::ArraySize::Constant(size),
7122                                ..
7123                            } => Some(size),
7124                            _ => None,
7125                        };
7126                        let resolved = options.resolve_local_binding(binding, out_mode)?;
7127                        write!(self.out, "{}{} {}", back::INDENT, ty_name, name)?;
7128                        if let Some(array_len) = array_len {
7129                            write!(self.out, " [{array_len}]")?;
7130                        }
7131                        resolved.try_fmt(&mut self.out)?;
7132                        writeln!(self.out, ";")?;
7133                    }
7134
7135                    if pipeline_options.allow_and_force_point_size
7136                        && ep.stage == crate::ShaderStage::Vertex
7137                        && !has_point_size
7138                    {
7139                        // inject the point size output last
7140                        writeln!(
7141                            self.out,
7142                            "{}float _point_size [[point_size]];",
7143                            back::INDENT
7144                        )?;
7145                    }
7146                    writeln!(self.out, "}};")?;
7147                    &stage_out_name
7148                }
7149                None => "void",
7150            };
7151
7152            // If we're doing a vertex pulling transform, define the buffer
7153            // structure types.
7154            if do_vertex_pulling {
7155                for vbm in &vbm_resolved {
7156                    let buffer_stride = vbm.stride;
7157                    let buffer_ty = &vbm.ty_name;
7158
7159                    // Define a structure of bytes of the appropriate size.
7160                    // When we access the attributes, we'll be unpacking these
7161                    // bytes at some offset.
7162                    writeln!(
7163                        self.out,
7164                        "struct {buffer_ty} {{ metal::uchar data[{buffer_stride}]; }};"
7165                    )?;
7166                }
7167            }
7168
7169            // Write the entry point function's name, and begin its argument list.
7170            writeln!(self.out, "{em_str} {result_type_name} {fun_name}(")?;
7171
7172            let mut is_first_argument = true;
7173            let mut separator = || {
7174                if is_first_argument {
7175                    is_first_argument = false;
7176                    ' '
7177                } else {
7178                    ','
7179                }
7180            };
7181
7182            // If we have produced a struct holding the `EntryPoint`'s
7183            // `Function`'s arguments' varyings, pass that struct first.
7184            if has_varyings {
7185                writeln!(
7186                    self.out,
7187                    "{} {stage_in_name} {varyings_member_name} [[stage_in]]",
7188                    separator()
7189                )?;
7190            }
7191
7192            let mut local_invocation_id = None;
7193
7194            // Then pass the remaining arguments not included in the varyings
7195            // struct.
7196            for &(ref name_key, ty, binding) in flattened_arguments.iter() {
7197                let binding = match binding {
7198                    Some(&crate::Binding::BuiltIn(crate::BuiltIn::Barycentric { .. })) => continue,
7199                    Some(binding @ &crate::Binding::BuiltIn { .. }) => binding,
7200                    _ => continue,
7201                };
7202                let name = match *name_key {
7203                    NameKey::StructMember(..) => &flattened_member_names[name_key],
7204                    _ => &self.names[name_key],
7205                };
7206
7207                if binding == &crate::Binding::BuiltIn(crate::BuiltIn::LocalInvocationId) {
7208                    local_invocation_id = Some(name_key);
7209                }
7210
7211                let ty_name = TypeContext {
7212                    handle: ty,
7213                    gctx: module.to_ctx(),
7214                    names: &self.names,
7215                    access: crate::StorageAccess::empty(),
7216                    first_time: false,
7217                };
7218
7219                match *binding {
7220                    crate::Binding::BuiltIn(crate::BuiltIn::VertexIndex) => {
7221                        v_existing_id = Some(name.clone());
7222                    }
7223                    crate::Binding::BuiltIn(crate::BuiltIn::InstanceIndex) => {
7224                        i_existing_id = Some(name.clone());
7225                    }
7226                    _ => {}
7227                };
7228
7229                let resolved = options.resolve_local_binding(binding, in_mode)?;
7230                write!(self.out, "{} {ty_name} {name}", separator())?;
7231                resolved.try_fmt(&mut self.out)?;
7232                writeln!(self.out)?;
7233            }
7234
7235            let need_workgroup_variables_initialization =
7236                self.need_workgroup_variables_initialization(options, ep, module, fun_info);
7237
7238            if need_workgroup_variables_initialization && local_invocation_id.is_none() {
7239                writeln!(
7240                    self.out,
7241                    "{} {NAMESPACE}::uint3 __local_invocation_id [[thread_position_in_threadgroup]]", separator()
7242                )?;
7243            }
7244
7245            // Those global variables used by this entry point and its callees
7246            // get passed as arguments. `Private` globals are an exception, they
7247            // don't outlive this invocation, so we declare them below as locals
7248            // within the entry point.
7249            for (handle, var) in module.global_variables.iter() {
7250                let usage = fun_info[handle];
7251                if usage.is_empty() || var.space == crate::AddressSpace::Private {
7252                    continue;
7253                }
7254
7255                if options.lang_version < (1, 2) {
7256                    match var.space {
7257                        // This restriction is not documented in the MSL spec
7258                        // but validation will fail if it is not upheld.
7259                        //
7260                        // We infer the required version from the "Function
7261                        // Buffer Read-Writes" section of [what's new], where
7262                        // the feature sets listed correspond with the ones
7263                        // supporting MSL 1.2.
7264                        //
7265                        // [what's new]: https://developer.apple.com/library/archive/documentation/Miscellaneous/Conceptual/MetalProgrammingGuide/WhatsNewiniOS10tvOS10andOSX1012/WhatsNewiniOS10tvOS10andOSX1012.html
7266                        crate::AddressSpace::Storage { access }
7267                            if access.contains(crate::StorageAccess::STORE)
7268                                && ep.stage == crate::ShaderStage::Fragment =>
7269                        {
7270                            return Err(Error::UnsupportedWriteableStorageBuffer)
7271                        }
7272                        crate::AddressSpace::Handle => {
7273                            match module.types[var.ty].inner {
7274                                crate::TypeInner::Image {
7275                                    class: crate::ImageClass::Storage { access, .. },
7276                                    ..
7277                                } => {
7278                                    // This restriction is not documented in the MSL spec
7279                                    // but validation will fail if it is not upheld.
7280                                    //
7281                                    // We infer the required version from the "Function
7282                                    // Texture Read-Writes" section of [what's new], where
7283                                    // the feature sets listed correspond with the ones
7284                                    // supporting MSL 1.2.
7285                                    //
7286                                    // [what's new]: https://developer.apple.com/library/archive/documentation/Miscellaneous/Conceptual/MetalProgrammingGuide/WhatsNewiniOS10tvOS10andOSX1012/WhatsNewiniOS10tvOS10andOSX1012.html
7287                                    if access.contains(crate::StorageAccess::STORE)
7288                                        && (ep.stage == crate::ShaderStage::Vertex
7289                                            || ep.stage == crate::ShaderStage::Fragment)
7290                                    {
7291                                        return Err(Error::UnsupportedWriteableStorageTexture(
7292                                            ep.stage,
7293                                        ));
7294                                    }
7295
7296                                    if access.contains(
7297                                        crate::StorageAccess::LOAD | crate::StorageAccess::STORE,
7298                                    ) {
7299                                        return Err(Error::UnsupportedRWStorageTexture);
7300                                    }
7301                                }
7302                                _ => {}
7303                            }
7304                        }
7305                        _ => {}
7306                    }
7307                }
7308
7309                // Check min MSL version for binding arrays
7310                match var.space {
7311                    crate::AddressSpace::Handle => match module.types[var.ty].inner {
7312                        crate::TypeInner::BindingArray { base, .. } => {
7313                            match module.types[base].inner {
7314                                crate::TypeInner::Sampler { .. } => {
7315                                    if options.lang_version < (2, 0) {
7316                                        return Err(Error::UnsupportedArrayOf(
7317                                            "samplers".to_string(),
7318                                        ));
7319                                    }
7320                                }
7321                                crate::TypeInner::Image { class, .. } => match class {
7322                                    crate::ImageClass::Sampled { .. }
7323                                    | crate::ImageClass::Depth { .. }
7324                                    | crate::ImageClass::Storage {
7325                                        access: crate::StorageAccess::LOAD,
7326                                        ..
7327                                    } => {
7328                                        // Array of textures since:
7329                                        // - iOS: Metal 1.2 (check depends on https://github.com/gfx-rs/naga/issues/2164)
7330                                        // - macOS: Metal 2
7331
7332                                        if options.lang_version < (2, 0) {
7333                                            return Err(Error::UnsupportedArrayOf(
7334                                                "textures".to_string(),
7335                                            ));
7336                                        }
7337                                    }
7338                                    crate::ImageClass::Storage {
7339                                        access: crate::StorageAccess::STORE,
7340                                        ..
7341                                    } => {
7342                                        // Array of write-only textures since:
7343                                        // - iOS: Metal 2.2 (check depends on https://github.com/gfx-rs/naga/issues/2164)
7344                                        // - macOS: Metal 2
7345
7346                                        if options.lang_version < (2, 0) {
7347                                            return Err(Error::UnsupportedArrayOf(
7348                                                "write-only textures".to_string(),
7349                                            ));
7350                                        }
7351                                    }
7352                                    crate::ImageClass::Storage { .. } => {
7353                                        if options.lang_version < (3, 0) {
7354                                            return Err(Error::UnsupportedArrayOf(
7355                                                "read-write textures".to_string(),
7356                                            ));
7357                                        }
7358                                    }
7359                                    crate::ImageClass::External => {
7360                                        return Err(Error::UnsupportedArrayOf(
7361                                            "external textures".to_string(),
7362                                        ));
7363                                    }
7364                                },
7365                                _ => {
7366                                    return Err(Error::UnsupportedArrayOfType(base));
7367                                }
7368                            }
7369                        }
7370                        _ => {}
7371                    },
7372                    _ => {}
7373                }
7374
7375                // the resolves have already been checked for `!fake_missing_bindings` case
7376                let resolved = match var.space {
7377                    crate::AddressSpace::Immediate => options.resolve_immediates(ep).ok(),
7378                    crate::AddressSpace::WorkGroup => None,
7379                    _ => options
7380                        .resolve_resource_binding(ep, var.binding.as_ref().unwrap())
7381                        .ok(),
7382                };
7383                if let Some(ref resolved) = resolved {
7384                    // Inline samplers are be defined in the EP body
7385                    if resolved.as_inline_sampler(options).is_some() {
7386                        continue;
7387                    }
7388                }
7389
7390                match module.types[var.ty].inner {
7391                    crate::TypeInner::Image {
7392                        class: crate::ImageClass::External,
7393                        ..
7394                    } => {
7395                        // External texture global variables get lowered to 3 textures
7396                        // and a constant buffer. We must emit a separate argument for
7397                        // each of these.
7398                        let target = match resolved {
7399                            Some(back::msl::ResolvedBinding::Resource(target)) => {
7400                                target.external_texture
7401                            }
7402                            _ => None,
7403                        };
7404
7405                        for i in 0..3 {
7406                            write!(self.out, "{} ", separator())?;
7407
7408                            let plane_name = &self.names[&NameKey::ExternalTextureGlobalVariable(
7409                                handle,
7410                                ExternalTextureNameKey::Plane(i),
7411                            )];
7412                            write!(
7413                              self.out,
7414                              "{NAMESPACE}::texture2d<float, {NAMESPACE}::access::sample> {plane_name}"
7415                            )?;
7416                            if let Some(ref target) = target {
7417                                write!(self.out, " [[texture({})]]", target.planes[i])?;
7418                            }
7419                            writeln!(self.out)?;
7420                        }
7421                        let params_ty_name = &self.names
7422                            [&NameKey::Type(module.special_types.external_texture_params.unwrap())];
7423                        let params_name = &self.names[&NameKey::ExternalTextureGlobalVariable(
7424                            handle,
7425                            ExternalTextureNameKey::Params,
7426                        )];
7427                        write!(self.out, "{} ", separator())?;
7428                        write!(self.out, "constant {params_ty_name}& {params_name}")?;
7429                        if let Some(ref target) = target {
7430                            write!(self.out, " [[buffer({})]]", target.params)?;
7431                        }
7432                    }
7433                    _ => {
7434                        let tyvar = TypedGlobalVariable {
7435                            module,
7436                            names: &self.names,
7437                            handle,
7438                            usage,
7439                            reference: true,
7440                        };
7441                        write!(self.out, "{} ", separator())?;
7442                        tyvar.try_fmt(&mut self.out)?;
7443                        if let Some(resolved) = resolved {
7444                            resolved.try_fmt(&mut self.out)?;
7445                        }
7446                        if let Some(value) = var.init {
7447                            write!(self.out, " = ")?;
7448                            self.put_const_expression(
7449                                value,
7450                                module,
7451                                mod_info,
7452                                &module.global_expressions,
7453                            )?;
7454                        }
7455                    }
7456                }
7457                writeln!(self.out)?;
7458            }
7459
7460            if do_vertex_pulling {
7461                if needs_vertex_id && v_existing_id.is_none() {
7462                    // Write the [[vertex_id]] argument.
7463                    writeln!(self.out, "{} uint {v_id} [[vertex_id]]", separator())?;
7464                }
7465
7466                if needs_instance_id && i_existing_id.is_none() {
7467                    writeln!(self.out, "{} uint {i_id} [[instance_id]]", separator())?;
7468                }
7469
7470                // Iterate vbm_resolved, output one argument for every vertex buffer,
7471                // using the names we generated earlier.
7472                for vbm in &vbm_resolved {
7473                    let id = &vbm.id;
7474                    let ty_name = &vbm.ty_name;
7475                    let param_name = &vbm.param_name;
7476                    writeln!(
7477                        self.out,
7478                        "{} const device {ty_name}* {param_name} [[buffer({id})]]",
7479                        separator()
7480                    )?;
7481                }
7482            }
7483
7484            // If this entry uses any variable-length arrays, their sizes are
7485            // passed as a final struct-typed argument.
7486            if needs_buffer_sizes {
7487                // this is checked earlier
7488                let resolved = options.resolve_sizes_buffer(ep).unwrap();
7489                write!(
7490                    self.out,
7491                    "{} constant _mslBufferSizes& _buffer_sizes",
7492                    separator()
7493                )?;
7494                resolved.try_fmt(&mut self.out)?;
7495                writeln!(self.out)?;
7496            }
7497
7498            // end of the entry point argument list
7499            writeln!(self.out, ") {{")?;
7500
7501            // Starting the function body.
7502            if do_vertex_pulling {
7503                // Provide zero values for all the attributes, which we will overwrite with
7504                // real data from the vertex attribute buffers, if the indices are in-bounds.
7505                for vbm in &vbm_resolved {
7506                    for attribute in vbm.attributes {
7507                        let location = attribute.shader_location;
7508                        let am_option = am_resolved.get(&location);
7509                        if am_option.is_none() {
7510                            // This bound attribute isn't used in this entry point, so
7511                            // don't bother zero-initializing it.
7512                            continue;
7513                        }
7514                        let am = am_option.unwrap();
7515                        let attribute_ty_name = &am.ty_name;
7516                        let attribute_name = &am.name;
7517
7518                        writeln!(
7519                            self.out,
7520                            "{}{attribute_ty_name} {attribute_name} = {{}};",
7521                            back::Level(1)
7522                        )?;
7523                    }
7524
7525                    // Output a bounds check block that will set real values for the
7526                    // attributes, if the bounds are satisfied.
7527                    write!(self.out, "{}if (", back::Level(1))?;
7528
7529                    let idx = &vbm.id;
7530                    let stride = &vbm.stride;
7531                    let index_name = match vbm.step_mode {
7532                        back::msl::VertexBufferStepMode::Constant => "0",
7533                        back::msl::VertexBufferStepMode::ByVertex => {
7534                            if let Some(ref name) = v_existing_id {
7535                                name
7536                            } else {
7537                                &v_id
7538                            }
7539                        }
7540                        back::msl::VertexBufferStepMode::ByInstance => {
7541                            if let Some(ref name) = i_existing_id {
7542                                name
7543                            } else {
7544                                &i_id
7545                            }
7546                        }
7547                    };
7548                    write!(
7549                        self.out,
7550                        "{index_name} < (_buffer_sizes.buffer_size{idx} / {stride})"
7551                    )?;
7552
7553                    writeln!(self.out, ") {{")?;
7554
7555                    // Pull the bytes out of the vertex buffer.
7556                    let ty_name = &vbm.ty_name;
7557                    let elem_name = &vbm.elem_name;
7558                    let param_name = &vbm.param_name;
7559
7560                    writeln!(
7561                        self.out,
7562                        "{}const {ty_name} {elem_name} = {param_name}[{index_name}];",
7563                        back::Level(2),
7564                    )?;
7565
7566                    // Now set real values for each of the attributes, by unpacking the data
7567                    // from the buffer elements.
7568                    for attribute in vbm.attributes {
7569                        let location = attribute.shader_location;
7570                        let Some(am) = am_resolved.get(&location) else {
7571                            // This bound attribute isn't used in this entry point, so
7572                            // don't bother extracting the data. Too bad we emitted the
7573                            // unpacking function earlier -- it might not get used.
7574                            continue;
7575                        };
7576                        let attribute_name = &am.name;
7577                        let attribute_ty_name = &am.ty_name;
7578
7579                        let offset = attribute.offset;
7580                        let func = unpacking_functions
7581                            .get(&attribute.format)
7582                            .expect("Should have generated this unpacking function earlier.");
7583                        let func_name = &func.name;
7584
7585                        // Check dimensionality of the attribute compared to the unpacking
7586                        // function. If attribute dimension > unpack dimension, we have to
7587                        // pad out the unpack value from a vec4(0, 0, 0, 1) of matching
7588                        // scalar type. Otherwise, if attribute dimension is < unpack
7589                        // dimension, then we need to explicitly truncate the result.
7590
7591                        let needs_padding_or_truncation = am.dimension.cmp(&func.dimension);
7592
7593                        if needs_padding_or_truncation != Ordering::Equal {
7594                            // Emit a comment flagging that a conversion is happening,
7595                            // since the actual logic can be at the end of a long line.
7596                            writeln!(
7597                                self.out,
7598                                "{}// {attribute_ty_name} <- {:?}",
7599                                back::Level(2),
7600                                attribute.format
7601                            )?;
7602                        }
7603
7604                        write!(self.out, "{}{attribute_name} = ", back::Level(2),)?;
7605
7606                        if needs_padding_or_truncation == Ordering::Greater {
7607                            // Needs padding: emit constructor call for wider type
7608                            write!(self.out, "{attribute_ty_name}(")?;
7609                        }
7610
7611                        // Emit call to unpacking function
7612                        write!(self.out, "{func_name}({elem_name}.data[{offset}]",)?;
7613                        for i in (offset + 1)..(offset + func.byte_count) {
7614                            write!(self.out, ", {elem_name}.data[{i}]")?;
7615                        }
7616                        write!(self.out, ")")?;
7617
7618                        match needs_padding_or_truncation {
7619                            Ordering::Greater => {
7620                                // Padding
7621                                let zero_value = if am.ty_is_int { "0" } else { "0.0" };
7622                                let one_value = if am.ty_is_int { "1" } else { "1.0" };
7623                                for i in func.dimension..am.dimension {
7624                                    write!(
7625                                        self.out,
7626                                        ", {}",
7627                                        if i == 3 { one_value } else { zero_value }
7628                                    )?;
7629                                }
7630                                write!(self.out, ")")?;
7631                            }
7632                            Ordering::Less => {
7633                                // Truncate to the first `am.dimension` components
7634                                write!(
7635                                    self.out,
7636                                    ".{}",
7637                                    &"xyzw"[0..usize::try_from(am.dimension).unwrap()]
7638                                )?;
7639                            }
7640                            Ordering::Equal => {}
7641                        }
7642
7643                        writeln!(self.out, ";")?;
7644                    }
7645
7646                    // End the bounds check / attribute setting block.
7647                    writeln!(self.out, "{}}}", back::Level(1))?;
7648                }
7649            }
7650
7651            if need_workgroup_variables_initialization {
7652                self.write_workgroup_variables_initialization(
7653                    module,
7654                    mod_info,
7655                    fun_info,
7656                    local_invocation_id,
7657                )?;
7658            }
7659
7660            // Metal doesn't support private mutable variables outside of functions,
7661            // so we put them here, just like the locals.
7662            for (handle, var) in module.global_variables.iter() {
7663                let usage = fun_info[handle];
7664                if usage.is_empty() {
7665                    continue;
7666                }
7667                if var.space == crate::AddressSpace::Private {
7668                    let tyvar = TypedGlobalVariable {
7669                        module,
7670                        names: &self.names,
7671                        handle,
7672                        usage,
7673
7674                        reference: false,
7675                    };
7676                    write!(self.out, "{}", back::INDENT)?;
7677                    tyvar.try_fmt(&mut self.out)?;
7678                    match var.init {
7679                        Some(value) => {
7680                            write!(self.out, " = ")?;
7681                            self.put_const_expression(
7682                                value,
7683                                module,
7684                                mod_info,
7685                                &module.global_expressions,
7686                            )?;
7687                            writeln!(self.out, ";")?;
7688                        }
7689                        None => {
7690                            writeln!(self.out, " = {{}};")?;
7691                        }
7692                    };
7693                } else if let Some(ref binding) = var.binding {
7694                    let resolved = options.resolve_resource_binding(ep, binding).unwrap();
7695                    if let Some(sampler) = resolved.as_inline_sampler(options) {
7696                        // write an inline sampler
7697                        let name = &self.names[&NameKey::GlobalVariable(handle)];
7698                        writeln!(
7699                            self.out,
7700                            "{}constexpr {}::sampler {}(",
7701                            back::INDENT,
7702                            NAMESPACE,
7703                            name
7704                        )?;
7705                        self.put_inline_sampler_properties(back::Level(2), sampler)?;
7706                        writeln!(self.out, "{});", back::INDENT)?;
7707                    } else if let crate::TypeInner::Image {
7708                        class: crate::ImageClass::External,
7709                        ..
7710                    } = module.types[var.ty].inner
7711                    {
7712                        // Wrap the individual arguments for each external texture global
7713                        // in a struct which can be easily passed around.
7714                        let wrapper_name = &self.names[&NameKey::GlobalVariable(handle)];
7715                        let l1 = back::Level(1);
7716                        let l2 = l1.next();
7717                        writeln!(
7718                            self.out,
7719                            "{l1}const {EXTERNAL_TEXTURE_WRAPPER_STRUCT} {wrapper_name} {{"
7720                        )?;
7721                        for i in 0..3 {
7722                            let plane_name = &self.names[&NameKey::ExternalTextureGlobalVariable(
7723                                handle,
7724                                ExternalTextureNameKey::Plane(i),
7725                            )];
7726                            writeln!(self.out, "{l2}.plane{i} = {plane_name},")?;
7727                        }
7728                        let params_name = &self.names[&NameKey::ExternalTextureGlobalVariable(
7729                            handle,
7730                            ExternalTextureNameKey::Params,
7731                        )];
7732                        writeln!(self.out, "{l2}.params = {params_name},")?;
7733                        writeln!(self.out, "{l1}}};")?;
7734                    }
7735                }
7736            }
7737
7738            // Now take the arguments that we gathered into structs, and the
7739            // structs that we flattened into arguments, and emit local
7740            // variables with initializers that put everything back the way the
7741            // body code expects.
7742            //
7743            // If we had to generate fresh names for struct members passed as
7744            // arguments, be sure to use those names when rebuilding the struct.
7745            //
7746            // "Each day, I change some zeros to ones, and some ones to zeros.
7747            // The rest, I leave alone."
7748            for (arg_index, arg) in fun.arguments.iter().enumerate() {
7749                let arg_name =
7750                    &self.names[&NameKey::EntryPointArgument(ep_index as _, arg_index as u32)];
7751                match module.types[arg.ty].inner {
7752                    crate::TypeInner::Struct { ref members, .. } => {
7753                        let struct_name = &self.names[&NameKey::Type(arg.ty)];
7754                        write!(
7755                            self.out,
7756                            "{}const {} {} = {{ ",
7757                            back::INDENT,
7758                            struct_name,
7759                            arg_name
7760                        )?;
7761                        for (member_index, member) in members.iter().enumerate() {
7762                            let key = NameKey::StructMember(arg.ty, member_index as u32);
7763                            let name = &flattened_member_names[&key];
7764                            if member_index != 0 {
7765                                write!(self.out, ", ")?;
7766                            }
7767                            // insert padding initialization, if needed
7768                            if self
7769                                .struct_member_pads
7770                                .contains(&(arg.ty, member_index as u32))
7771                            {
7772                                write!(self.out, "{{}}, ")?;
7773                            }
7774                            if let Some(crate::Binding::Location { .. }) = member.binding {
7775                                if has_varyings {
7776                                    write!(self.out, "{varyings_member_name}.")?;
7777                                }
7778                            }
7779                            write!(self.out, "{name}")?;
7780                        }
7781                        writeln!(self.out, " }};")?;
7782                    }
7783                    _ => match arg.binding {
7784                        Some(crate::Binding::Location { .. })
7785                        | Some(crate::Binding::BuiltIn(crate::BuiltIn::Barycentric { .. })) => {
7786                            if has_varyings {
7787                                writeln!(
7788                                    self.out,
7789                                    "{}const auto {} = {}.{};",
7790                                    back::INDENT,
7791                                    arg_name,
7792                                    varyings_member_name,
7793                                    arg_name
7794                                )?;
7795                            }
7796                        }
7797                        _ => {}
7798                    },
7799                }
7800            }
7801
7802            let guarded_indices =
7803                index::find_checked_indexes(module, fun, fun_info, options.bounds_check_policies);
7804
7805            let context = StatementContext {
7806                expression: ExpressionContext {
7807                    function: fun,
7808                    origin: FunctionOrigin::EntryPoint(ep_index as _),
7809                    info: fun_info,
7810                    lang_version: options.lang_version,
7811                    policies: options.bounds_check_policies,
7812                    guarded_indices,
7813                    module,
7814                    mod_info,
7815                    pipeline_options,
7816                    force_loop_bounding: options.force_loop_bounding,
7817                },
7818                result_struct: Some(&stage_out_name),
7819            };
7820
7821            // Finally, declare all the local variables that we need
7822            //TODO: we can postpone this till the relevant expressions are emitted
7823            self.put_locals(&context.expression)?;
7824            self.update_expressions_to_bake(fun, fun_info, &context.expression);
7825            self.put_block(back::Level(1), &fun.body, &context)?;
7826            writeln!(self.out, "}}")?;
7827            if ep_index + 1 != module.entry_points.len() {
7828                writeln!(self.out)?;
7829            }
7830            self.named_expressions.clear();
7831        }
7832
7833        Ok(info)
7834    }
7835
7836    fn write_barrier(&mut self, flags: crate::Barrier, level: back::Level) -> BackendResult {
7837        // Note: OR-ring bitflags requires `__HAVE_MEMFLAG_OPERATORS__`,
7838        // so we try to avoid it here.
7839        if flags.is_empty() {
7840            writeln!(
7841                self.out,
7842                "{level}{NAMESPACE}::threadgroup_barrier({NAMESPACE}::mem_flags::mem_none);",
7843            )?;
7844        }
7845        if flags.contains(crate::Barrier::STORAGE) {
7846            writeln!(
7847                self.out,
7848                "{level}{NAMESPACE}::threadgroup_barrier({NAMESPACE}::mem_flags::mem_device);",
7849            )?;
7850        }
7851        if flags.contains(crate::Barrier::WORK_GROUP) {
7852            writeln!(
7853                self.out,
7854                "{level}{NAMESPACE}::threadgroup_barrier({NAMESPACE}::mem_flags::mem_threadgroup);",
7855            )?;
7856        }
7857        if flags.contains(crate::Barrier::SUB_GROUP) {
7858            writeln!(
7859                self.out,
7860                "{level}{NAMESPACE}::simdgroup_barrier({NAMESPACE}::mem_flags::mem_threadgroup);",
7861            )?;
7862        }
7863        if flags.contains(crate::Barrier::TEXTURE) {
7864            writeln!(
7865                self.out,
7866                "{level}{NAMESPACE}::threadgroup_barrier({NAMESPACE}::mem_flags::mem_texture);",
7867            )?;
7868        }
7869        Ok(())
7870    }
7871}
7872
7873/// Initializing workgroup variables is more tricky for Metal because we have to deal
7874/// with atomics at the type-level (which don't have a copy constructor).
7875mod workgroup_mem_init {
7876    use crate::EntryPoint;
7877
7878    use super::*;
7879
7880    enum Access {
7881        GlobalVariable(Handle<crate::GlobalVariable>),
7882        StructMember(Handle<crate::Type>, u32),
7883        Array(usize),
7884    }
7885
7886    impl Access {
7887        fn write<W: Write>(
7888            &self,
7889            writer: &mut W,
7890            names: &FastHashMap<NameKey, String>,
7891        ) -> Result<(), core::fmt::Error> {
7892            match *self {
7893                Access::GlobalVariable(handle) => {
7894                    write!(writer, "{}", &names[&NameKey::GlobalVariable(handle)])
7895                }
7896                Access::StructMember(handle, index) => {
7897                    write!(writer, ".{}", &names[&NameKey::StructMember(handle, index)])
7898                }
7899                Access::Array(depth) => write!(writer, ".{WRAPPED_ARRAY_FIELD}[__i{depth}]"),
7900            }
7901        }
7902    }
7903
7904    struct AccessStack {
7905        stack: Vec<Access>,
7906        array_depth: usize,
7907    }
7908
7909    impl AccessStack {
7910        const fn new() -> Self {
7911            Self {
7912                stack: Vec::new(),
7913                array_depth: 0,
7914            }
7915        }
7916
7917        fn enter_array<R>(&mut self, cb: impl FnOnce(&mut Self, usize) -> R) -> R {
7918            let array_depth = self.array_depth;
7919            self.stack.push(Access::Array(array_depth));
7920            self.array_depth += 1;
7921            let res = cb(self, array_depth);
7922            self.stack.pop();
7923            self.array_depth -= 1;
7924            res
7925        }
7926
7927        fn enter<R>(&mut self, new: Access, cb: impl FnOnce(&mut Self) -> R) -> R {
7928            self.stack.push(new);
7929            let res = cb(self);
7930            self.stack.pop();
7931            res
7932        }
7933
7934        fn write<W: Write>(
7935            &self,
7936            writer: &mut W,
7937            names: &FastHashMap<NameKey, String>,
7938        ) -> Result<(), core::fmt::Error> {
7939            for next in self.stack.iter() {
7940                next.write(writer, names)?;
7941            }
7942            Ok(())
7943        }
7944    }
7945
7946    impl<W: Write> Writer<W> {
7947        pub(super) fn need_workgroup_variables_initialization(
7948            &mut self,
7949            options: &Options,
7950            ep: &EntryPoint,
7951            module: &crate::Module,
7952            fun_info: &valid::FunctionInfo,
7953        ) -> bool {
7954            options.zero_initialize_workgroup_memory
7955                && ep.stage.compute_like()
7956                && module.global_variables.iter().any(|(handle, var)| {
7957                    !fun_info[handle].is_empty() && var.space == crate::AddressSpace::WorkGroup
7958                })
7959        }
7960
7961        pub(super) fn write_workgroup_variables_initialization(
7962            &mut self,
7963            module: &crate::Module,
7964            module_info: &valid::ModuleInfo,
7965            fun_info: &valid::FunctionInfo,
7966            local_invocation_id: Option<&NameKey>,
7967        ) -> BackendResult {
7968            let level = back::Level(1);
7969
7970            writeln!(
7971                self.out,
7972                "{}if ({}::all({} == {}::uint3(0u))) {{",
7973                level,
7974                NAMESPACE,
7975                local_invocation_id
7976                    .map(|name_key| self.names[name_key].as_str())
7977                    .unwrap_or("__local_invocation_id"),
7978                NAMESPACE,
7979            )?;
7980
7981            let mut access_stack = AccessStack::new();
7982
7983            let vars = module.global_variables.iter().filter(|&(handle, var)| {
7984                !fun_info[handle].is_empty() && var.space == crate::AddressSpace::WorkGroup
7985            });
7986
7987            for (handle, var) in vars {
7988                access_stack.enter(Access::GlobalVariable(handle), |access_stack| {
7989                    self.write_workgroup_variable_initialization(
7990                        module,
7991                        module_info,
7992                        var.ty,
7993                        access_stack,
7994                        level.next(),
7995                    )
7996                })?;
7997            }
7998
7999            writeln!(self.out, "{level}}}")?;
8000            self.write_barrier(crate::Barrier::WORK_GROUP, level)
8001        }
8002
8003        fn write_workgroup_variable_initialization(
8004            &mut self,
8005            module: &crate::Module,
8006            module_info: &valid::ModuleInfo,
8007            ty: Handle<crate::Type>,
8008            access_stack: &mut AccessStack,
8009            level: back::Level,
8010        ) -> BackendResult {
8011            if module_info[ty].contains(valid::TypeFlags::CONSTRUCTIBLE) {
8012                write!(self.out, "{level}")?;
8013                access_stack.write(&mut self.out, &self.names)?;
8014                writeln!(self.out, " = {{}};")?;
8015            } else {
8016                match module.types[ty].inner {
8017                    crate::TypeInner::Atomic { .. } => {
8018                        write!(
8019                            self.out,
8020                            "{level}{NAMESPACE}::atomic_store_explicit({ATOMIC_REFERENCE}"
8021                        )?;
8022                        access_stack.write(&mut self.out, &self.names)?;
8023                        writeln!(self.out, ", 0, {NAMESPACE}::memory_order_relaxed);")?;
8024                    }
8025                    crate::TypeInner::Array { base, size, .. } => {
8026                        let count = match size.resolve(module.to_ctx())? {
8027                            proc::IndexableLength::Known(count) => count,
8028                            proc::IndexableLength::Dynamic => unreachable!(),
8029                        };
8030
8031                        access_stack.enter_array(|access_stack, array_depth| {
8032                            writeln!(
8033                                self.out,
8034                                "{level}for (int __i{array_depth} = 0; __i{array_depth} < {count}; __i{array_depth}++) {{"
8035                            )?;
8036                            self.write_workgroup_variable_initialization(
8037                                module,
8038                                module_info,
8039                                base,
8040                                access_stack,
8041                                level.next(),
8042                            )?;
8043                            writeln!(self.out, "{level}}}")?;
8044                            BackendResult::Ok(())
8045                        })?;
8046                    }
8047                    crate::TypeInner::Struct { ref members, .. } => {
8048                        for (index, member) in members.iter().enumerate() {
8049                            access_stack.enter(
8050                                Access::StructMember(ty, index as u32),
8051                                |access_stack| {
8052                                    self.write_workgroup_variable_initialization(
8053                                        module,
8054                                        module_info,
8055                                        member.ty,
8056                                        access_stack,
8057                                        level,
8058                                    )
8059                                },
8060                            )?;
8061                        }
8062                    }
8063                    _ => unreachable!(),
8064                }
8065            }
8066
8067            Ok(())
8068        }
8069    }
8070}
8071
8072impl crate::AtomicFunction {
8073    const fn to_msl(self) -> &'static str {
8074        match self {
8075            Self::Add => "fetch_add",
8076            Self::Subtract => "fetch_sub",
8077            Self::And => "fetch_and",
8078            Self::InclusiveOr => "fetch_or",
8079            Self::ExclusiveOr => "fetch_xor",
8080            Self::Min => "fetch_min",
8081            Self::Max => "fetch_max",
8082            Self::Exchange { compare: None } => "exchange",
8083            Self::Exchange { compare: Some(_) } => ATOMIC_COMP_EXCH_FUNCTION,
8084        }
8085    }
8086
8087    fn to_msl_64_bit(self) -> Result<&'static str, Error> {
8088        Ok(match self {
8089            Self::Min => "min",
8090            Self::Max => "max",
8091            _ => Err(Error::FeatureNotImplemented(
8092                "64-bit atomic operation other than min/max".to_string(),
8093            ))?,
8094        })
8095    }
8096}