naga/back/msl/
writer.rs

1use alloc::{
2    format,
3    string::{String, ToString},
4    vec,
5    vec::Vec,
6};
7use core::{
8    cmp::Ordering,
9    fmt::{Display, Error as FmtError, Formatter, Write},
10    iter,
11};
12use num_traits::real::Real as _;
13
14use half::f16;
15
16use super::{sampler as sm, Error, LocationMode, Options, PipelineOptions, TranslationInfo};
17use crate::{
18    arena::{Handle, HandleSet},
19    back::{self, get_entry_points, Baked},
20    common,
21    proc::{
22        self, concrete_int_scalars,
23        index::{self, BoundsCheck},
24        ExternalTextureNameKey, NameKey, TypeResolution,
25    },
26    valid, FastHashMap, FastHashSet,
27};
28
29#[cfg(test)]
30use core::ptr;
31
32/// Shorthand result used internally by the backend
33type BackendResult = Result<(), Error>;
34
35const NAMESPACE: &str = "metal";
36// The name of the array member of the Metal struct types we generate to
37// represent Naga `Array` types. See the comments in `Writer::write_type_defs`
38// for details.
39const WRAPPED_ARRAY_FIELD: &str = "inner";
40// This is a hack: we need to pass a pointer to an atomic,
41// but generally the backend isn't putting "&" in front of every pointer.
42// Some more general handling of pointers is needed to be implemented here.
43const ATOMIC_REFERENCE: &str = "&";
44
45const RT_NAMESPACE: &str = "metal::raytracing";
46const RAY_QUERY_TYPE: &str = "_RayQuery";
47const RAY_QUERY_FIELD_INTERSECTOR: &str = "intersector";
48const RAY_QUERY_FIELD_INTERSECTION: &str = "intersection";
49const RAY_QUERY_MODERN_SUPPORT: bool = false; //TODO
50const RAY_QUERY_FIELD_READY: &str = "ready";
51const RAY_QUERY_FUN_MAP_INTERSECTION: &str = "_map_intersection_type";
52
53pub(crate) const ATOMIC_COMP_EXCH_FUNCTION: &str = "naga_atomic_compare_exchange_weak_explicit";
54pub(crate) const MODF_FUNCTION: &str = "naga_modf";
55pub(crate) const FREXP_FUNCTION: &str = "naga_frexp";
56pub(crate) const ABS_FUNCTION: &str = "naga_abs";
57pub(crate) const DIV_FUNCTION: &str = "naga_div";
58pub(crate) const DOT_FUNCTION_PREFIX: &str = "naga_dot";
59pub(crate) const MOD_FUNCTION: &str = "naga_mod";
60pub(crate) const NEG_FUNCTION: &str = "naga_neg";
61pub(crate) const F2I32_FUNCTION: &str = "naga_f2i32";
62pub(crate) const F2U32_FUNCTION: &str = "naga_f2u32";
63pub(crate) const F2I64_FUNCTION: &str = "naga_f2i64";
64pub(crate) const F2U64_FUNCTION: &str = "naga_f2u64";
65pub(crate) const IMAGE_LOAD_EXTERNAL_FUNCTION: &str = "nagaTextureLoadExternal";
66pub(crate) const IMAGE_SIZE_EXTERNAL_FUNCTION: &str = "nagaTextureDimensionsExternal";
67pub(crate) const IMAGE_SAMPLE_BASE_CLAMP_TO_EDGE_FUNCTION: &str =
68    "nagaTextureSampleBaseClampToEdge";
69/// For some reason, Metal does not let you have `metal::texture<..>*` as a buffer argument.
70/// However, if you put that texture inside a struct, everything is totally fine. This
71/// baffles me to no end.
72///
73/// As such, we wrap all argument buffers in a struct that has a single generic `<T>` field.
74/// This allows `NagaArgumentBufferWrapper<metal::texture<..>>*` to work. The astute among
75/// you have noticed that this should be exactly the same to the compiler, and you're correct.
76pub(crate) const ARGUMENT_BUFFER_WRAPPER_STRUCT: &str = "NagaArgumentBufferWrapper";
77/// Name of the struct that is declared to wrap the 3 textures and parameters
78/// buffer that [`crate::ImageClass::External`] variables are lowered to,
79/// allowing them to be conveniently passed to user-defined or wrapper
80/// functions. The struct is declared in [`Writer::write_type_defs`].
81pub(crate) const EXTERNAL_TEXTURE_WRAPPER_STRUCT: &str = "NagaExternalTextureWrapper";
82pub(crate) const COOPERATIVE_LOAD_FUNCTION: &str = "NagaCooperativeLoad";
83pub(crate) const COOPERATIVE_MULTIPLY_ADD_FUNCTION: &str = "NagaCooperativeMultiplyAdd";
84
85/// Write the Metal name for a Naga numeric type: scalar, vector, or matrix.
86///
87/// The `sizes` slice determines whether this function writes a
88/// scalar, vector, or matrix type:
89///
90/// - An empty slice produces a scalar type.
91/// - A one-element slice produces a vector type.
92/// - A two element slice `[ROWS COLUMNS]` produces a matrix of the given size.
93fn put_numeric_type(
94    out: &mut impl Write,
95    scalar: crate::Scalar,
96    sizes: &[crate::VectorSize],
97) -> Result<(), FmtError> {
98    match (scalar, sizes) {
99        (scalar, &[]) => {
100            write!(out, "{}", scalar.to_msl_name())
101        }
102        (scalar, &[rows]) => {
103            write!(
104                out,
105                "{}::{}{}",
106                NAMESPACE,
107                scalar.to_msl_name(),
108                common::vector_size_str(rows)
109            )
110        }
111        (scalar, &[rows, columns]) => {
112            write!(
113                out,
114                "{}::{}{}x{}",
115                NAMESPACE,
116                scalar.to_msl_name(),
117                common::vector_size_str(columns),
118                common::vector_size_str(rows)
119            )
120        }
121        (_, _) => Ok(()), // not meaningful
122    }
123}
124
125const fn scalar_is_int(scalar: crate::Scalar) -> bool {
126    use crate::ScalarKind::*;
127    match scalar.kind {
128        Sint | Uint | AbstractInt | Bool => true,
129        Float | AbstractFloat => false,
130    }
131}
132
133/// Prefix for cached clamped level-of-detail values for `ImageLoad` expressions.
134const CLAMPED_LOD_LOAD_PREFIX: &str = "clamped_lod_e";
135
136/// Prefix for reinterpreted expressions using `as_type<T>(...)`.
137const REINTERPRET_PREFIX: &str = "reinterpreted_";
138
139/// Wrapper for identifier names for clamped level-of-detail values
140///
141/// Values of this type implement [`core::fmt::Display`], formatting as
142/// the name of the variable used to hold the cached clamped
143/// level-of-detail value for an `ImageLoad` expression.
144struct ClampedLod(Handle<crate::Expression>);
145
146impl Display for ClampedLod {
147    fn fmt(&self, f: &mut Formatter<'_>) -> core::fmt::Result {
148        self.0.write_prefixed(f, CLAMPED_LOD_LOAD_PREFIX)
149    }
150}
151
152/// Wrapper for generating `struct _mslBufferSizes` member names for
153/// runtime-sized array lengths.
154///
155/// On Metal, `wgpu_hal` passes the element counts for all runtime-sized arrays
156/// as an argument to the entry point. This argument's type in the MSL is
157/// `struct _mslBufferSizes`, a Naga-synthesized struct with a `uint` member for
158/// each global variable containing a runtime-sized array.
159///
160/// If `global` is a [`Handle`] for a [`GlobalVariable`] that contains a
161/// runtime-sized array, then the value `ArraySize(global)` implements
162/// [`core::fmt::Display`], formatting as the name of the struct member carrying
163/// the number of elements in that runtime-sized array.
164///
165/// [`GlobalVariable`]: crate::GlobalVariable
166struct ArraySizeMember(Handle<crate::GlobalVariable>);
167
168impl Display for ArraySizeMember {
169    fn fmt(&self, f: &mut Formatter<'_>) -> core::fmt::Result {
170        self.0.write_prefixed(f, "size")
171    }
172}
173
174/// Wrapper for reinterpreted variables using `as_type<target_type>(orig)`.
175///
176/// Implements [`core::fmt::Display`], formatting as a name derived from
177/// `target_type` and the variable name of `orig`.
178#[derive(Clone, Copy)]
179struct Reinterpreted<'a> {
180    target_type: &'a str,
181    orig: Handle<crate::Expression>,
182}
183
184impl<'a> Reinterpreted<'a> {
185    const fn new(target_type: &'a str, orig: Handle<crate::Expression>) -> Self {
186        Self { target_type, orig }
187    }
188}
189
190impl Display for Reinterpreted<'_> {
191    fn fmt(&self, f: &mut Formatter<'_>) -> core::fmt::Result {
192        f.write_str(REINTERPRET_PREFIX)?;
193        f.write_str(self.target_type)?;
194        self.orig.write_prefixed(f, "_e")
195    }
196}
197
198struct TypeContext<'a> {
199    handle: Handle<crate::Type>,
200    gctx: proc::GlobalCtx<'a>,
201    names: &'a FastHashMap<NameKey, String>,
202    access: crate::StorageAccess,
203    first_time: bool,
204}
205
206impl TypeContext<'_> {
207    fn scalar(&self) -> Option<crate::Scalar> {
208        let ty = &self.gctx.types[self.handle];
209        ty.inner.scalar()
210    }
211
212    fn vertex_input_dimension(&self) -> u32 {
213        let ty = &self.gctx.types[self.handle];
214        match ty.inner {
215            crate::TypeInner::Scalar(_) => 1,
216            crate::TypeInner::Vector { size, .. } => size as u32,
217            _ => unreachable!(),
218        }
219    }
220}
221
222impl Display for TypeContext<'_> {
223    fn fmt(&self, out: &mut Formatter<'_>) -> Result<(), FmtError> {
224        let ty = &self.gctx.types[self.handle];
225        if ty.needs_alias() && !self.first_time {
226            let name = &self.names[&NameKey::Type(self.handle)];
227            return write!(out, "{name}");
228        }
229
230        match ty.inner {
231            crate::TypeInner::Scalar(scalar) => put_numeric_type(out, scalar, &[]),
232            crate::TypeInner::Atomic(scalar) => {
233                write!(out, "{}::atomic_{}", NAMESPACE, scalar.to_msl_name())
234            }
235            crate::TypeInner::Vector { size, scalar } => put_numeric_type(out, scalar, &[size]),
236            crate::TypeInner::Matrix {
237                columns,
238                rows,
239                scalar,
240            } => put_numeric_type(out, scalar, &[rows, columns]),
241            // Requires Metal-2.3
242            crate::TypeInner::CooperativeMatrix {
243                columns,
244                rows,
245                scalar,
246                role: _,
247            } => {
248                write!(
249                    out,
250                    "{NAMESPACE}::simdgroup_{}{}x{}",
251                    scalar.to_msl_name(),
252                    columns as u32,
253                    rows as u32,
254                )
255            }
256            crate::TypeInner::Pointer { base, space } => {
257                let sub = Self {
258                    handle: base,
259                    first_time: false,
260                    ..*self
261                };
262                let space_name = match space.to_msl_name() {
263                    Some(name) => name,
264                    None => return Ok(()),
265                };
266                write!(out, "{space_name} {sub}&")
267            }
268            crate::TypeInner::ValuePointer {
269                size,
270                scalar,
271                space,
272            } => {
273                match space.to_msl_name() {
274                    Some(name) => write!(out, "{name} ")?,
275                    None => return Ok(()),
276                };
277                match size {
278                    Some(rows) => put_numeric_type(out, scalar, &[rows])?,
279                    None => put_numeric_type(out, scalar, &[])?,
280                };
281
282                write!(out, "&")
283            }
284            crate::TypeInner::Array { base, .. } => {
285                let sub = Self {
286                    handle: base,
287                    first_time: false,
288                    ..*self
289                };
290                // Array lengths go at the end of the type definition,
291                // so just print the element type here.
292                write!(out, "{sub}")
293            }
294            crate::TypeInner::Struct { .. } => unreachable!(),
295            crate::TypeInner::Image {
296                dim,
297                arrayed,
298                class,
299            } => {
300                let dim_str = match dim {
301                    crate::ImageDimension::D1 => "1d",
302                    crate::ImageDimension::D2 => "2d",
303                    crate::ImageDimension::D3 => "3d",
304                    crate::ImageDimension::Cube => "cube",
305                };
306                let (texture_str, msaa_str, scalar, access) = match class {
307                    crate::ImageClass::Sampled { kind, multi } => {
308                        let (msaa_str, access) = if multi {
309                            ("_ms", "read")
310                        } else {
311                            ("", "sample")
312                        };
313                        let scalar = crate::Scalar { kind, width: 4 };
314                        ("texture", msaa_str, scalar, access)
315                    }
316                    crate::ImageClass::Depth { multi } => {
317                        let (msaa_str, access) = if multi {
318                            ("_ms", "read")
319                        } else {
320                            ("", "sample")
321                        };
322                        let scalar = crate::Scalar {
323                            kind: crate::ScalarKind::Float,
324                            width: 4,
325                        };
326                        ("depth", msaa_str, scalar, access)
327                    }
328                    crate::ImageClass::Storage { format, .. } => {
329                        let access = if self
330                            .access
331                            .contains(crate::StorageAccess::LOAD | crate::StorageAccess::STORE)
332                        {
333                            "read_write"
334                        } else if self.access.contains(crate::StorageAccess::STORE) {
335                            "write"
336                        } else if self.access.contains(crate::StorageAccess::LOAD) {
337                            "read"
338                        } else {
339                            log::warn!(
340                                "Storage access for {:?} (name '{}'): {:?}",
341                                self.handle,
342                                ty.name.as_deref().unwrap_or_default(),
343                                self.access
344                            );
345                            unreachable!("module is not valid");
346                        };
347                        ("texture", "", format.into(), access)
348                    }
349                    crate::ImageClass::External => {
350                        return write!(out, "{EXTERNAL_TEXTURE_WRAPPER_STRUCT}");
351                    }
352                };
353                let base_name = scalar.to_msl_name();
354                let array_str = if arrayed { "_array" } else { "" };
355                write!(
356                    out,
357                    "{NAMESPACE}::{texture_str}{dim_str}{msaa_str}{array_str}<{base_name}, {NAMESPACE}::access::{access}>",
358                )
359            }
360            crate::TypeInner::Sampler { comparison: _ } => {
361                write!(out, "{NAMESPACE}::sampler")
362            }
363            crate::TypeInner::AccelerationStructure { vertex_return } => {
364                if vertex_return {
365                    unimplemented!("metal does not support vertex ray hit return")
366                }
367                write!(out, "{RT_NAMESPACE}::instance_acceleration_structure")
368            }
369            crate::TypeInner::RayQuery { vertex_return } => {
370                if vertex_return {
371                    unimplemented!("metal does not support vertex ray hit return")
372                }
373                write!(out, "{RAY_QUERY_TYPE}")
374            }
375            crate::TypeInner::BindingArray { base, .. } => {
376                let base_tyname = Self {
377                    handle: base,
378                    first_time: false,
379                    ..*self
380                };
381
382                write!(
383                    out,
384                    "constant {ARGUMENT_BUFFER_WRAPPER_STRUCT}<{base_tyname}>*"
385                )
386            }
387        }
388    }
389}
390
391struct TypedGlobalVariable<'a> {
392    module: &'a crate::Module,
393    names: &'a FastHashMap<NameKey, String>,
394    handle: Handle<crate::GlobalVariable>,
395    usage: valid::GlobalUse,
396    reference: bool,
397}
398
399impl TypedGlobalVariable<'_> {
400    fn try_fmt<W: Write>(&self, out: &mut W) -> BackendResult {
401        let var = &self.module.global_variables[self.handle];
402        let name = &self.names[&NameKey::GlobalVariable(self.handle)];
403
404        let storage_access = match var.space {
405            crate::AddressSpace::Storage { access } => access,
406            _ => match self.module.types[var.ty].inner {
407                crate::TypeInner::Image {
408                    class: crate::ImageClass::Storage { access, .. },
409                    ..
410                } => access,
411                crate::TypeInner::BindingArray { base, .. } => {
412                    match self.module.types[base].inner {
413                        crate::TypeInner::Image {
414                            class: crate::ImageClass::Storage { access, .. },
415                            ..
416                        } => access,
417                        _ => crate::StorageAccess::default(),
418                    }
419                }
420                _ => crate::StorageAccess::default(),
421            },
422        };
423        let ty_name = TypeContext {
424            handle: var.ty,
425            gctx: self.module.to_ctx(),
426            names: self.names,
427            access: storage_access,
428            first_time: false,
429        };
430
431        let (space, access, reference) = match var.space.to_msl_name() {
432            Some(space) if self.reference => {
433                let access = if var.space.needs_access_qualifier()
434                    && !self.usage.intersects(valid::GlobalUse::WRITE)
435                {
436                    "const"
437                } else {
438                    ""
439                };
440                (space, access, "&")
441            }
442            _ => ("", "", ""),
443        };
444
445        Ok(write!(
446            out,
447            "{}{}{}{}{}{} {}",
448            space,
449            if space.is_empty() { "" } else { " " },
450            ty_name,
451            if access.is_empty() { "" } else { " " },
452            access,
453            reference,
454            name,
455        )?)
456    }
457}
458
459#[derive(Eq, PartialEq, Hash)]
460enum WrappedFunction {
461    UnaryOp {
462        op: crate::UnaryOperator,
463        ty: (Option<crate::VectorSize>, crate::Scalar),
464    },
465    BinaryOp {
466        op: crate::BinaryOperator,
467        left_ty: (Option<crate::VectorSize>, crate::Scalar),
468        right_ty: (Option<crate::VectorSize>, crate::Scalar),
469    },
470    Math {
471        fun: crate::MathFunction,
472        arg_ty: (Option<crate::VectorSize>, crate::Scalar),
473    },
474    Cast {
475        src_scalar: crate::Scalar,
476        vector_size: Option<crate::VectorSize>,
477        dst_scalar: crate::Scalar,
478    },
479    ImageLoad {
480        class: crate::ImageClass,
481    },
482    ImageSample {
483        class: crate::ImageClass,
484        clamp_to_edge: bool,
485    },
486    ImageQuerySize {
487        class: crate::ImageClass,
488    },
489    CooperativeLoad {
490        space_name: &'static str,
491        columns: crate::CooperativeSize,
492        rows: crate::CooperativeSize,
493        scalar: crate::Scalar,
494    },
495    CooperativeMultiplyAdd {
496        space_name: &'static str,
497        columns: crate::CooperativeSize,
498        rows: crate::CooperativeSize,
499        intermediate: crate::CooperativeSize,
500        scalar: crate::Scalar,
501    },
502}
503
504pub struct Writer<W> {
505    out: W,
506    names: FastHashMap<NameKey, String>,
507    named_expressions: crate::NamedExpressions,
508    /// Set of expressions that need to be baked to avoid unnecessary repetition in output
509    need_bake_expressions: back::NeedBakeExpressions,
510    namer: proc::Namer,
511    wrapped_functions: FastHashSet<WrappedFunction>,
512    #[cfg(test)]
513    put_expression_stack_pointers: FastHashSet<*const ()>,
514    #[cfg(test)]
515    put_block_stack_pointers: FastHashSet<*const ()>,
516    /// Set of (struct type, struct field index) denoting which fields require
517    /// padding inserted **before** them (i.e. between fields at index - 1 and index)
518    struct_member_pads: FastHashSet<(Handle<crate::Type>, u32)>,
519}
520
521impl crate::Scalar {
522    pub(super) fn to_msl_name(self) -> &'static str {
523        use crate::ScalarKind as Sk;
524        match self {
525            Self {
526                kind: Sk::Float,
527                width: 4,
528            } => "float",
529            Self {
530                kind: Sk::Float,
531                width: 2,
532            } => "half",
533            Self {
534                kind: Sk::Sint,
535                width: 4,
536            } => "int",
537            Self {
538                kind: Sk::Uint,
539                width: 4,
540            } => "uint",
541            Self {
542                kind: Sk::Sint,
543                width: 8,
544            } => "long",
545            Self {
546                kind: Sk::Uint,
547                width: 8,
548            } => "ulong",
549            Self {
550                kind: Sk::Bool,
551                width: _,
552            } => "bool",
553            Self {
554                kind: Sk::AbstractInt | Sk::AbstractFloat,
555                width: _,
556            } => unreachable!("Found Abstract scalar kind"),
557            _ => unreachable!("Unsupported scalar kind: {:?}", self),
558        }
559    }
560}
561
562const fn separate(need_separator: bool) -> &'static str {
563    if need_separator {
564        ","
565    } else {
566        ""
567    }
568}
569
570fn should_pack_struct_member(
571    members: &[crate::StructMember],
572    span: u32,
573    index: usize,
574    module: &crate::Module,
575) -> Option<crate::Scalar> {
576    let member = &members[index];
577
578    let ty_inner = &module.types[member.ty].inner;
579    let last_offset = member.offset + ty_inner.size(module.to_ctx());
580    let next_offset = match members.get(index + 1) {
581        Some(next) => next.offset,
582        None => span,
583    };
584    let is_tight = next_offset == last_offset;
585
586    match *ty_inner {
587        crate::TypeInner::Vector {
588            size: crate::VectorSize::Tri,
589            scalar: scalar @ crate::Scalar { width: 4 | 2, .. },
590        } if is_tight => Some(scalar),
591        _ => None,
592    }
593}
594
595fn needs_array_length(ty: Handle<crate::Type>, arena: &crate::UniqueArena<crate::Type>) -> bool {
596    match arena[ty].inner {
597        crate::TypeInner::Struct { ref members, .. } => {
598            if let Some(member) = members.last() {
599                if let crate::TypeInner::Array {
600                    size: crate::ArraySize::Dynamic,
601                    ..
602                } = arena[member.ty].inner
603                {
604                    return true;
605                }
606            }
607            false
608        }
609        crate::TypeInner::Array {
610            size: crate::ArraySize::Dynamic,
611            ..
612        } => true,
613        _ => false,
614    }
615}
616
617impl crate::AddressSpace {
618    /// Returns true if global variables in this address space are
619    /// passed in function arguments. These arguments need to be
620    /// passed through any functions called from the entry point.
621    const fn needs_pass_through(&self) -> bool {
622        match *self {
623            Self::Uniform
624            | Self::Storage { .. }
625            | Self::Private
626            | Self::WorkGroup
627            | Self::Immediate
628            | Self::Handle
629            | Self::TaskPayload => true,
630            Self::Function => false,
631            Self::RayPayload | Self::IncomingRayPayload => unreachable!(),
632        }
633    }
634
635    /// Returns true if the address space may need a "const" qualifier.
636    const fn needs_access_qualifier(&self) -> bool {
637        match *self {
638            //Note: we are ignoring the storage access here, and instead
639            // rely on the actual use of a global by functions. This means we
640            // may end up with "const" even if the binding is read-write,
641            // and that should be OK.
642            Self::Storage { .. } => true,
643            Self::TaskPayload | Self::RayPayload | Self::IncomingRayPayload => unimplemented!(),
644            // These should always be read-write.
645            Self::Private | Self::WorkGroup => false,
646            // These translate to `constant` address space, no need for qualifiers.
647            Self::Uniform | Self::Immediate => false,
648            // Not applicable.
649            Self::Handle | Self::Function => false,
650        }
651    }
652
653    const fn to_msl_name(self) -> Option<&'static str> {
654        match self {
655            Self::Handle => None,
656            Self::Uniform | Self::Immediate => Some("constant"),
657            Self::Storage { .. } => Some("device"),
658            // note for `RayPayload`, this probably needs to be emulated as a
659            // private variable, as metal has essentially an inout input
660            // for where it is passed.
661            Self::Private | Self::Function | Self::RayPayload => Some("thread"),
662            Self::WorkGroup => Some("threadgroup"),
663            Self::TaskPayload => Some("object_data"),
664            Self::IncomingRayPayload => Some("ray_data"),
665        }
666    }
667}
668
669impl crate::Type {
670    // Returns `true` if we need to emit an alias for this type.
671    const fn needs_alias(&self) -> bool {
672        use crate::TypeInner as Ti;
673
674        match self.inner {
675            // value types are concise enough, we only alias them if they are named
676            Ti::Scalar(_)
677            | Ti::Vector { .. }
678            | Ti::Matrix { .. }
679            | Ti::CooperativeMatrix { .. }
680            | Ti::Atomic(_)
681            | Ti::Pointer { .. }
682            | Ti::ValuePointer { .. } => self.name.is_some(),
683            // composite types are better to be aliased, regardless of the name
684            Ti::Struct { .. } | Ti::Array { .. } => true,
685            // handle types may be different, depending on the global var access, so we always inline them
686            Ti::Image { .. }
687            | Ti::Sampler { .. }
688            | Ti::AccelerationStructure { .. }
689            | Ti::RayQuery { .. }
690            | Ti::BindingArray { .. } => false,
691        }
692    }
693}
694
695#[derive(Clone, Copy)]
696enum FunctionOrigin {
697    Handle(Handle<crate::Function>),
698    EntryPoint(proc::EntryPointIndex),
699}
700
701trait NameKeyExt {
702    fn local(origin: FunctionOrigin, local_handle: Handle<crate::LocalVariable>) -> NameKey {
703        match origin {
704            FunctionOrigin::Handle(handle) => NameKey::FunctionLocal(handle, local_handle),
705            FunctionOrigin::EntryPoint(idx) => NameKey::EntryPointLocal(idx, local_handle),
706        }
707    }
708
709    /// Return the name key for a local variable used by ReadZeroSkipWrite bounds-check
710    /// policy when it needs to produce a pointer-typed result for an OOB access. These
711    /// are unique per accessed type, so the second argument is a type handle. See docs
712    /// for [`crate::back::msl`].
713    fn oob_local_for_type(origin: FunctionOrigin, ty: Handle<crate::Type>) -> NameKey {
714        match origin {
715            FunctionOrigin::Handle(handle) => NameKey::FunctionOobLocal(handle, ty),
716            FunctionOrigin::EntryPoint(idx) => NameKey::EntryPointOobLocal(idx, ty),
717        }
718    }
719}
720
721impl NameKeyExt for NameKey {}
722
723/// A level of detail argument.
724///
725/// When [`BoundsCheckPolicy::Restrict`] applies to an [`ImageLoad`] access, we
726/// save the clamped level of detail in a temporary variable whose name is based
727/// on the handle of the `ImageLoad` expression. But for other policies, we just
728/// use the expression directly.
729///
730/// [`BoundsCheckPolicy::Restrict`]: index::BoundsCheckPolicy::Restrict
731/// [`ImageLoad`]: crate::Expression::ImageLoad
732#[derive(Clone, Copy)]
733enum LevelOfDetail {
734    Direct(Handle<crate::Expression>),
735    Restricted(Handle<crate::Expression>),
736}
737
738/// Values needed to select a particular texel for [`ImageLoad`] and [`ImageStore`].
739///
740/// When this is used in code paths unconcerned with the `Restrict` bounds check
741/// policy, the `LevelOfDetail` enum introduces an unneeded match, since `level`
742/// will always be either `None` or `Some(Direct(_))`. But this turns out not to
743/// be too awkward. If that changes, we can revisit.
744///
745/// [`ImageLoad`]: crate::Expression::ImageLoad
746/// [`ImageStore`]: crate::Statement::ImageStore
747struct TexelAddress {
748    coordinate: Handle<crate::Expression>,
749    array_index: Option<Handle<crate::Expression>>,
750    sample: Option<Handle<crate::Expression>>,
751    level: Option<LevelOfDetail>,
752}
753
754struct ExpressionContext<'a> {
755    function: &'a crate::Function,
756    origin: FunctionOrigin,
757    info: &'a valid::FunctionInfo,
758    module: &'a crate::Module,
759    mod_info: &'a valid::ModuleInfo,
760    pipeline_options: &'a PipelineOptions,
761    lang_version: (u8, u8),
762    policies: index::BoundsCheckPolicies,
763
764    /// The set of expressions used as indices in `ReadZeroSkipWrite`-policy
765    /// accesses. These may need to be cached in temporary variables. See
766    /// `index::find_checked_indexes` for details.
767    guarded_indices: HandleSet<crate::Expression>,
768    /// See [`Writer::gen_force_bounded_loop_statements`] for details.
769    force_loop_bounding: bool,
770}
771
772impl<'a> ExpressionContext<'a> {
773    fn resolve_type(&self, handle: Handle<crate::Expression>) -> &'a crate::TypeInner {
774        self.info[handle].ty.inner_with(&self.module.types)
775    }
776
777    /// Return true if calls to `image`'s `read` and `write` methods should supply a level of detail.
778    ///
779    /// Only mipmapped images need to specify a level of detail. Since 1D
780    /// textures cannot have mipmaps, MSL requires that the level argument to
781    /// texture1d queries and accesses must be a constexpr 0. It's easiest
782    /// just to omit the level entirely for 1D textures.
783    fn image_needs_lod(&self, image: Handle<crate::Expression>) -> bool {
784        let image_ty = self.resolve_type(image);
785        if let crate::TypeInner::Image { dim, class, .. } = *image_ty {
786            class.is_mipmapped() && dim != crate::ImageDimension::D1
787        } else {
788            false
789        }
790    }
791
792    fn choose_bounds_check_policy(
793        &self,
794        pointer: Handle<crate::Expression>,
795    ) -> index::BoundsCheckPolicy {
796        self.policies
797            .choose_policy(pointer, &self.module.types, self.info)
798    }
799
800    /// See docs for [`proc::index::access_needs_check`].
801    fn access_needs_check(
802        &self,
803        base: Handle<crate::Expression>,
804        index: index::GuardedIndex,
805    ) -> Option<index::IndexableLength> {
806        index::access_needs_check(
807            base,
808            index,
809            self.module,
810            &self.function.expressions,
811            self.info,
812        )
813    }
814
815    /// See docs for [`proc::index::bounds_check_iter`].
816    fn bounds_check_iter(
817        &self,
818        chain: Handle<crate::Expression>,
819    ) -> impl Iterator<Item = BoundsCheck> + '_ {
820        index::bounds_check_iter(chain, self.module, self.function, self.info)
821    }
822
823    /// See docs for [`proc::index::oob_local_types`].
824    fn oob_local_types(&self) -> FastHashSet<Handle<crate::Type>> {
825        index::oob_local_types(self.module, self.function, self.info, self.policies)
826    }
827
828    fn get_packed_vec_kind(&self, expr_handle: Handle<crate::Expression>) -> Option<crate::Scalar> {
829        match self.function.expressions[expr_handle] {
830            crate::Expression::AccessIndex { base, index } => {
831                let ty = match *self.resolve_type(base) {
832                    crate::TypeInner::Pointer { base, .. } => &self.module.types[base].inner,
833                    ref ty => ty,
834                };
835                match *ty {
836                    crate::TypeInner::Struct {
837                        ref members, span, ..
838                    } => should_pack_struct_member(members, span, index as usize, self.module),
839                    _ => None,
840                }
841            }
842            _ => None,
843        }
844    }
845}
846
847struct StatementContext<'a> {
848    expression: ExpressionContext<'a>,
849    result_struct: Option<&'a str>,
850}
851
852impl<W: Write> Writer<W> {
853    /// Creates a new `Writer` instance.
854    pub fn new(out: W) -> Self {
855        Writer {
856            out,
857            names: FastHashMap::default(),
858            named_expressions: Default::default(),
859            need_bake_expressions: Default::default(),
860            namer: proc::Namer::default(),
861            wrapped_functions: FastHashSet::default(),
862            #[cfg(test)]
863            put_expression_stack_pointers: Default::default(),
864            #[cfg(test)]
865            put_block_stack_pointers: Default::default(),
866            struct_member_pads: FastHashSet::default(),
867        }
868    }
869
870    /// Finishes writing and returns the output.
871    // See https://github.com/rust-lang/rust-clippy/issues/4979.
872    pub fn finish(self) -> W {
873        self.out
874    }
875
876    /// Generates statements to be inserted immediately before and at the very
877    /// start of the body of each loop, to defeat MSL infinite loop reasoning.
878    /// The 0th item of the returned tuple should be inserted immediately prior
879    /// to the loop and the 1st item should be inserted at the very start of
880    /// the loop body.
881    ///
882    /// # What is this trying to solve?
883    ///
884    /// In Metal Shading Language, an infinite loop has undefined behavior.
885    /// (This rule is inherited from C++14.) This means that, if the MSL
886    /// compiler determines that a given loop will never exit, it may assume
887    /// that it is never reached. It may thus assume that any conditions
888    /// sufficient to cause the loop to be reached must be false. Like many
889    /// optimizing compilers, MSL uses this kind of analysis to establish limits
890    /// on the range of values variables involved in those conditions might
891    /// hold.
892    ///
893    /// For example, suppose the MSL compiler sees the code:
894    ///
895    /// ```ignore
896    /// if (i >= 10) {
897    ///     while (true) { }
898    /// }
899    /// ```
900    ///
901    /// It will recognize that the `while` loop will never terminate, conclude
902    /// that it must be unreachable, and thus infer that, if this code is
903    /// reached, then `i < 10` at that point.
904    ///
905    /// Now suppose that, at some point where `i` has the same value as above,
906    /// the compiler sees the code:
907    ///
908    /// ```ignore
909    /// if (i < 10) {
910    ///     a[i] = 1;
911    /// }
912    /// ```
913    ///
914    /// Because the compiler is confident that `i < 10`, it will make the
915    /// assignment to `a[i]` unconditional, rewriting this code as, simply:
916    ///
917    /// ```ignore
918    /// a[i] = 1;
919    /// ```
920    ///
921    /// If that `if` condition was injected by Naga to implement a bounds check,
922    /// the MSL compiler's optimizations could allow out-of-bounds array
923    /// accesses to occur.
924    ///
925    /// Naga cannot feasibly anticipate whether the MSL compiler will determine
926    /// that a loop is infinite, so an attacker could craft a Naga module
927    /// containing an infinite loop protected by conditions that cause the Metal
928    /// compiler to remove bounds checks that Naga injected elsewhere in the
929    /// function.
930    ///
931    /// This rewrite could occur even if the conditional assignment appears
932    /// *before* the `while` loop, as long as `i < 10` by the time the loop is
933    /// reached. This would allow the attacker to save the results of
934    /// unauthorized reads somewhere accessible before entering the infinite
935    /// loop. But even worse, the MSL compiler has been observed to simply
936    /// delete the infinite loop entirely, so that even code dominated by the
937    /// loop becomes reachable. This would make the attack even more flexible,
938    /// since shaders that would appear to never terminate would actually exit
939    /// nicely, after having stolen data from elsewhere in the GPU address
940    /// space.
941    ///
942    /// To avoid UB, Naga must persuade the MSL compiler that no loop Naga
943    /// generates is infinite. One approach would be to add inline assembly to
944    /// each loop that is annotated as potentially branching out of the loop,
945    /// but which in fact generates no instructions. Unfortunately, inline
946    /// assembly is not handled correctly by some Metal device drivers.
947    ///
948    /// A previously used approach was to add the following code to the bottom
949    /// of every loop:
950    ///
951    /// ```ignore
952    /// if (volatile bool unpredictable = false; unpredictable)
953    ///     break;
954    /// ```
955    ///
956    /// Although the `if` condition will always be false in any real execution,
957    /// the `volatile` qualifier prevents the compiler from assuming this. Thus,
958    /// it must assume that the `break` might be reached, and hence that the
959    /// loop is not unbounded. This prevents the range analysis impact described
960    /// above. Unfortunately this prevented the compiler from making important,
961    /// and safe, optimizations such as loop unrolling and was observed to
962    /// significantly hurt performance.
963    ///
964    /// Our current approach declares a counter before every loop and
965    /// increments it every iteration, breaking after 2^64 iterations:
966    ///
967    /// ```ignore
968    /// uint2 loop_bound = uint2(0);
969    /// while (true) {
970    ///   if (metal::all(loop_bound == uint2(4294967295))) { break; }
971    ///   loop_bound += uint2(loop_bound.y == 4294967295, 1);
972    /// }
973    /// ```
974    ///
975    /// This convinces the compiler that the loop is finite and therefore may
976    /// execute, whilst at the same time allowing optimizations such as loop
977    /// unrolling. Furthermore the 64-bit counter is large enough it seems
978    /// implausible that it would affect the execution of any shader.
979    ///
980    /// This approach is also used by Chromium WebGPU's Dawn shader compiler:
981    /// <https://dawn.googlesource.com/dawn/+/d9e2d1f718678ebee0728b999830576c410cce0a/src/tint/lang/core/ir/transform/prevent_infinite_loops.cc>
982    fn gen_force_bounded_loop_statements(
983        &mut self,
984        level: back::Level,
985        context: &StatementContext,
986    ) -> Option<(String, String)> {
987        if !context.expression.force_loop_bounding {
988            return None;
989        }
990
991        let loop_bound_name = self.namer.call("loop_bound");
992        // Count down from u32::MAX rather than up from 0 to avoid hang on
993        // certain Intel drivers. See <https://github.com/gfx-rs/wgpu/issues/7319>.
994        let decl = format!("{level}uint2 {loop_bound_name} = uint2({}u);", u32::MAX);
995        let level = level.next();
996        let break_and_inc = format!(
997            "{level}if ({NAMESPACE}::all({loop_bound_name} == uint2(0u))) {{ break; }}
998{level}{loop_bound_name} -= uint2({loop_bound_name}.y == 0u, 1u);"
999        );
1000
1001        Some((decl, break_and_inc))
1002    }
1003
1004    fn put_call_parameters(
1005        &mut self,
1006        parameters: impl Iterator<Item = Handle<crate::Expression>>,
1007        context: &ExpressionContext,
1008    ) -> BackendResult {
1009        self.put_call_parameters_impl(parameters, context, |writer, context, expr| {
1010            writer.put_expression(expr, context, true)
1011        })
1012    }
1013
1014    fn put_call_parameters_impl<C, E>(
1015        &mut self,
1016        parameters: impl Iterator<Item = Handle<crate::Expression>>,
1017        ctx: &C,
1018        put_expression: E,
1019    ) -> BackendResult
1020    where
1021        E: Fn(&mut Self, &C, Handle<crate::Expression>) -> BackendResult,
1022    {
1023        write!(self.out, "(")?;
1024        for (i, handle) in parameters.enumerate() {
1025            if i != 0 {
1026                write!(self.out, ", ")?;
1027            }
1028            put_expression(self, ctx, handle)?;
1029        }
1030        write!(self.out, ")")?;
1031        Ok(())
1032    }
1033
1034    /// Writes the local variables of the given function, as well as any extra
1035    /// out-of-bounds locals that are needed.
1036    ///
1037    /// The names of the OOB locals are also added to `self.names` at the same
1038    /// time.
1039    fn put_locals(&mut self, context: &ExpressionContext) -> BackendResult {
1040        let oob_local_types = context.oob_local_types();
1041        for &ty in oob_local_types.iter() {
1042            let name_key = NameKey::oob_local_for_type(context.origin, ty);
1043            self.names.insert(name_key, self.namer.call("oob"));
1044        }
1045
1046        for (name_key, ty, init) in context
1047            .function
1048            .local_variables
1049            .iter()
1050            .map(|(local_handle, local)| {
1051                let name_key = NameKey::local(context.origin, local_handle);
1052                (name_key, local.ty, local.init)
1053            })
1054            .chain(oob_local_types.iter().map(|&ty| {
1055                let name_key = NameKey::oob_local_for_type(context.origin, ty);
1056                (name_key, ty, None)
1057            }))
1058        {
1059            let ty_name = TypeContext {
1060                handle: ty,
1061                gctx: context.module.to_ctx(),
1062                names: &self.names,
1063                access: crate::StorageAccess::empty(),
1064                first_time: false,
1065            };
1066            write!(
1067                self.out,
1068                "{}{} {}",
1069                back::INDENT,
1070                ty_name,
1071                self.names[&name_key]
1072            )?;
1073            match init {
1074                Some(value) => {
1075                    write!(self.out, " = ")?;
1076                    self.put_expression(value, context, true)?;
1077                }
1078                None => {
1079                    write!(self.out, " = {{}}")?;
1080                }
1081            };
1082            writeln!(self.out, ";")?;
1083        }
1084        Ok(())
1085    }
1086
1087    fn put_level_of_detail(
1088        &mut self,
1089        level: LevelOfDetail,
1090        context: &ExpressionContext,
1091    ) -> BackendResult {
1092        match level {
1093            LevelOfDetail::Direct(expr) => self.put_expression(expr, context, true)?,
1094            LevelOfDetail::Restricted(load) => write!(self.out, "{}", ClampedLod(load))?,
1095        }
1096        Ok(())
1097    }
1098
1099    fn put_image_query(
1100        &mut self,
1101        image: Handle<crate::Expression>,
1102        query: &str,
1103        level: Option<LevelOfDetail>,
1104        context: &ExpressionContext,
1105    ) -> BackendResult {
1106        self.put_expression(image, context, false)?;
1107        write!(self.out, ".get_{query}(")?;
1108        if let Some(level) = level {
1109            self.put_level_of_detail(level, context)?;
1110        }
1111        write!(self.out, ")")?;
1112        Ok(())
1113    }
1114
1115    fn put_image_size_query(
1116        &mut self,
1117        image: Handle<crate::Expression>,
1118        level: Option<LevelOfDetail>,
1119        kind: crate::ScalarKind,
1120        context: &ExpressionContext,
1121    ) -> BackendResult {
1122        if let crate::TypeInner::Image {
1123            class: crate::ImageClass::External,
1124            ..
1125        } = *context.resolve_type(image)
1126        {
1127            write!(self.out, "{IMAGE_SIZE_EXTERNAL_FUNCTION}(")?;
1128            self.put_expression(image, context, true)?;
1129            write!(self.out, ")")?;
1130            return Ok(());
1131        }
1132
1133        //Note: MSL only has separate width/height/depth queries,
1134        // so compose the result of them.
1135        let dim = match *context.resolve_type(image) {
1136            crate::TypeInner::Image { dim, .. } => dim,
1137            ref other => unreachable!("Unexpected type {:?}", other),
1138        };
1139        let scalar = crate::Scalar { kind, width: 4 };
1140        let coordinate_type = scalar.to_msl_name();
1141        match dim {
1142            crate::ImageDimension::D1 => {
1143                // Since 1D textures never have mipmaps, MSL requires that the
1144                // `level` argument be a constexpr 0. It's simplest for us just
1145                // to pass `None` and omit the level entirely.
1146                if kind == crate::ScalarKind::Uint {
1147                    // No need to construct a vector. No cast needed.
1148                    self.put_image_query(image, "width", None, context)?;
1149                } else {
1150                    // There's no definition for `int` in the `metal` namespace.
1151                    write!(self.out, "int(")?;
1152                    self.put_image_query(image, "width", None, context)?;
1153                    write!(self.out, ")")?;
1154                }
1155            }
1156            crate::ImageDimension::D2 => {
1157                write!(self.out, "{NAMESPACE}::{coordinate_type}2(")?;
1158                self.put_image_query(image, "width", level, context)?;
1159                write!(self.out, ", ")?;
1160                self.put_image_query(image, "height", level, context)?;
1161                write!(self.out, ")")?;
1162            }
1163            crate::ImageDimension::D3 => {
1164                write!(self.out, "{NAMESPACE}::{coordinate_type}3(")?;
1165                self.put_image_query(image, "width", level, context)?;
1166                write!(self.out, ", ")?;
1167                self.put_image_query(image, "height", level, context)?;
1168                write!(self.out, ", ")?;
1169                self.put_image_query(image, "depth", level, context)?;
1170                write!(self.out, ")")?;
1171            }
1172            crate::ImageDimension::Cube => {
1173                write!(self.out, "{NAMESPACE}::{coordinate_type}2(")?;
1174                self.put_image_query(image, "width", level, context)?;
1175                write!(self.out, ")")?;
1176            }
1177        }
1178        Ok(())
1179    }
1180
1181    fn put_cast_to_uint_scalar_or_vector(
1182        &mut self,
1183        expr: Handle<crate::Expression>,
1184        context: &ExpressionContext,
1185    ) -> BackendResult {
1186        // coordinates in IR are int, but Metal expects uint
1187        match *context.resolve_type(expr) {
1188            crate::TypeInner::Scalar(_) => {
1189                put_numeric_type(&mut self.out, crate::Scalar::U32, &[])?
1190            }
1191            crate::TypeInner::Vector { size, .. } => {
1192                put_numeric_type(&mut self.out, crate::Scalar::U32, &[size])?
1193            }
1194            _ => {
1195                return Err(Error::GenericValidation(
1196                    "Invalid type for image coordinate".into(),
1197                ))
1198            }
1199        };
1200
1201        write!(self.out, "(")?;
1202        self.put_expression(expr, context, true)?;
1203        write!(self.out, ")")?;
1204        Ok(())
1205    }
1206
1207    fn put_image_sample_level(
1208        &mut self,
1209        image: Handle<crate::Expression>,
1210        level: crate::SampleLevel,
1211        context: &ExpressionContext,
1212    ) -> BackendResult {
1213        let has_levels = context.image_needs_lod(image);
1214        match level {
1215            crate::SampleLevel::Auto => {}
1216            crate::SampleLevel::Zero => {
1217                //TODO: do we support Zero on `Sampled` image classes?
1218            }
1219            _ if !has_levels => {
1220                log::warn!("1D image can't be sampled with level {level:?}");
1221            }
1222            crate::SampleLevel::Exact(h) => {
1223                write!(self.out, ", {NAMESPACE}::level(")?;
1224                self.put_expression(h, context, true)?;
1225                write!(self.out, ")")?;
1226            }
1227            crate::SampleLevel::Bias(h) => {
1228                write!(self.out, ", {NAMESPACE}::bias(")?;
1229                self.put_expression(h, context, true)?;
1230                write!(self.out, ")")?;
1231            }
1232            crate::SampleLevel::Gradient { x, y } => {
1233                write!(self.out, ", {NAMESPACE}::gradient2d(")?;
1234                self.put_expression(x, context, true)?;
1235                write!(self.out, ", ")?;
1236                self.put_expression(y, context, true)?;
1237                write!(self.out, ")")?;
1238            }
1239        }
1240        Ok(())
1241    }
1242
1243    fn put_image_coordinate_limits(
1244        &mut self,
1245        image: Handle<crate::Expression>,
1246        level: Option<LevelOfDetail>,
1247        context: &ExpressionContext,
1248    ) -> BackendResult {
1249        self.put_image_size_query(image, level, crate::ScalarKind::Uint, context)?;
1250        write!(self.out, " - 1")?;
1251        Ok(())
1252    }
1253
1254    /// General function for writing restricted image indexes.
1255    ///
1256    /// This is used to produce restricted mip levels, array indices, and sample
1257    /// indices for [`ImageLoad`] and [`ImageStore`] accesses under the
1258    /// [`Restrict`] bounds check policy.
1259    ///
1260    /// This function writes an expression of the form:
1261    ///
1262    /// ```ignore
1263    ///
1264    ///     metal::min(uint(INDEX), IMAGE.LIMIT_METHOD() - 1)
1265    ///
1266    /// ```
1267    ///
1268    /// [`ImageLoad`]: crate::Expression::ImageLoad
1269    /// [`ImageStore`]: crate::Statement::ImageStore
1270    /// [`Restrict`]: index::BoundsCheckPolicy::Restrict
1271    fn put_restricted_scalar_image_index(
1272        &mut self,
1273        image: Handle<crate::Expression>,
1274        index: Handle<crate::Expression>,
1275        limit_method: &str,
1276        context: &ExpressionContext,
1277    ) -> BackendResult {
1278        write!(self.out, "{NAMESPACE}::min(uint(")?;
1279        self.put_expression(index, context, true)?;
1280        write!(self.out, "), ")?;
1281        self.put_expression(image, context, false)?;
1282        write!(self.out, ".{limit_method}() - 1)")?;
1283        Ok(())
1284    }
1285
1286    fn put_restricted_texel_address(
1287        &mut self,
1288        image: Handle<crate::Expression>,
1289        address: &TexelAddress,
1290        context: &ExpressionContext,
1291    ) -> BackendResult {
1292        // Write the coordinate.
1293        write!(self.out, "{NAMESPACE}::min(")?;
1294        self.put_cast_to_uint_scalar_or_vector(address.coordinate, context)?;
1295        write!(self.out, ", ")?;
1296        self.put_image_coordinate_limits(image, address.level, context)?;
1297        write!(self.out, ")")?;
1298
1299        // Write the array index, if present.
1300        if let Some(array_index) = address.array_index {
1301            write!(self.out, ", ")?;
1302            self.put_restricted_scalar_image_index(image, array_index, "get_array_size", context)?;
1303        }
1304
1305        // Write the sample index, if present.
1306        if let Some(sample) = address.sample {
1307            write!(self.out, ", ")?;
1308            self.put_restricted_scalar_image_index(image, sample, "get_num_samples", context)?;
1309        }
1310
1311        // The level of detail should be clamped and cached by
1312        // `put_cache_restricted_level`, so we don't need to clamp it here.
1313        if let Some(level) = address.level {
1314            write!(self.out, ", ")?;
1315            self.put_level_of_detail(level, context)?;
1316        }
1317
1318        Ok(())
1319    }
1320
1321    /// Write an expression that is true if the given image access is in bounds.
1322    fn put_image_access_bounds_check(
1323        &mut self,
1324        image: Handle<crate::Expression>,
1325        address: &TexelAddress,
1326        context: &ExpressionContext,
1327    ) -> BackendResult {
1328        let mut conjunction = "";
1329
1330        // First, check the level of detail. Only if that is in bounds can we
1331        // use it to find the appropriate bounds for the coordinates.
1332        let level = if let Some(level) = address.level {
1333            write!(self.out, "uint(")?;
1334            self.put_level_of_detail(level, context)?;
1335            write!(self.out, ") < ")?;
1336            self.put_expression(image, context, true)?;
1337            write!(self.out, ".get_num_mip_levels()")?;
1338            conjunction = " && ";
1339            Some(level)
1340        } else {
1341            None
1342        };
1343
1344        // Check sample index, if present.
1345        if let Some(sample) = address.sample {
1346            write!(self.out, "uint(")?;
1347            self.put_expression(sample, context, true)?;
1348            write!(self.out, ") < ")?;
1349            self.put_expression(image, context, true)?;
1350            write!(self.out, ".get_num_samples()")?;
1351            conjunction = " && ";
1352        }
1353
1354        // Check array index, if present.
1355        if let Some(array_index) = address.array_index {
1356            write!(self.out, "{conjunction}uint(")?;
1357            self.put_expression(array_index, context, true)?;
1358            write!(self.out, ") < ")?;
1359            self.put_expression(image, context, true)?;
1360            write!(self.out, ".get_array_size()")?;
1361            conjunction = " && ";
1362        }
1363
1364        // Finally, check if the coordinates are within bounds.
1365        let coord_is_vector = match *context.resolve_type(address.coordinate) {
1366            crate::TypeInner::Vector { .. } => true,
1367            _ => false,
1368        };
1369        write!(self.out, "{conjunction}")?;
1370        if coord_is_vector {
1371            write!(self.out, "{NAMESPACE}::all(")?;
1372        }
1373        self.put_cast_to_uint_scalar_or_vector(address.coordinate, context)?;
1374        write!(self.out, " < ")?;
1375        self.put_image_size_query(image, level, crate::ScalarKind::Uint, context)?;
1376        if coord_is_vector {
1377            write!(self.out, ")")?;
1378        }
1379
1380        Ok(())
1381    }
1382
1383    fn put_image_load(
1384        &mut self,
1385        load: Handle<crate::Expression>,
1386        image: Handle<crate::Expression>,
1387        mut address: TexelAddress,
1388        context: &ExpressionContext,
1389    ) -> BackendResult {
1390        if let crate::TypeInner::Image {
1391            class: crate::ImageClass::External,
1392            ..
1393        } = *context.resolve_type(image)
1394        {
1395            write!(self.out, "{IMAGE_LOAD_EXTERNAL_FUNCTION}(")?;
1396            self.put_expression(image, context, true)?;
1397            write!(self.out, ", ")?;
1398            self.put_cast_to_uint_scalar_or_vector(address.coordinate, context)?;
1399            write!(self.out, ")")?;
1400            return Ok(());
1401        }
1402
1403        match context.policies.image_load {
1404            proc::BoundsCheckPolicy::Restrict => {
1405                // Use the cached restricted level of detail, if any. Omit the
1406                // level altogether for 1D textures.
1407                if address.level.is_some() {
1408                    address.level = if context.image_needs_lod(image) {
1409                        Some(LevelOfDetail::Restricted(load))
1410                    } else {
1411                        None
1412                    }
1413                }
1414
1415                self.put_expression(image, context, false)?;
1416                write!(self.out, ".read(")?;
1417                self.put_restricted_texel_address(image, &address, context)?;
1418                write!(self.out, ")")?;
1419            }
1420            proc::BoundsCheckPolicy::ReadZeroSkipWrite => {
1421                write!(self.out, "(")?;
1422                self.put_image_access_bounds_check(image, &address, context)?;
1423                write!(self.out, " ? ")?;
1424                self.put_unchecked_image_load(image, &address, context)?;
1425                write!(self.out, ": DefaultConstructible())")?;
1426            }
1427            proc::BoundsCheckPolicy::Unchecked => {
1428                self.put_unchecked_image_load(image, &address, context)?;
1429            }
1430        }
1431
1432        Ok(())
1433    }
1434
1435    fn put_unchecked_image_load(
1436        &mut self,
1437        image: Handle<crate::Expression>,
1438        address: &TexelAddress,
1439        context: &ExpressionContext,
1440    ) -> BackendResult {
1441        self.put_expression(image, context, false)?;
1442        write!(self.out, ".read(")?;
1443        // coordinates in IR are int, but Metal expects uint
1444        self.put_cast_to_uint_scalar_or_vector(address.coordinate, context)?;
1445        if let Some(expr) = address.array_index {
1446            write!(self.out, ", ")?;
1447            self.put_expression(expr, context, true)?;
1448        }
1449        if let Some(sample) = address.sample {
1450            write!(self.out, ", ")?;
1451            self.put_expression(sample, context, true)?;
1452        }
1453        if let Some(level) = address.level {
1454            if context.image_needs_lod(image) {
1455                write!(self.out, ", ")?;
1456                self.put_level_of_detail(level, context)?;
1457            }
1458        }
1459        write!(self.out, ")")?;
1460
1461        Ok(())
1462    }
1463
1464    fn put_image_atomic(
1465        &mut self,
1466        level: back::Level,
1467        image: Handle<crate::Expression>,
1468        address: &TexelAddress,
1469        fun: crate::AtomicFunction,
1470        value: Handle<crate::Expression>,
1471        context: &StatementContext,
1472    ) -> BackendResult {
1473        write!(self.out, "{level}")?;
1474        self.put_expression(image, &context.expression, false)?;
1475        let op = if context.expression.resolve_type(value).scalar_width() == Some(8) {
1476            fun.to_msl_64_bit()?
1477        } else {
1478            fun.to_msl()
1479        };
1480        write!(self.out, ".atomic_{op}(")?;
1481        // coordinates in IR are int, but Metal expects uint
1482        self.put_cast_to_uint_scalar_or_vector(address.coordinate, &context.expression)?;
1483        write!(self.out, ", ")?;
1484        self.put_expression(value, &context.expression, true)?;
1485        writeln!(self.out, ");")?;
1486
1487        Ok(())
1488    }
1489
1490    fn put_image_store(
1491        &mut self,
1492        level: back::Level,
1493        image: Handle<crate::Expression>,
1494        address: &TexelAddress,
1495        value: Handle<crate::Expression>,
1496        context: &StatementContext,
1497    ) -> BackendResult {
1498        write!(self.out, "{level}")?;
1499        self.put_expression(image, &context.expression, false)?;
1500        write!(self.out, ".write(")?;
1501        self.put_expression(value, &context.expression, true)?;
1502        write!(self.out, ", ")?;
1503        // coordinates in IR are int, but Metal expects uint
1504        self.put_cast_to_uint_scalar_or_vector(address.coordinate, &context.expression)?;
1505        if let Some(expr) = address.array_index {
1506            write!(self.out, ", ")?;
1507            self.put_expression(expr, &context.expression, true)?;
1508        }
1509        writeln!(self.out, ");")?;
1510
1511        Ok(())
1512    }
1513
1514    /// Write the maximum valid index of the dynamically sized array at the end of `handle`.
1515    ///
1516    /// The 'maximum valid index' is simply one less than the array's length.
1517    ///
1518    /// This emits an expression of the form `a / b`, so the caller must
1519    /// parenthesize its output if it will be applying operators of higher
1520    /// precedence.
1521    ///
1522    /// `handle` must be the handle of a global variable whose final member is a
1523    /// dynamically sized array.
1524    fn put_dynamic_array_max_index(
1525        &mut self,
1526        handle: Handle<crate::GlobalVariable>,
1527        context: &ExpressionContext,
1528    ) -> BackendResult {
1529        let global = &context.module.global_variables[handle];
1530        let (offset, array_ty) = match context.module.types[global.ty].inner {
1531            crate::TypeInner::Struct { ref members, .. } => match members.last() {
1532                Some(&crate::StructMember { offset, ty, .. }) => (offset, ty),
1533                None => return Err(Error::GenericValidation("Struct has no members".into())),
1534            },
1535            crate::TypeInner::Array {
1536                size: crate::ArraySize::Dynamic,
1537                ..
1538            } => (0, global.ty),
1539            ref ty => {
1540                return Err(Error::GenericValidation(format!(
1541                    "Expected type with dynamic array, got {ty:?}"
1542                )))
1543            }
1544        };
1545
1546        let (size, stride) = match context.module.types[array_ty].inner {
1547            crate::TypeInner::Array { base, stride, .. } => (
1548                context.module.types[base]
1549                    .inner
1550                    .size(context.module.to_ctx()),
1551                stride,
1552            ),
1553            ref ty => {
1554                return Err(Error::GenericValidation(format!(
1555                    "Expected array type, got {ty:?}"
1556                )))
1557            }
1558        };
1559
1560        // When the stride length is larger than the size, the final element's stride of
1561        // bytes would have padding following the value. But the buffer size in
1562        // `buffer_sizes.sizeN` may not include this padding - it only needs to be large
1563        // enough to hold the actual values' bytes.
1564        //
1565        // So subtract off the size to get a byte size that falls at the start or within
1566        // the final element. Then divide by the stride size, to get one less than the
1567        // length, and then add one. This works even if the buffer size does include the
1568        // stride padding, since division rounds towards zero (MSL 2.4 §6.1). It will fail
1569        // if there are zero elements in the array, but the WebGPU `validating shader binding`
1570        // rules, together with draw-time validation when `minBindingSize` is zero,
1571        // prevent that.
1572        write!(
1573            self.out,
1574            "(_buffer_sizes.{member} - {offset} - {size}) / {stride}",
1575            member = ArraySizeMember(handle),
1576            offset = offset,
1577            size = size,
1578            stride = stride,
1579        )?;
1580        Ok(())
1581    }
1582
1583    /// Emit code for the arithmetic expression of the dot product.
1584    ///
1585    /// The argument `extractor` is a function that accepts a `Writer`, a vector, and
1586    /// an index. It writes out the expression for the vector component at that index.
1587    fn put_dot_product<T: Copy>(
1588        &mut self,
1589        arg: T,
1590        arg1: T,
1591        size: usize,
1592        extractor: impl Fn(&mut Self, T, usize) -> BackendResult,
1593    ) -> BackendResult {
1594        // Write parentheses around the dot product expression to prevent operators
1595        // with different precedences from applying earlier.
1596        write!(self.out, "(")?;
1597
1598        // Cycle through all the components of the vector
1599        for index in 0..size {
1600            // Write the addition to the previous product
1601            // This will print an extra '+' at the beginning but that is fine in msl
1602            write!(self.out, " + ")?;
1603            extractor(self, arg, index)?;
1604            write!(self.out, " * ")?;
1605            extractor(self, arg1, index)?;
1606        }
1607
1608        write!(self.out, ")")?;
1609        Ok(())
1610    }
1611
1612    /// Emit code for the WGSL functions `pack4x{I, U}8[Clamp]`.
1613    fn put_pack4x8(
1614        &mut self,
1615        arg: Handle<crate::Expression>,
1616        context: &ExpressionContext<'_>,
1617        was_signed: bool,
1618        clamp_bounds: Option<(&str, &str)>,
1619    ) -> Result<(), Error> {
1620        let write_arg = |this: &mut Self| -> BackendResult {
1621            if let Some((min, max)) = clamp_bounds {
1622                // Clamping with scalar bounds works (component-wise) even for packed_[u]char4.
1623                write!(this.out, "{NAMESPACE}::clamp(")?;
1624                this.put_expression(arg, context, true)?;
1625                write!(this.out, ", {min}, {max})")?;
1626            } else {
1627                this.put_expression(arg, context, true)?;
1628            }
1629            Ok(())
1630        };
1631
1632        if context.lang_version >= (2, 1) {
1633            let packed_type = if was_signed {
1634                "packed_char4"
1635            } else {
1636                "packed_uchar4"
1637            };
1638            // Metal uses little endian byte order, which matches what WGSL expects here.
1639            write!(self.out, "as_type<uint>({packed_type}(")?;
1640            write_arg(self)?;
1641            write!(self.out, "))")?;
1642        } else {
1643            // MSL < 2.1 doesn't support `as_type` casting between packed chars and scalars.
1644            if was_signed {
1645                write!(self.out, "uint(")?;
1646            }
1647            write!(self.out, "(")?;
1648            write_arg(self)?;
1649            write!(self.out, "[0] & 0xFF) | ((")?;
1650            write_arg(self)?;
1651            write!(self.out, "[1] & 0xFF) << 8) | ((")?;
1652            write_arg(self)?;
1653            write!(self.out, "[2] & 0xFF) << 16) | ((")?;
1654            write_arg(self)?;
1655            write!(self.out, "[3] & 0xFF) << 24)")?;
1656            if was_signed {
1657                write!(self.out, ")")?;
1658            }
1659        }
1660
1661        Ok(())
1662    }
1663
1664    /// Emit code for the isign expression.
1665    ///
1666    fn put_isign(
1667        &mut self,
1668        arg: Handle<crate::Expression>,
1669        context: &ExpressionContext,
1670    ) -> BackendResult {
1671        write!(self.out, "{NAMESPACE}::select({NAMESPACE}::select(")?;
1672        let scalar = context
1673            .resolve_type(arg)
1674            .scalar()
1675            .expect("put_isign should only be called for args which have an integer scalar type")
1676            .to_msl_name();
1677        match context.resolve_type(arg) {
1678            &crate::TypeInner::Vector { size, .. } => {
1679                let size = common::vector_size_str(size);
1680                write!(self.out, "{scalar}{size}(-1), {scalar}{size}(1)")?;
1681            }
1682            _ => {
1683                write!(self.out, "{scalar}(-1), {scalar}(1)")?;
1684            }
1685        }
1686        write!(self.out, ", (")?;
1687        self.put_expression(arg, context, true)?;
1688        write!(self.out, " > 0)), {scalar}(0), (")?;
1689        self.put_expression(arg, context, true)?;
1690        write!(self.out, " == 0))")?;
1691        Ok(())
1692    }
1693
1694    fn put_const_expression(
1695        &mut self,
1696        expr_handle: Handle<crate::Expression>,
1697        module: &crate::Module,
1698        mod_info: &valid::ModuleInfo,
1699        arena: &crate::Arena<crate::Expression>,
1700    ) -> BackendResult {
1701        self.put_possibly_const_expression(
1702            expr_handle,
1703            arena,
1704            module,
1705            mod_info,
1706            &(module, mod_info),
1707            |&(_, mod_info), expr| &mod_info[expr],
1708            |writer, &(module, _), expr| writer.put_const_expression(expr, module, mod_info, arena),
1709        )
1710    }
1711
1712    fn put_literal(&mut self, literal: crate::Literal) -> BackendResult {
1713        match literal {
1714            crate::Literal::F64(_) => {
1715                return Err(Error::CapabilityNotSupported(valid::Capabilities::FLOAT64))
1716            }
1717            crate::Literal::F16(value) => {
1718                if value.is_infinite() {
1719                    let sign = if value.is_sign_negative() { "-" } else { "" };
1720                    write!(self.out, "{sign}INFINITY")?;
1721                } else if value.is_nan() {
1722                    write!(self.out, "NAN")?;
1723                } else {
1724                    let suffix = if value.fract() == f16::from_f32(0.0) {
1725                        ".0h"
1726                    } else {
1727                        "h"
1728                    };
1729                    write!(self.out, "{value}{suffix}")?;
1730                }
1731            }
1732            crate::Literal::F32(value) => {
1733                if value.is_infinite() {
1734                    let sign = if value.is_sign_negative() { "-" } else { "" };
1735                    write!(self.out, "{sign}INFINITY")?;
1736                } else if value.is_nan() {
1737                    write!(self.out, "NAN")?;
1738                } else {
1739                    let suffix = if value.fract() == 0.0 { ".0" } else { "" };
1740                    write!(self.out, "{value}{suffix}")?;
1741                }
1742            }
1743            crate::Literal::U32(value) => {
1744                write!(self.out, "{value}u")?;
1745            }
1746            crate::Literal::I32(value) => {
1747                // `-2147483648` is parsed as unary negation of positive 2147483648.
1748                // 2147483648 is too large for int32_t meaning the expression gets
1749                // promoted to a int64_t which is not our intention. Avoid this by instead
1750                // using `-2147483647 - 1`.
1751                if value == i32::MIN {
1752                    write!(self.out, "({} - 1)", value + 1)?;
1753                } else {
1754                    write!(self.out, "{value}")?;
1755                }
1756            }
1757            crate::Literal::U64(value) => {
1758                write!(self.out, "{value}uL")?;
1759            }
1760            crate::Literal::I64(value) => {
1761                // `-9223372036854775808` is parsed as unary negation of positive
1762                // 9223372036854775808. 9223372036854775808 is too large for int64_t
1763                // causing Metal to emit a `-Wconstant-conversion` warning, and change the
1764                // value to `-9223372036854775808`. Which would then be negated, possibly
1765                // causing undefined behaviour. Avoid this by instead using
1766                // `-9223372036854775808L - 1L`.
1767                if value == i64::MIN {
1768                    write!(self.out, "({}L - 1L)", value + 1)?;
1769                } else {
1770                    write!(self.out, "{value}L")?;
1771                }
1772            }
1773            crate::Literal::Bool(value) => {
1774                write!(self.out, "{value}")?;
1775            }
1776            crate::Literal::AbstractInt(_) | crate::Literal::AbstractFloat(_) => {
1777                return Err(Error::GenericValidation(
1778                    "Unsupported abstract literal".into(),
1779                ));
1780            }
1781        }
1782        Ok(())
1783    }
1784
1785    #[allow(clippy::too_many_arguments)]
1786    fn put_possibly_const_expression<C, I, E>(
1787        &mut self,
1788        expr_handle: Handle<crate::Expression>,
1789        expressions: &crate::Arena<crate::Expression>,
1790        module: &crate::Module,
1791        mod_info: &valid::ModuleInfo,
1792        ctx: &C,
1793        get_expr_ty: I,
1794        put_expression: E,
1795    ) -> BackendResult
1796    where
1797        I: Fn(&C, Handle<crate::Expression>) -> &TypeResolution,
1798        E: Fn(&mut Self, &C, Handle<crate::Expression>) -> BackendResult,
1799    {
1800        match expressions[expr_handle] {
1801            crate::Expression::Literal(literal) => {
1802                self.put_literal(literal)?;
1803            }
1804            crate::Expression::Constant(handle) => {
1805                let constant = &module.constants[handle];
1806                if constant.name.is_some() {
1807                    write!(self.out, "{}", self.names[&NameKey::Constant(handle)])?;
1808                } else {
1809                    self.put_const_expression(
1810                        constant.init,
1811                        module,
1812                        mod_info,
1813                        &module.global_expressions,
1814                    )?;
1815                }
1816            }
1817            crate::Expression::ZeroValue(ty) => {
1818                let ty_name = TypeContext {
1819                    handle: ty,
1820                    gctx: module.to_ctx(),
1821                    names: &self.names,
1822                    access: crate::StorageAccess::empty(),
1823                    first_time: false,
1824                };
1825                write!(self.out, "{ty_name} {{}}")?;
1826            }
1827            crate::Expression::Compose { ty, ref components } => {
1828                let ty_name = TypeContext {
1829                    handle: ty,
1830                    gctx: module.to_ctx(),
1831                    names: &self.names,
1832                    access: crate::StorageAccess::empty(),
1833                    first_time: false,
1834                };
1835                write!(self.out, "{ty_name}")?;
1836                match module.types[ty].inner {
1837                    crate::TypeInner::Scalar(_)
1838                    | crate::TypeInner::Vector { .. }
1839                    | crate::TypeInner::Matrix { .. } => {
1840                        self.put_call_parameters_impl(
1841                            components.iter().copied(),
1842                            ctx,
1843                            put_expression,
1844                        )?;
1845                    }
1846                    crate::TypeInner::Array { .. } => {
1847                        // Naga Arrays are Metal arrays wrapped in structs, so
1848                        // we need two levels of braces.
1849                        write!(self.out, " {{{{")?;
1850                        for (index, &component) in components.iter().enumerate() {
1851                            if index != 0 {
1852                                write!(self.out, ", ")?;
1853                            }
1854                            put_expression(self, ctx, component)?;
1855                        }
1856                        write!(self.out, "}}}}")?;
1857                    }
1858                    crate::TypeInner::Struct { .. } => {
1859                        write!(self.out, " {{")?;
1860                        for (index, &component) in components.iter().enumerate() {
1861                            if index != 0 {
1862                                write!(self.out, ", ")?;
1863                            }
1864                            // insert padding initialization, if needed
1865                            if self.struct_member_pads.contains(&(ty, index as u32)) {
1866                                write!(self.out, "{{}}, ")?;
1867                            }
1868                            put_expression(self, ctx, component)?;
1869                        }
1870                        write!(self.out, "}}")?;
1871                    }
1872                    _ => return Err(Error::UnsupportedCompose(ty)),
1873                }
1874            }
1875            crate::Expression::Splat { size, value } => {
1876                let scalar = match *get_expr_ty(ctx, value).inner_with(&module.types) {
1877                    crate::TypeInner::Scalar(scalar) => scalar,
1878                    ref ty => {
1879                        return Err(Error::GenericValidation(format!(
1880                            "Expected splat value type must be a scalar, got {ty:?}",
1881                        )))
1882                    }
1883                };
1884                put_numeric_type(&mut self.out, scalar, &[size])?;
1885                write!(self.out, "(")?;
1886                put_expression(self, ctx, value)?;
1887                write!(self.out, ")")?;
1888            }
1889            _ => {
1890                return Err(Error::Override);
1891            }
1892        }
1893
1894        Ok(())
1895    }
1896
1897    /// Emit code for the expression `expr_handle`.
1898    ///
1899    /// The `is_scoped` argument is true if the surrounding operators have the
1900    /// precedence of the comma operator, or lower. So, for example:
1901    ///
1902    /// - Pass `true` for `is_scoped` when writing function arguments, an
1903    ///   expression statement, an initializer expression, or anything already
1904    ///   wrapped in parenthesis.
1905    ///
1906    /// - Pass `false` if it is an operand of a `?:` operator, a `[]`, or really
1907    ///   almost anything else.
1908    fn put_expression(
1909        &mut self,
1910        expr_handle: Handle<crate::Expression>,
1911        context: &ExpressionContext,
1912        is_scoped: bool,
1913    ) -> BackendResult {
1914        // Add to the set in order to track the stack size.
1915        #[cfg(test)]
1916        self.put_expression_stack_pointers
1917            .insert(ptr::from_ref(&expr_handle).cast());
1918
1919        if let Some(name) = self.named_expressions.get(&expr_handle) {
1920            write!(self.out, "{name}")?;
1921            return Ok(());
1922        }
1923
1924        let expression = &context.function.expressions[expr_handle];
1925        match *expression {
1926            crate::Expression::Literal(_)
1927            | crate::Expression::Constant(_)
1928            | crate::Expression::ZeroValue(_)
1929            | crate::Expression::Compose { .. }
1930            | crate::Expression::Splat { .. } => {
1931                self.put_possibly_const_expression(
1932                    expr_handle,
1933                    &context.function.expressions,
1934                    context.module,
1935                    context.mod_info,
1936                    context,
1937                    |context, expr: Handle<crate::Expression>| &context.info[expr].ty,
1938                    |writer, context, expr| writer.put_expression(expr, context, true),
1939                )?;
1940            }
1941            crate::Expression::Override(_) => return Err(Error::Override),
1942            crate::Expression::Access { base, .. }
1943            | crate::Expression::AccessIndex { base, .. } => {
1944                // This is an acceptable place to generate a `ReadZeroSkipWrite` check.
1945                // Since `put_bounds_checks` and `put_access_chain` handle an entire
1946                // access chain at a time, recursing back through `put_expression` only
1947                // for index expressions and the base object, we will never see intermediate
1948                // `Access` or `AccessIndex` expressions here.
1949                let policy = context.choose_bounds_check_policy(base);
1950                if policy == index::BoundsCheckPolicy::ReadZeroSkipWrite
1951                    && self.put_bounds_checks(
1952                        expr_handle,
1953                        context,
1954                        back::Level(0),
1955                        if is_scoped { "" } else { "(" },
1956                    )?
1957                {
1958                    write!(self.out, " ? ")?;
1959                    self.put_access_chain(expr_handle, policy, context)?;
1960                    write!(self.out, " : ")?;
1961
1962                    if context.resolve_type(base).pointer_space().is_some() {
1963                        // We can't just use `DefaultConstructible` if this is a pointer.
1964                        // Instead, we create a dummy local variable to serve as pointer
1965                        // target if the access is out of bounds.
1966                        let result_ty = context.info[expr_handle]
1967                            .ty
1968                            .inner_with(&context.module.types)
1969                            .pointer_base_type();
1970                        let result_ty_handle = match result_ty {
1971                            Some(TypeResolution::Handle(handle)) => handle,
1972                            Some(TypeResolution::Value(_)) => {
1973                                // As long as the result of a pointer access expression is
1974                                // passed to a function or stored in a let binding, the
1975                                // type will be in the arena. If additional uses of
1976                                // pointers become valid, this assumption might no longer
1977                                // hold. Note that the LHS of a load or store doesn't
1978                                // take this path -- there is dedicated code in `put_load`
1979                                // and `put_store`.
1980                                unreachable!(
1981                                    "Expected type {result_ty:?} of access through pointer type {base:?} to be in the arena",
1982                                );
1983                            }
1984                            None => {
1985                                unreachable!(
1986                                    "Expected access through pointer type {base:?} to return a pointer, but got {result_ty:?}",
1987                                )
1988                            }
1989                        };
1990                        let name_key =
1991                            NameKey::oob_local_for_type(context.origin, result_ty_handle);
1992                        self.out.write_str(&self.names[&name_key])?;
1993                    } else {
1994                        write!(self.out, "DefaultConstructible()")?;
1995                    }
1996
1997                    if !is_scoped {
1998                        write!(self.out, ")")?;
1999                    }
2000                } else {
2001                    self.put_access_chain(expr_handle, policy, context)?;
2002                }
2003            }
2004            crate::Expression::Swizzle {
2005                size,
2006                vector,
2007                pattern,
2008            } => {
2009                self.put_wrapped_expression_for_packed_vec3_access(
2010                    vector,
2011                    context,
2012                    false,
2013                    &Self::put_expression,
2014                )?;
2015                write!(self.out, ".")?;
2016                for &sc in pattern[..size as usize].iter() {
2017                    write!(self.out, "{}", back::COMPONENTS[sc as usize])?;
2018                }
2019            }
2020            crate::Expression::FunctionArgument(index) => {
2021                let name_key = match context.origin {
2022                    FunctionOrigin::Handle(handle) => NameKey::FunctionArgument(handle, index),
2023                    FunctionOrigin::EntryPoint(ep_index) => {
2024                        NameKey::EntryPointArgument(ep_index, index)
2025                    }
2026                };
2027                let name = &self.names[&name_key];
2028                write!(self.out, "{name}")?;
2029            }
2030            crate::Expression::GlobalVariable(handle) => {
2031                let name = &self.names[&NameKey::GlobalVariable(handle)];
2032                write!(self.out, "{name}")?;
2033            }
2034            crate::Expression::LocalVariable(handle) => {
2035                let name_key = NameKey::local(context.origin, handle);
2036                let name = &self.names[&name_key];
2037                write!(self.out, "{name}")?;
2038            }
2039            crate::Expression::Load { pointer } => self.put_load(pointer, context, is_scoped)?,
2040            crate::Expression::ImageSample {
2041                coordinate,
2042                image,
2043                sampler,
2044                clamp_to_edge: true,
2045                gather: None,
2046                array_index: None,
2047                offset: None,
2048                level: crate::SampleLevel::Zero,
2049                depth_ref: None,
2050            } => {
2051                write!(self.out, "{IMAGE_SAMPLE_BASE_CLAMP_TO_EDGE_FUNCTION}(")?;
2052                self.put_expression(image, context, true)?;
2053                write!(self.out, ", ")?;
2054                self.put_expression(sampler, context, true)?;
2055                write!(self.out, ", ")?;
2056                self.put_expression(coordinate, context, true)?;
2057                write!(self.out, ")")?;
2058            }
2059            crate::Expression::ImageSample {
2060                image,
2061                sampler,
2062                gather,
2063                coordinate,
2064                array_index,
2065                offset,
2066                level,
2067                depth_ref,
2068                clamp_to_edge,
2069            } => {
2070                if clamp_to_edge {
2071                    return Err(Error::GenericValidation(
2072                        "ImageSample::clamp_to_edge should have been validated out".to_string(),
2073                    ));
2074                }
2075
2076                let main_op = match gather {
2077                    Some(_) => "gather",
2078                    None => "sample",
2079                };
2080                let comparison_op = match depth_ref {
2081                    Some(_) => "_compare",
2082                    None => "",
2083                };
2084                self.put_expression(image, context, false)?;
2085                write!(self.out, ".{main_op}{comparison_op}(")?;
2086                self.put_expression(sampler, context, true)?;
2087                write!(self.out, ", ")?;
2088                self.put_expression(coordinate, context, true)?;
2089                if let Some(expr) = array_index {
2090                    write!(self.out, ", ")?;
2091                    self.put_expression(expr, context, true)?;
2092                }
2093                if let Some(dref) = depth_ref {
2094                    write!(self.out, ", ")?;
2095                    self.put_expression(dref, context, true)?;
2096                }
2097
2098                self.put_image_sample_level(image, level, context)?;
2099
2100                if let Some(offset) = offset {
2101                    write!(self.out, ", ")?;
2102                    self.put_expression(offset, context, true)?;
2103                }
2104
2105                match gather {
2106                    None | Some(crate::SwizzleComponent::X) => {}
2107                    Some(component) => {
2108                        let is_cube_map = match *context.resolve_type(image) {
2109                            crate::TypeInner::Image {
2110                                dim: crate::ImageDimension::Cube,
2111                                ..
2112                            } => true,
2113                            _ => false,
2114                        };
2115                        // Offset always comes before the gather, except
2116                        // in cube maps where it's not applicable
2117                        if offset.is_none() && !is_cube_map {
2118                            write!(self.out, ", {NAMESPACE}::int2(0)")?;
2119                        }
2120                        let letter = back::COMPONENTS[component as usize];
2121                        write!(self.out, ", {NAMESPACE}::component::{letter}")?;
2122                    }
2123                }
2124                write!(self.out, ")")?;
2125            }
2126            crate::Expression::ImageLoad {
2127                image,
2128                coordinate,
2129                array_index,
2130                sample,
2131                level,
2132            } => {
2133                let address = TexelAddress {
2134                    coordinate,
2135                    array_index,
2136                    sample,
2137                    level: level.map(LevelOfDetail::Direct),
2138                };
2139                self.put_image_load(expr_handle, image, address, context)?;
2140            }
2141            //Note: for all the queries, the signed integers are expected,
2142            // so a conversion is needed.
2143            crate::Expression::ImageQuery { image, query } => match query {
2144                crate::ImageQuery::Size { level } => {
2145                    self.put_image_size_query(
2146                        image,
2147                        level.map(LevelOfDetail::Direct),
2148                        crate::ScalarKind::Uint,
2149                        context,
2150                    )?;
2151                }
2152                crate::ImageQuery::NumLevels => {
2153                    self.put_expression(image, context, false)?;
2154                    write!(self.out, ".get_num_mip_levels()")?;
2155                }
2156                crate::ImageQuery::NumLayers => {
2157                    self.put_expression(image, context, false)?;
2158                    write!(self.out, ".get_array_size()")?;
2159                }
2160                crate::ImageQuery::NumSamples => {
2161                    self.put_expression(image, context, false)?;
2162                    write!(self.out, ".get_num_samples()")?;
2163                }
2164            },
2165            crate::Expression::Unary { op, expr } => {
2166                let op_str = match op {
2167                    crate::UnaryOperator::Negate => {
2168                        match context.resolve_type(expr).scalar_kind() {
2169                            Some(crate::ScalarKind::Sint) => NEG_FUNCTION,
2170                            _ => "-",
2171                        }
2172                    }
2173                    crate::UnaryOperator::LogicalNot => "!",
2174                    crate::UnaryOperator::BitwiseNot => "~",
2175                };
2176                write!(self.out, "{op_str}(")?;
2177                self.put_expression(expr, context, false)?;
2178                write!(self.out, ")")?;
2179            }
2180            crate::Expression::Binary { op, left, right } => {
2181                let kind = context
2182                    .resolve_type(left)
2183                    .scalar_kind()
2184                    .ok_or(Error::UnsupportedBinaryOp(op))?;
2185
2186                if op == crate::BinaryOperator::Divide
2187                    && (kind == crate::ScalarKind::Sint || kind == crate::ScalarKind::Uint)
2188                {
2189                    write!(self.out, "{DIV_FUNCTION}(")?;
2190                    self.put_expression(left, context, true)?;
2191                    write!(self.out, ", ")?;
2192                    self.put_expression(right, context, true)?;
2193                    write!(self.out, ")")?;
2194                } else if op == crate::BinaryOperator::Modulo
2195                    && (kind == crate::ScalarKind::Sint || kind == crate::ScalarKind::Uint)
2196                {
2197                    write!(self.out, "{MOD_FUNCTION}(")?;
2198                    self.put_expression(left, context, true)?;
2199                    write!(self.out, ", ")?;
2200                    self.put_expression(right, context, true)?;
2201                    write!(self.out, ")")?;
2202                } else if op == crate::BinaryOperator::Modulo && kind == crate::ScalarKind::Float {
2203                    // TODO: handle undefined behavior of BinaryOperator::Modulo
2204                    //
2205                    // float:
2206                    // if right == 0 return ? see https://github.com/gpuweb/gpuweb/issues/2798
2207                    write!(self.out, "{NAMESPACE}::fmod(")?;
2208                    self.put_expression(left, context, true)?;
2209                    write!(self.out, ", ")?;
2210                    self.put_expression(right, context, true)?;
2211                    write!(self.out, ")")?;
2212                } else if (op == crate::BinaryOperator::Add
2213                    || op == crate::BinaryOperator::Subtract
2214                    || op == crate::BinaryOperator::Multiply)
2215                    && kind == crate::ScalarKind::Sint
2216                {
2217                    let to_unsigned = |ty: &crate::TypeInner| match *ty {
2218                        crate::TypeInner::Scalar(scalar) => {
2219                            Ok(crate::TypeInner::Scalar(crate::Scalar {
2220                                kind: crate::ScalarKind::Uint,
2221                                ..scalar
2222                            }))
2223                        }
2224                        crate::TypeInner::Vector { size, scalar } => Ok(crate::TypeInner::Vector {
2225                            size,
2226                            scalar: crate::Scalar {
2227                                kind: crate::ScalarKind::Uint,
2228                                ..scalar
2229                            },
2230                        }),
2231                        _ => Err(Error::UnsupportedBitCast(ty.clone())),
2232                    };
2233
2234                    // Avoid undefined behaviour due to overflowing signed
2235                    // integer arithmetic. Cast the operands to unsigned prior
2236                    // to performing the operation, then cast the result back
2237                    // to signed.
2238                    self.put_bitcasted_expression(
2239                        context.resolve_type(expr_handle),
2240                        context,
2241                        &|writer, context, is_scoped| {
2242                            writer.put_binop(
2243                                op,
2244                                left,
2245                                right,
2246                                context,
2247                                is_scoped,
2248                                &|writer, expr, context, _is_scoped| {
2249                                    writer.put_bitcasted_expression(
2250                                        &to_unsigned(context.resolve_type(expr))?,
2251                                        context,
2252                                        &|writer, context, is_scoped| {
2253                                            writer.put_expression(expr, context, is_scoped)
2254                                        },
2255                                    )
2256                                },
2257                            )
2258                        },
2259                    )?;
2260                } else {
2261                    self.put_binop(op, left, right, context, is_scoped, &Self::put_expression)?;
2262                }
2263            }
2264            crate::Expression::Select {
2265                condition,
2266                accept,
2267                reject,
2268            } => match *context.resolve_type(condition) {
2269                crate::TypeInner::Scalar(crate::Scalar {
2270                    kind: crate::ScalarKind::Bool,
2271                    ..
2272                }) => {
2273                    if !is_scoped {
2274                        write!(self.out, "(")?;
2275                    }
2276                    self.put_expression(condition, context, false)?;
2277                    write!(self.out, " ? ")?;
2278                    self.put_expression(accept, context, false)?;
2279                    write!(self.out, " : ")?;
2280                    self.put_expression(reject, context, false)?;
2281                    if !is_scoped {
2282                        write!(self.out, ")")?;
2283                    }
2284                }
2285                crate::TypeInner::Vector {
2286                    scalar:
2287                        crate::Scalar {
2288                            kind: crate::ScalarKind::Bool,
2289                            ..
2290                        },
2291                    ..
2292                } => {
2293                    write!(self.out, "{NAMESPACE}::select(")?;
2294                    self.put_expression(reject, context, true)?;
2295                    write!(self.out, ", ")?;
2296                    self.put_expression(accept, context, true)?;
2297                    write!(self.out, ", ")?;
2298                    self.put_expression(condition, context, true)?;
2299                    write!(self.out, ")")?;
2300                }
2301                ref ty => {
2302                    return Err(Error::GenericValidation(format!(
2303                        "Expected select condition to be a non-bool type, got {ty:?}",
2304                    )))
2305                }
2306            },
2307            crate::Expression::Derivative { axis, expr, .. } => {
2308                use crate::DerivativeAxis as Axis;
2309                let op = match axis {
2310                    Axis::X => "dfdx",
2311                    Axis::Y => "dfdy",
2312                    Axis::Width => "fwidth",
2313                };
2314                write!(self.out, "{NAMESPACE}::{op}")?;
2315                self.put_call_parameters(iter::once(expr), context)?;
2316            }
2317            crate::Expression::Relational { fun, argument } => {
2318                let op = match fun {
2319                    crate::RelationalFunction::Any => "any",
2320                    crate::RelationalFunction::All => "all",
2321                    crate::RelationalFunction::IsNan => "isnan",
2322                    crate::RelationalFunction::IsInf => "isinf",
2323                };
2324                write!(self.out, "{NAMESPACE}::{op}")?;
2325                self.put_call_parameters(iter::once(argument), context)?;
2326            }
2327            crate::Expression::Math {
2328                fun,
2329                arg,
2330                arg1,
2331                arg2,
2332                arg3,
2333            } => {
2334                use crate::MathFunction as Mf;
2335
2336                let arg_type = context.resolve_type(arg);
2337                let scalar_argument = match arg_type {
2338                    &crate::TypeInner::Scalar(_) => true,
2339                    _ => false,
2340                };
2341
2342                let fun_name = match fun {
2343                    // comparison
2344                    Mf::Abs => "abs",
2345                    Mf::Min => "min",
2346                    Mf::Max => "max",
2347                    Mf::Clamp => "clamp",
2348                    Mf::Saturate => "saturate",
2349                    // trigonometry
2350                    Mf::Cos => "cos",
2351                    Mf::Cosh => "cosh",
2352                    Mf::Sin => "sin",
2353                    Mf::Sinh => "sinh",
2354                    Mf::Tan => "tan",
2355                    Mf::Tanh => "tanh",
2356                    Mf::Acos => "acos",
2357                    Mf::Asin => "asin",
2358                    Mf::Atan => "atan",
2359                    Mf::Atan2 => "atan2",
2360                    Mf::Asinh => "asinh",
2361                    Mf::Acosh => "acosh",
2362                    Mf::Atanh => "atanh",
2363                    Mf::Radians => "",
2364                    Mf::Degrees => "",
2365                    // decomposition
2366                    Mf::Ceil => "ceil",
2367                    Mf::Floor => "floor",
2368                    Mf::Round => "rint",
2369                    Mf::Fract => "fract",
2370                    Mf::Trunc => "trunc",
2371                    Mf::Modf => MODF_FUNCTION,
2372                    Mf::Frexp => FREXP_FUNCTION,
2373                    Mf::Ldexp => "ldexp",
2374                    // exponent
2375                    Mf::Exp => "exp",
2376                    Mf::Exp2 => "exp2",
2377                    Mf::Log => "log",
2378                    Mf::Log2 => "log2",
2379                    Mf::Pow => "pow",
2380                    // geometry
2381                    Mf::Dot => match *context.resolve_type(arg) {
2382                        crate::TypeInner::Vector {
2383                            scalar:
2384                                crate::Scalar {
2385                                    // Resolve float values to MSL's builtin dot function.
2386                                    kind: crate::ScalarKind::Float,
2387                                    ..
2388                                },
2389                            ..
2390                        } => "dot",
2391                        crate::TypeInner::Vector {
2392                            size,
2393                            scalar:
2394                                scalar @ crate::Scalar {
2395                                    kind: crate::ScalarKind::Sint | crate::ScalarKind::Uint,
2396                                    ..
2397                                },
2398                        } => {
2399                            // Integer vector dot: call our mangled helper `dot_{type}{N}(a, b)`.
2400                            let fun_name = self.get_dot_wrapper_function_helper_name(scalar, size);
2401                            write!(self.out, "{fun_name}(")?;
2402                            self.put_expression(arg, context, true)?;
2403                            write!(self.out, ", ")?;
2404                            self.put_expression(arg1.unwrap(), context, true)?;
2405                            write!(self.out, ")")?;
2406                            return Ok(());
2407                        }
2408                        _ => unreachable!(
2409                            "Correct TypeInner for dot product should be already validated"
2410                        ),
2411                    },
2412                    fun @ (Mf::Dot4I8Packed | Mf::Dot4U8Packed) => {
2413                        if context.lang_version >= (2, 1) {
2414                            // Write potentially optimizable code using `packed_(u?)char4`.
2415                            // The two function arguments were already reinterpreted as packed (signed
2416                            // or unsigned) chars in `Self::put_block`.
2417                            let packed_type = match fun {
2418                                Mf::Dot4I8Packed => "packed_char4",
2419                                Mf::Dot4U8Packed => "packed_uchar4",
2420                                _ => unreachable!(),
2421                            };
2422
2423                            return self.put_dot_product(
2424                                Reinterpreted::new(packed_type, arg),
2425                                Reinterpreted::new(packed_type, arg1.unwrap()),
2426                                4,
2427                                |writer, arg, index| {
2428                                    // MSL implicitly promotes these (signed or unsigned) chars to
2429                                    // `int` or `uint` in the multiplication, so no overflow can occur.
2430                                    write!(writer.out, "{arg}[{index}]")?;
2431                                    Ok(())
2432                                },
2433                            );
2434                        } else {
2435                            // Fall back to a polyfill since MSL < 2.1 doesn't seem to support
2436                            // bitcasting from uint to `packed_char4` or `packed_uchar4`.
2437                            // See <https://github.com/gfx-rs/wgpu/pull/7574#issuecomment-2835464472>.
2438                            let conversion = match fun {
2439                                Mf::Dot4I8Packed => "int",
2440                                Mf::Dot4U8Packed => "",
2441                                _ => unreachable!(),
2442                            };
2443
2444                            return self.put_dot_product(
2445                                arg,
2446                                arg1.unwrap(),
2447                                4,
2448                                |writer, arg, index| {
2449                                    write!(writer.out, "({conversion}(")?;
2450                                    writer.put_expression(arg, context, true)?;
2451                                    if index == 3 {
2452                                        write!(writer.out, ") >> 24)")?;
2453                                    } else {
2454                                        write!(writer.out, ") << {} >> 24)", (3 - index) * 8)?;
2455                                    }
2456                                    Ok(())
2457                                },
2458                            );
2459                        }
2460                    }
2461                    Mf::Outer => return Err(Error::UnsupportedCall(format!("{fun:?}"))),
2462                    Mf::Cross => "cross",
2463                    Mf::Distance => "distance",
2464                    Mf::Length if scalar_argument => "abs",
2465                    Mf::Length => "length",
2466                    Mf::Normalize => "normalize",
2467                    Mf::FaceForward => "faceforward",
2468                    Mf::Reflect => "reflect",
2469                    Mf::Refract => "refract",
2470                    // computational
2471                    Mf::Sign => match arg_type.scalar_kind() {
2472                        Some(crate::ScalarKind::Sint) => {
2473                            return self.put_isign(arg, context);
2474                        }
2475                        _ => "sign",
2476                    },
2477                    Mf::Fma => "fma",
2478                    Mf::Mix => "mix",
2479                    Mf::Step => "step",
2480                    Mf::SmoothStep => "smoothstep",
2481                    Mf::Sqrt => "sqrt",
2482                    Mf::InverseSqrt => "rsqrt",
2483                    Mf::Inverse => return Err(Error::UnsupportedCall(format!("{fun:?}"))),
2484                    Mf::Transpose => "transpose",
2485                    Mf::Determinant => "determinant",
2486                    Mf::QuantizeToF16 => "",
2487                    // bits
2488                    Mf::CountTrailingZeros => "ctz",
2489                    Mf::CountLeadingZeros => "clz",
2490                    Mf::CountOneBits => "popcount",
2491                    Mf::ReverseBits => "reverse_bits",
2492                    Mf::ExtractBits => "",
2493                    Mf::InsertBits => "",
2494                    Mf::FirstTrailingBit => "",
2495                    Mf::FirstLeadingBit => "",
2496                    // data packing
2497                    Mf::Pack4x8snorm => "pack_float_to_snorm4x8",
2498                    Mf::Pack4x8unorm => "pack_float_to_unorm4x8",
2499                    Mf::Pack2x16snorm => "pack_float_to_snorm2x16",
2500                    Mf::Pack2x16unorm => "pack_float_to_unorm2x16",
2501                    Mf::Pack2x16float => "",
2502                    Mf::Pack4xI8 => "",
2503                    Mf::Pack4xU8 => "",
2504                    Mf::Pack4xI8Clamp => "",
2505                    Mf::Pack4xU8Clamp => "",
2506                    // data unpacking
2507                    Mf::Unpack4x8snorm => "unpack_snorm4x8_to_float",
2508                    Mf::Unpack4x8unorm => "unpack_unorm4x8_to_float",
2509                    Mf::Unpack2x16snorm => "unpack_snorm2x16_to_float",
2510                    Mf::Unpack2x16unorm => "unpack_unorm2x16_to_float",
2511                    Mf::Unpack2x16float => "",
2512                    Mf::Unpack4xI8 => "",
2513                    Mf::Unpack4xU8 => "",
2514                };
2515
2516                match fun {
2517                    Mf::ReverseBits | Mf::ExtractBits | Mf::InsertBits => {
2518                        // reverse_bits is listed as requiring MSL 2.1 but that
2519                        // is a copy/paste error. Looking at previous snapshots
2520                        // on web.archive.org it's present in MSL 1.2.
2521                        //
2522                        // https://developer.apple.com/library/archive/documentation/Miscellaneous/Conceptual/MetalProgrammingGuide/WhatsNewiniOS10tvOS10andOSX1012/WhatsNewiniOS10tvOS10andOSX1012.html
2523                        // also talks about MSL 1.2 adding "New integer
2524                        // functions to extract, insert, and reverse bits, as
2525                        // described in Integer Functions."
2526                        if context.lang_version < (1, 2) {
2527                            return Err(Error::UnsupportedFunction(fun_name.to_string()));
2528                        }
2529                    }
2530                    _ => {}
2531                }
2532
2533                match fun {
2534                    Mf::Abs if arg_type.scalar_kind() == Some(crate::ScalarKind::Sint) => {
2535                        write!(self.out, "{ABS_FUNCTION}(")?;
2536                        self.put_expression(arg, context, true)?;
2537                        write!(self.out, ")")?;
2538                    }
2539                    Mf::Distance if scalar_argument => {
2540                        write!(self.out, "{NAMESPACE}::abs(")?;
2541                        self.put_expression(arg, context, false)?;
2542                        write!(self.out, " - ")?;
2543                        self.put_expression(arg1.unwrap(), context, false)?;
2544                        write!(self.out, ")")?;
2545                    }
2546                    Mf::FirstTrailingBit => {
2547                        let scalar = context.resolve_type(arg).scalar().unwrap();
2548                        let constant = scalar.width * 8 + 1;
2549
2550                        write!(self.out, "((({NAMESPACE}::ctz(")?;
2551                        self.put_expression(arg, context, true)?;
2552                        write!(self.out, ") + 1) % {constant}) - 1)")?;
2553                    }
2554                    Mf::FirstLeadingBit => {
2555                        let inner = context.resolve_type(arg);
2556                        let scalar = inner.scalar().unwrap();
2557                        let constant = scalar.width * 8 - 1;
2558
2559                        write!(
2560                            self.out,
2561                            "{NAMESPACE}::select({constant} - {NAMESPACE}::clz("
2562                        )?;
2563
2564                        if scalar.kind == crate::ScalarKind::Sint {
2565                            write!(self.out, "{NAMESPACE}::select(")?;
2566                            self.put_expression(arg, context, true)?;
2567                            write!(self.out, ", ~")?;
2568                            self.put_expression(arg, context, true)?;
2569                            write!(self.out, ", ")?;
2570                            self.put_expression(arg, context, true)?;
2571                            write!(self.out, " < 0)")?;
2572                        } else {
2573                            self.put_expression(arg, context, true)?;
2574                        }
2575
2576                        write!(self.out, "), ")?;
2577
2578                        // or metal will complain that select is ambiguous
2579                        match *inner {
2580                            crate::TypeInner::Vector { size, scalar } => {
2581                                let size = common::vector_size_str(size);
2582                                let name = scalar.to_msl_name();
2583                                write!(self.out, "{name}{size}")?;
2584                            }
2585                            crate::TypeInner::Scalar(scalar) => {
2586                                let name = scalar.to_msl_name();
2587                                write!(self.out, "{name}")?;
2588                            }
2589                            _ => (),
2590                        }
2591
2592                        write!(self.out, "(-1), ")?;
2593                        self.put_expression(arg, context, true)?;
2594                        write!(self.out, " == 0 || ")?;
2595                        self.put_expression(arg, context, true)?;
2596                        write!(self.out, " == -1)")?;
2597                    }
2598                    Mf::Unpack2x16float => {
2599                        write!(self.out, "float2(as_type<half2>(")?;
2600                        self.put_expression(arg, context, false)?;
2601                        write!(self.out, "))")?;
2602                    }
2603                    Mf::Pack2x16float => {
2604                        write!(self.out, "as_type<uint>(half2(")?;
2605                        self.put_expression(arg, context, false)?;
2606                        write!(self.out, "))")?;
2607                    }
2608                    Mf::ExtractBits => {
2609                        // The behavior of ExtractBits is undefined when offset + count > bit_width. We need
2610                        // to first sanitize the offset and count first. If we don't do this, Apple chips
2611                        // will return out-of-spec values if the extracted range is not within the bit width.
2612                        //
2613                        // This encodes the exact formula specified by the wgsl spec, without temporary values:
2614                        // https://gpuweb.github.io/gpuweb/wgsl/#extractBits-unsigned-builtin
2615                        //
2616                        // w = sizeof(x) * 8
2617                        // o = min(offset, w)
2618                        // tmp = w - o
2619                        // c = min(count, tmp)
2620                        //
2621                        // bitfieldExtract(x, o, c)
2622                        //
2623                        // extract_bits(e, min(offset, w), min(count, w - min(offset, w))))
2624
2625                        let scalar_bits = context.resolve_type(arg).scalar_width().unwrap() * 8;
2626
2627                        write!(self.out, "{NAMESPACE}::extract_bits(")?;
2628                        self.put_expression(arg, context, true)?;
2629                        write!(self.out, ", {NAMESPACE}::min(")?;
2630                        self.put_expression(arg1.unwrap(), context, true)?;
2631                        write!(self.out, ", {scalar_bits}u), {NAMESPACE}::min(")?;
2632                        self.put_expression(arg2.unwrap(), context, true)?;
2633                        write!(self.out, ", {scalar_bits}u - {NAMESPACE}::min(")?;
2634                        self.put_expression(arg1.unwrap(), context, true)?;
2635                        write!(self.out, ", {scalar_bits}u)))")?;
2636                    }
2637                    Mf::InsertBits => {
2638                        // The behavior of InsertBits has the same issue as ExtractBits.
2639                        //
2640                        // insertBits(e, newBits, min(offset, w), min(count, w - min(offset, w))))
2641
2642                        let scalar_bits = context.resolve_type(arg).scalar_width().unwrap() * 8;
2643
2644                        write!(self.out, "{NAMESPACE}::insert_bits(")?;
2645                        self.put_expression(arg, context, true)?;
2646                        write!(self.out, ", ")?;
2647                        self.put_expression(arg1.unwrap(), context, true)?;
2648                        write!(self.out, ", {NAMESPACE}::min(")?;
2649                        self.put_expression(arg2.unwrap(), context, true)?;
2650                        write!(self.out, ", {scalar_bits}u), {NAMESPACE}::min(")?;
2651                        self.put_expression(arg3.unwrap(), context, true)?;
2652                        write!(self.out, ", {scalar_bits}u - {NAMESPACE}::min(")?;
2653                        self.put_expression(arg2.unwrap(), context, true)?;
2654                        write!(self.out, ", {scalar_bits}u)))")?;
2655                    }
2656                    Mf::Radians => {
2657                        write!(self.out, "((")?;
2658                        self.put_expression(arg, context, false)?;
2659                        write!(self.out, ") * 0.017453292519943295474)")?;
2660                    }
2661                    Mf::Degrees => {
2662                        write!(self.out, "((")?;
2663                        self.put_expression(arg, context, false)?;
2664                        write!(self.out, ") * 57.295779513082322865)")?;
2665                    }
2666                    Mf::Modf | Mf::Frexp => {
2667                        write!(self.out, "{fun_name}")?;
2668                        self.put_call_parameters(iter::once(arg), context)?;
2669                    }
2670                    Mf::Pack4xI8 => self.put_pack4x8(arg, context, true, None)?,
2671                    Mf::Pack4xU8 => self.put_pack4x8(arg, context, false, None)?,
2672                    Mf::Pack4xI8Clamp => {
2673                        self.put_pack4x8(arg, context, true, Some(("-128", "127")))?
2674                    }
2675                    Mf::Pack4xU8Clamp => {
2676                        self.put_pack4x8(arg, context, false, Some(("0", "255")))?
2677                    }
2678                    fun @ (Mf::Unpack4xI8 | Mf::Unpack4xU8) => {
2679                        let sign_prefix = if matches!(fun, Mf::Unpack4xU8) {
2680                            "u"
2681                        } else {
2682                            ""
2683                        };
2684
2685                        if context.lang_version >= (2, 1) {
2686                            // Metal uses little endian byte order, which matches what WGSL expects here.
2687                            write!(
2688                                self.out,
2689                                "{sign_prefix}int4(as_type<packed_{sign_prefix}char4>("
2690                            )?;
2691                            self.put_expression(arg, context, true)?;
2692                            write!(self.out, "))")?;
2693                        } else {
2694                            // MSL < 2.1 doesn't support `as_type` casting between packed chars and scalars.
2695                            write!(self.out, "({sign_prefix}int4(")?;
2696                            self.put_expression(arg, context, true)?;
2697                            write!(self.out, ", ")?;
2698                            self.put_expression(arg, context, true)?;
2699                            write!(self.out, " >> 8, ")?;
2700                            self.put_expression(arg, context, true)?;
2701                            write!(self.out, " >> 16, ")?;
2702                            self.put_expression(arg, context, true)?;
2703                            write!(self.out, " >> 24) << 24 >> 24)")?;
2704                        }
2705                    }
2706                    Mf::QuantizeToF16 => {
2707                        match *context.resolve_type(arg) {
2708                            crate::TypeInner::Scalar { .. } => write!(self.out, "float(half(")?,
2709                            crate::TypeInner::Vector { size, .. } => write!(
2710                                self.out,
2711                                "{NAMESPACE}::float{size}({NAMESPACE}::half{size}(",
2712                                size = common::vector_size_str(size),
2713                            )?,
2714                            _ => unreachable!(
2715                                "Correct TypeInner for QuantizeToF16 should be already validated"
2716                            ),
2717                        };
2718
2719                        self.put_expression(arg, context, true)?;
2720                        write!(self.out, "))")?;
2721                    }
2722                    _ => {
2723                        write!(self.out, "{NAMESPACE}::{fun_name}")?;
2724                        self.put_call_parameters(
2725                            iter::once(arg).chain(arg1).chain(arg2).chain(arg3),
2726                            context,
2727                        )?;
2728                    }
2729                }
2730            }
2731            crate::Expression::As {
2732                expr,
2733                kind,
2734                convert,
2735            } => match *context.resolve_type(expr) {
2736                crate::TypeInner::Scalar(src) | crate::TypeInner::Vector { scalar: src, .. } => {
2737                    if src.kind == crate::ScalarKind::Float
2738                        && (kind == crate::ScalarKind::Sint || kind == crate::ScalarKind::Uint)
2739                        && convert.is_some()
2740                    {
2741                        // Use helper functions for float to int casts in order to avoid
2742                        // undefined behaviour when value is out of range for the target
2743                        // type.
2744                        let fun_name = match (kind, convert) {
2745                            (crate::ScalarKind::Sint, Some(4)) => F2I32_FUNCTION,
2746                            (crate::ScalarKind::Uint, Some(4)) => F2U32_FUNCTION,
2747                            (crate::ScalarKind::Sint, Some(8)) => F2I64_FUNCTION,
2748                            (crate::ScalarKind::Uint, Some(8)) => F2U64_FUNCTION,
2749                            _ => unreachable!(),
2750                        };
2751                        write!(self.out, "{fun_name}(")?;
2752                        self.put_expression(expr, context, true)?;
2753                        write!(self.out, ")")?;
2754                    } else {
2755                        let target_scalar = crate::Scalar {
2756                            kind,
2757                            width: convert.unwrap_or(src.width),
2758                        };
2759                        let op = match convert {
2760                            Some(_) => "static_cast",
2761                            None => "as_type",
2762                        };
2763                        write!(self.out, "{op}<")?;
2764                        match *context.resolve_type(expr) {
2765                            crate::TypeInner::Vector { size, .. } => {
2766                                put_numeric_type(&mut self.out, target_scalar, &[size])?
2767                            }
2768                            _ => put_numeric_type(&mut self.out, target_scalar, &[])?,
2769                        };
2770                        write!(self.out, ">(")?;
2771                        self.put_expression(expr, context, true)?;
2772                        write!(self.out, ")")?;
2773                    }
2774                }
2775                crate::TypeInner::Matrix {
2776                    columns,
2777                    rows,
2778                    scalar,
2779                } => {
2780                    let target_scalar = crate::Scalar {
2781                        kind,
2782                        width: convert.unwrap_or(scalar.width),
2783                    };
2784                    put_numeric_type(&mut self.out, target_scalar, &[rows, columns])?;
2785                    write!(self.out, "(")?;
2786                    self.put_expression(expr, context, true)?;
2787                    write!(self.out, ")")?;
2788                }
2789                ref ty => {
2790                    return Err(Error::GenericValidation(format!(
2791                        "Unsupported type for As: {ty:?}"
2792                    )))
2793                }
2794            },
2795            // has to be a named expression
2796            crate::Expression::CallResult(_)
2797            | crate::Expression::AtomicResult { .. }
2798            | crate::Expression::WorkGroupUniformLoadResult { .. }
2799            | crate::Expression::SubgroupBallotResult
2800            | crate::Expression::SubgroupOperationResult { .. }
2801            | crate::Expression::RayQueryProceedResult => {
2802                unreachable!()
2803            }
2804            crate::Expression::ArrayLength(expr) => {
2805                // Find the global to which the array belongs.
2806                let global = match context.function.expressions[expr] {
2807                    crate::Expression::AccessIndex { base, .. } => {
2808                        match context.function.expressions[base] {
2809                            crate::Expression::GlobalVariable(handle) => handle,
2810                            ref ex => {
2811                                return Err(Error::GenericValidation(format!(
2812                                    "Expected global variable in AccessIndex, got {ex:?}"
2813                                )))
2814                            }
2815                        }
2816                    }
2817                    crate::Expression::GlobalVariable(handle) => handle,
2818                    ref ex => {
2819                        return Err(Error::GenericValidation(format!(
2820                            "Unexpected expression in ArrayLength, got {ex:?}"
2821                        )))
2822                    }
2823                };
2824
2825                if !is_scoped {
2826                    write!(self.out, "(")?;
2827                }
2828                write!(self.out, "1 + ")?;
2829                self.put_dynamic_array_max_index(global, context)?;
2830                if !is_scoped {
2831                    write!(self.out, ")")?;
2832                }
2833            }
2834            crate::Expression::RayQueryVertexPositions { .. } => {
2835                unimplemented!()
2836            }
2837            crate::Expression::RayQueryGetIntersection {
2838                query,
2839                committed: _,
2840            } => {
2841                if context.lang_version < (2, 4) {
2842                    return Err(Error::UnsupportedRayTracing);
2843                }
2844
2845                let ty = context.module.special_types.ray_intersection.unwrap();
2846                let type_name = &self.names[&NameKey::Type(ty)];
2847                write!(self.out, "{type_name} {{{RAY_QUERY_FUN_MAP_INTERSECTION}(")?;
2848                self.put_expression(query, context, true)?;
2849                write!(self.out, ".{RAY_QUERY_FIELD_INTERSECTION}.type)")?;
2850                let fields = [
2851                    "distance",
2852                    "user_instance_id", // req Metal 2.4
2853                    "instance_id",
2854                    "", // SBT offset
2855                    "geometry_id",
2856                    "primitive_id",
2857                    "triangle_barycentric_coord",
2858                    "triangle_front_facing",
2859                    "",                          // padding
2860                    "object_to_world_transform", // req Metal 2.4
2861                    "world_to_object_transform", // req Metal 2.4
2862                ];
2863                for field in fields {
2864                    write!(self.out, ", ")?;
2865                    if field.is_empty() {
2866                        write!(self.out, "{{}}")?;
2867                    } else {
2868                        self.put_expression(query, context, true)?;
2869                        write!(self.out, ".{RAY_QUERY_FIELD_INTERSECTION}.{field}")?;
2870                    }
2871                }
2872                write!(self.out, "}}")?;
2873            }
2874            crate::Expression::CooperativeLoad { ref data, .. } => {
2875                if context.lang_version < (2, 3) {
2876                    return Err(Error::UnsupportedCooperativeMatrix);
2877                }
2878                write!(self.out, "{COOPERATIVE_LOAD_FUNCTION}(")?;
2879                write!(self.out, "&")?;
2880                self.put_access_chain(data.pointer, context.policies.index, context)?;
2881                write!(self.out, ", ")?;
2882                self.put_expression(data.stride, context, true)?;
2883                write!(self.out, ", {})", data.row_major)?;
2884            }
2885            crate::Expression::CooperativeMultiplyAdd { a, b, c } => {
2886                if context.lang_version < (2, 3) {
2887                    return Err(Error::UnsupportedCooperativeMatrix);
2888                }
2889                write!(self.out, "{COOPERATIVE_MULTIPLY_ADD_FUNCTION}(")?;
2890                self.put_expression(a, context, true)?;
2891                write!(self.out, ", ")?;
2892                self.put_expression(b, context, true)?;
2893                write!(self.out, ", ")?;
2894                self.put_expression(c, context, true)?;
2895                write!(self.out, ")")?;
2896            }
2897        }
2898        Ok(())
2899    }
2900
2901    /// Emits code for a binary operation, using the provided callback to emit
2902    /// the left and right operands.
2903    fn put_binop<F>(
2904        &mut self,
2905        op: crate::BinaryOperator,
2906        left: Handle<crate::Expression>,
2907        right: Handle<crate::Expression>,
2908        context: &ExpressionContext,
2909        is_scoped: bool,
2910        put_expression: &F,
2911    ) -> BackendResult
2912    where
2913        F: Fn(&mut Self, Handle<crate::Expression>, &ExpressionContext, bool) -> BackendResult,
2914    {
2915        let op_str = back::binary_operation_str(op);
2916
2917        if !is_scoped {
2918            write!(self.out, "(")?;
2919        }
2920
2921        // Cast packed vector if necessary
2922        // Packed vector - matrix multiplications are not supported in MSL
2923        if op == crate::BinaryOperator::Multiply
2924            && matches!(
2925                context.resolve_type(right),
2926                &crate::TypeInner::Matrix { .. }
2927            )
2928        {
2929            self.put_wrapped_expression_for_packed_vec3_access(
2930                left,
2931                context,
2932                false,
2933                put_expression,
2934            )?;
2935        } else {
2936            put_expression(self, left, context, false)?;
2937        }
2938
2939        write!(self.out, " {op_str} ")?;
2940
2941        // See comment above
2942        if op == crate::BinaryOperator::Multiply
2943            && matches!(context.resolve_type(left), &crate::TypeInner::Matrix { .. })
2944        {
2945            self.put_wrapped_expression_for_packed_vec3_access(
2946                right,
2947                context,
2948                false,
2949                put_expression,
2950            )?;
2951        } else {
2952            put_expression(self, right, context, false)?;
2953        }
2954
2955        if !is_scoped {
2956            write!(self.out, ")")?;
2957        }
2958
2959        Ok(())
2960    }
2961
2962    /// Used by expressions like Swizzle and Binary since they need packed_vec3's to be casted to a vec3
2963    fn put_wrapped_expression_for_packed_vec3_access<F>(
2964        &mut self,
2965        expr_handle: Handle<crate::Expression>,
2966        context: &ExpressionContext,
2967        is_scoped: bool,
2968        put_expression: &F,
2969    ) -> BackendResult
2970    where
2971        F: Fn(&mut Self, Handle<crate::Expression>, &ExpressionContext, bool) -> BackendResult,
2972    {
2973        if let Some(scalar) = context.get_packed_vec_kind(expr_handle) {
2974            write!(self.out, "{}::{}3(", NAMESPACE, scalar.to_msl_name())?;
2975            put_expression(self, expr_handle, context, is_scoped)?;
2976            write!(self.out, ")")?;
2977        } else {
2978            put_expression(self, expr_handle, context, is_scoped)?;
2979        }
2980        Ok(())
2981    }
2982
2983    /// Emits code for an expression using the provided callback, wrapping the
2984    /// result in a bitcast to the type `cast_to`.
2985    fn put_bitcasted_expression<F>(
2986        &mut self,
2987        cast_to: &crate::TypeInner,
2988        context: &ExpressionContext,
2989        put_expression: &F,
2990    ) -> BackendResult
2991    where
2992        F: Fn(&mut Self, &ExpressionContext, bool) -> BackendResult,
2993    {
2994        write!(self.out, "as_type<")?;
2995        match *cast_to {
2996            crate::TypeInner::Scalar(scalar) => put_numeric_type(&mut self.out, scalar, &[])?,
2997            crate::TypeInner::Vector { size, scalar } => {
2998                put_numeric_type(&mut self.out, scalar, &[size])?
2999            }
3000            _ => return Err(Error::UnsupportedBitCast(cast_to.clone())),
3001        };
3002        write!(self.out, ">(")?;
3003        put_expression(self, context, true)?;
3004        write!(self.out, ")")?;
3005
3006        Ok(())
3007    }
3008
3009    /// Write a `GuardedIndex` as a Metal expression.
3010    fn put_index(
3011        &mut self,
3012        index: index::GuardedIndex,
3013        context: &ExpressionContext,
3014        is_scoped: bool,
3015    ) -> BackendResult {
3016        match index {
3017            index::GuardedIndex::Expression(expr) => {
3018                self.put_expression(expr, context, is_scoped)?
3019            }
3020            index::GuardedIndex::Known(value) => write!(self.out, "{value}")?,
3021        }
3022        Ok(())
3023    }
3024
3025    /// Emit an index bounds check condition for `chain`, if required.
3026    ///
3027    /// `chain` is a subtree of `Access` and `AccessIndex` expressions,
3028    /// operating either on a pointer to a value, or on a value directly. If we cannot
3029    /// statically determine that all indexing operations in `chain` are within
3030    /// bounds, then write a conditional expression to check them dynamically,
3031    /// and return true. All accesses in the chain are checked by the generated
3032    /// expression.
3033    ///
3034    /// This assumes that the [`BoundsCheckPolicy`] for `chain` is [`ReadZeroSkipWrite`].
3035    ///
3036    /// The text written is of the form:
3037    ///
3038    /// ```ignore
3039    /// {level}{prefix}uint(i) < 4 && uint(j) < 10
3040    /// ```
3041    ///
3042    /// where `{level}` and `{prefix}` are the arguments to this function. For [`Store`]
3043    /// statements, presumably these arguments start an indented `if` statement; for
3044    /// [`Load`] expressions, the caller is probably building up a ternary `?:`
3045    /// expression. In either case, what is written is not a complete syntactic structure
3046    /// in its own right, and the caller will have to finish it off if we return `true`.
3047    ///
3048    /// If no expression is written, return false.
3049    ///
3050    /// [`BoundsCheckPolicy`]: index::BoundsCheckPolicy
3051    /// [`ReadZeroSkipWrite`]: index::BoundsCheckPolicy::ReadZeroSkipWrite
3052    /// [`Store`]: crate::Statement::Store
3053    /// [`Load`]: crate::Expression::Load
3054    fn put_bounds_checks(
3055        &mut self,
3056        chain: Handle<crate::Expression>,
3057        context: &ExpressionContext,
3058        level: back::Level,
3059        prefix: &'static str,
3060    ) -> Result<bool, Error> {
3061        let mut check_written = false;
3062
3063        // Iterate over the access chain, handling each required bounds check.
3064        for item in context.bounds_check_iter(chain) {
3065            let BoundsCheck {
3066                base,
3067                index,
3068                length,
3069            } = item;
3070
3071            if check_written {
3072                write!(self.out, " && ")?;
3073            } else {
3074                write!(self.out, "{level}{prefix}")?;
3075                check_written = true;
3076            }
3077
3078            // Check that the index falls within bounds. Do this with a single
3079            // comparison, by casting the index to `uint` first, so that negative
3080            // indices become large positive values.
3081            write!(self.out, "uint(")?;
3082            self.put_index(index, context, true)?;
3083            self.out.write_str(") < ")?;
3084            match length {
3085                index::IndexableLength::Known(value) => write!(self.out, "{value}")?,
3086                index::IndexableLength::Dynamic => {
3087                    let global = context.function.originating_global(base).ok_or_else(|| {
3088                        Error::GenericValidation("Could not find originating global".into())
3089                    })?;
3090                    write!(self.out, "1 + ")?;
3091                    self.put_dynamic_array_max_index(global, context)?
3092                }
3093            }
3094        }
3095
3096        Ok(check_written)
3097    }
3098
3099    /// Write the access chain `chain`.
3100    ///
3101    /// `chain` is a subtree of [`Access`] and [`AccessIndex`] expressions,
3102    /// operating either on a pointer to a value, or on a value directly.
3103    ///
3104    /// Generate bounds checks code only if `policy` is [`Restrict`]. The
3105    /// [`ReadZeroSkipWrite`] policy requires checks before any accesses take place, so
3106    /// that must be handled in the caller.
3107    ///
3108    /// Handle the entire chain, recursing back into `put_expression` only for index
3109    /// expressions and the base expression that originates the pointer or composite value
3110    /// being accessed. This allows `put_expression` to assume that any `Access` or
3111    /// `AccessIndex` expressions it sees are the top of a chain, so it can emit
3112    /// `ReadZeroSkipWrite` checks.
3113    ///
3114    /// [`Access`]: crate::Expression::Access
3115    /// [`AccessIndex`]: crate::Expression::AccessIndex
3116    /// [`Restrict`]: crate::proc::index::BoundsCheckPolicy::Restrict
3117    /// [`ReadZeroSkipWrite`]: crate::proc::index::BoundsCheckPolicy::ReadZeroSkipWrite
3118    fn put_access_chain(
3119        &mut self,
3120        chain: Handle<crate::Expression>,
3121        policy: index::BoundsCheckPolicy,
3122        context: &ExpressionContext,
3123    ) -> BackendResult {
3124        match context.function.expressions[chain] {
3125            crate::Expression::Access { base, index } => {
3126                let mut base_ty = context.resolve_type(base);
3127
3128                // Look through any pointers to see what we're really indexing.
3129                if let crate::TypeInner::Pointer { base, space: _ } = *base_ty {
3130                    base_ty = &context.module.types[base].inner;
3131                }
3132
3133                self.put_subscripted_access_chain(
3134                    base,
3135                    base_ty,
3136                    index::GuardedIndex::Expression(index),
3137                    policy,
3138                    context,
3139                )?;
3140            }
3141            crate::Expression::AccessIndex { base, index } => {
3142                let base_resolution = &context.info[base].ty;
3143                let mut base_ty = base_resolution.inner_with(&context.module.types);
3144                let mut base_ty_handle = base_resolution.handle();
3145
3146                // Look through any pointers to see what we're really indexing.
3147                if let crate::TypeInner::Pointer { base, space: _ } = *base_ty {
3148                    base_ty = &context.module.types[base].inner;
3149                    base_ty_handle = Some(base);
3150                }
3151
3152                // Handle structs and anything else that can use `.x` syntax here, so
3153                // `put_subscripted_access_chain` won't have to handle the absurd case of
3154                // indexing a struct with an expression.
3155                match *base_ty {
3156                    crate::TypeInner::Struct { .. } => {
3157                        let base_ty = base_ty_handle.unwrap();
3158                        self.put_access_chain(base, policy, context)?;
3159                        let name = &self.names[&NameKey::StructMember(base_ty, index)];
3160                        write!(self.out, ".{name}")?;
3161                    }
3162                    crate::TypeInner::ValuePointer { .. } | crate::TypeInner::Vector { .. } => {
3163                        self.put_access_chain(base, policy, context)?;
3164                        // Prior to Metal v2.1 component access for packed vectors wasn't available
3165                        // however array indexing is
3166                        if context.get_packed_vec_kind(base).is_some() {
3167                            write!(self.out, "[{index}]")?;
3168                        } else {
3169                            write!(self.out, ".{}", back::COMPONENTS[index as usize])?;
3170                        }
3171                    }
3172                    _ => {
3173                        self.put_subscripted_access_chain(
3174                            base,
3175                            base_ty,
3176                            index::GuardedIndex::Known(index),
3177                            policy,
3178                            context,
3179                        )?;
3180                    }
3181                }
3182            }
3183            _ => self.put_expression(chain, context, false)?,
3184        }
3185
3186        Ok(())
3187    }
3188
3189    /// Write a `[]`-style access of `base` by `index`.
3190    ///
3191    /// If `policy` is [`Restrict`], then generate code as needed to force all index
3192    /// values within bounds.
3193    ///
3194    /// The `base_ty` argument must be the type we are actually indexing, like [`Array`] or
3195    /// [`Vector`]. In other words, it's `base`'s type with any surrounding [`Pointer`]
3196    /// removed. Our callers often already have this handy.
3197    ///
3198    /// This only emits `[]` expressions; it doesn't handle struct member accesses or
3199    /// referencing vector components by name.
3200    ///
3201    /// [`Restrict`]: crate::proc::index::BoundsCheckPolicy::Restrict
3202    /// [`Array`]: crate::TypeInner::Array
3203    /// [`Vector`]: crate::TypeInner::Vector
3204    /// [`Pointer`]: crate::TypeInner::Pointer
3205    fn put_subscripted_access_chain(
3206        &mut self,
3207        base: Handle<crate::Expression>,
3208        base_ty: &crate::TypeInner,
3209        index: index::GuardedIndex,
3210        policy: index::BoundsCheckPolicy,
3211        context: &ExpressionContext,
3212    ) -> BackendResult {
3213        let accessing_wrapped_array = match *base_ty {
3214            crate::TypeInner::Array {
3215                size: crate::ArraySize::Constant(_) | crate::ArraySize::Pending(_),
3216                ..
3217            } => true,
3218            _ => false,
3219        };
3220        let accessing_wrapped_binding_array =
3221            matches!(*base_ty, crate::TypeInner::BindingArray { .. });
3222
3223        self.put_access_chain(base, policy, context)?;
3224        if accessing_wrapped_array {
3225            write!(self.out, ".{WRAPPED_ARRAY_FIELD}")?;
3226        }
3227        write!(self.out, "[")?;
3228
3229        // Decide whether this index needs to be clamped to fall within range.
3230        let restriction_needed = if policy == index::BoundsCheckPolicy::Restrict {
3231            context.access_needs_check(base, index)
3232        } else {
3233            None
3234        };
3235        if let Some(limit) = restriction_needed {
3236            write!(self.out, "{NAMESPACE}::min(unsigned(")?;
3237            self.put_index(index, context, true)?;
3238            write!(self.out, "), ")?;
3239            match limit {
3240                index::IndexableLength::Known(limit) => {
3241                    write!(self.out, "{}u", limit - 1)?;
3242                }
3243                index::IndexableLength::Dynamic => {
3244                    let global = context.function.originating_global(base).ok_or_else(|| {
3245                        Error::GenericValidation("Could not find originating global".into())
3246                    })?;
3247                    self.put_dynamic_array_max_index(global, context)?;
3248                }
3249            }
3250            write!(self.out, ")")?;
3251        } else {
3252            self.put_index(index, context, true)?;
3253        }
3254
3255        write!(self.out, "]")?;
3256
3257        if accessing_wrapped_binding_array {
3258            write!(self.out, ".{WRAPPED_ARRAY_FIELD}")?;
3259        }
3260
3261        Ok(())
3262    }
3263
3264    fn put_load(
3265        &mut self,
3266        pointer: Handle<crate::Expression>,
3267        context: &ExpressionContext,
3268        is_scoped: bool,
3269    ) -> BackendResult {
3270        // Since access chains never cross between address spaces, we can just
3271        // check the index bounds check policy once at the top.
3272        let policy = context.choose_bounds_check_policy(pointer);
3273        if policy == index::BoundsCheckPolicy::ReadZeroSkipWrite
3274            && self.put_bounds_checks(
3275                pointer,
3276                context,
3277                back::Level(0),
3278                if is_scoped { "" } else { "(" },
3279            )?
3280        {
3281            write!(self.out, " ? ")?;
3282            self.put_unchecked_load(pointer, policy, context)?;
3283            write!(self.out, " : DefaultConstructible()")?;
3284
3285            if !is_scoped {
3286                write!(self.out, ")")?;
3287            }
3288        } else {
3289            self.put_unchecked_load(pointer, policy, context)?;
3290        }
3291
3292        Ok(())
3293    }
3294
3295    fn put_unchecked_load(
3296        &mut self,
3297        pointer: Handle<crate::Expression>,
3298        policy: index::BoundsCheckPolicy,
3299        context: &ExpressionContext,
3300    ) -> BackendResult {
3301        let is_atomic_pointer = context
3302            .resolve_type(pointer)
3303            .is_atomic_pointer(&context.module.types);
3304
3305        if is_atomic_pointer {
3306            write!(
3307                self.out,
3308                "{NAMESPACE}::atomic_load_explicit({ATOMIC_REFERENCE}"
3309            )?;
3310            self.put_access_chain(pointer, policy, context)?;
3311            write!(self.out, ", {NAMESPACE}::memory_order_relaxed)")?;
3312        } else {
3313            // We don't do any dereferencing with `*` here as pointer arguments to functions
3314            // are done by `&` references and not `*` pointers. These do not need to be
3315            // dereferenced.
3316            self.put_access_chain(pointer, policy, context)?;
3317        }
3318
3319        Ok(())
3320    }
3321
3322    fn put_return_value(
3323        &mut self,
3324        level: back::Level,
3325        expr_handle: Handle<crate::Expression>,
3326        result_struct: Option<&str>,
3327        context: &ExpressionContext,
3328    ) -> BackendResult {
3329        match result_struct {
3330            Some(struct_name) => {
3331                let mut has_point_size = false;
3332                let result_ty = context.function.result.as_ref().unwrap().ty;
3333                match context.module.types[result_ty].inner {
3334                    crate::TypeInner::Struct { ref members, .. } => {
3335                        let tmp = "_tmp";
3336                        write!(self.out, "{level}const auto {tmp} = ")?;
3337                        self.put_expression(expr_handle, context, true)?;
3338                        writeln!(self.out, ";")?;
3339                        write!(self.out, "{level}return {struct_name} {{")?;
3340
3341                        let mut is_first = true;
3342
3343                        for (index, member) in members.iter().enumerate() {
3344                            if let Some(crate::Binding::BuiltIn(crate::BuiltIn::PointSize)) =
3345                                member.binding
3346                            {
3347                                has_point_size = true;
3348                                if !context.pipeline_options.allow_and_force_point_size {
3349                                    continue;
3350                                }
3351                            }
3352
3353                            let comma = if is_first { "" } else { "," };
3354                            is_first = false;
3355                            let name = &self.names[&NameKey::StructMember(result_ty, index as u32)];
3356                            // HACK: we are forcefully deduplicating the expression here
3357                            // to convert from a wrapped struct to a raw array, e.g.
3358                            // `float gl_ClipDistance1 [[clip_distance]] [1];`.
3359                            if let crate::TypeInner::Array {
3360                                size: crate::ArraySize::Constant(size),
3361                                ..
3362                            } = context.module.types[member.ty].inner
3363                            {
3364                                write!(self.out, "{comma} {{")?;
3365                                for j in 0..size.get() {
3366                                    if j != 0 {
3367                                        write!(self.out, ",")?;
3368                                    }
3369                                    write!(self.out, "{tmp}.{name}.{WRAPPED_ARRAY_FIELD}[{j}]")?;
3370                                }
3371                                write!(self.out, "}}")?;
3372                            } else {
3373                                write!(self.out, "{comma} {tmp}.{name}")?;
3374                            }
3375                        }
3376                    }
3377                    _ => {
3378                        write!(self.out, "{level}return {struct_name} {{ ")?;
3379                        self.put_expression(expr_handle, context, true)?;
3380                    }
3381                }
3382
3383                if let FunctionOrigin::EntryPoint(ep_index) = context.origin {
3384                    let stage = context.module.entry_points[ep_index as usize].stage;
3385                    if context.pipeline_options.allow_and_force_point_size
3386                        && stage == crate::ShaderStage::Vertex
3387                        && !has_point_size
3388                    {
3389                        // point size was injected and comes last
3390                        write!(self.out, ", 1.0")?;
3391                    }
3392                }
3393                write!(self.out, " }}")?;
3394            }
3395            None => {
3396                write!(self.out, "{level}return ")?;
3397                self.put_expression(expr_handle, context, true)?;
3398            }
3399        }
3400        writeln!(self.out, ";")?;
3401        Ok(())
3402    }
3403
3404    /// Helper method used to find which expressions of a given function require baking
3405    ///
3406    /// # Notes
3407    /// This function overwrites the contents of `self.need_bake_expressions`
3408    fn update_expressions_to_bake(
3409        &mut self,
3410        func: &crate::Function,
3411        info: &valid::FunctionInfo,
3412        context: &ExpressionContext,
3413    ) {
3414        use crate::Expression;
3415        self.need_bake_expressions.clear();
3416
3417        for (expr_handle, expr) in func.expressions.iter() {
3418            // Expressions whose reference count is above the
3419            // threshold should always be stored in temporaries.
3420            let expr_info = &info[expr_handle];
3421            let min_ref_count = func.expressions[expr_handle].bake_ref_count();
3422            if min_ref_count <= expr_info.ref_count {
3423                self.need_bake_expressions.insert(expr_handle);
3424            } else {
3425                match expr_info.ty {
3426                    // force ray desc to be baked: it's used multiple times internally
3427                    TypeResolution::Handle(h)
3428                        if Some(h) == context.module.special_types.ray_desc =>
3429                    {
3430                        self.need_bake_expressions.insert(expr_handle);
3431                    }
3432                    _ => {}
3433                }
3434            }
3435
3436            if let Expression::Math {
3437                fun,
3438                arg,
3439                arg1,
3440                arg2,
3441                ..
3442            } = *expr
3443            {
3444                match fun {
3445                    // WGSL's `dot` function works on any `vecN` type, but Metal's only
3446                    // works on floating-point vectors, so we emit inline code for
3447                    // integer vector `dot` calls. But that code uses each argument `N`
3448                    // times, once for each component (see `put_dot_product`), so to
3449                    // avoid duplicated evaluation, we must bake integer operands.
3450                    // This applies both when using the polyfill (because of the duplicate
3451                    // evaluation issue) and when we don't use the polyfill (because we
3452                    // need them to be emitted before casting to packed chars -- see the
3453                    // comment at the call to `put_casting_to_packed_chars`).
3454                    crate::MathFunction::Dot4U8Packed | crate::MathFunction::Dot4I8Packed => {
3455                        self.need_bake_expressions.insert(arg);
3456                        self.need_bake_expressions.insert(arg1.unwrap());
3457                    }
3458                    crate::MathFunction::FirstLeadingBit => {
3459                        self.need_bake_expressions.insert(arg);
3460                    }
3461                    crate::MathFunction::Pack4xI8
3462                    | crate::MathFunction::Pack4xU8
3463                    | crate::MathFunction::Pack4xI8Clamp
3464                    | crate::MathFunction::Pack4xU8Clamp
3465                    | crate::MathFunction::Unpack4xI8
3466                    | crate::MathFunction::Unpack4xU8 => {
3467                        // On MSL < 2.1, we emit a polyfill for these functions that uses the
3468                        // argument multiple times. This is no longer necessary on MSL >= 2.1.
3469                        if context.lang_version < (2, 1) {
3470                            self.need_bake_expressions.insert(arg);
3471                        }
3472                    }
3473                    crate::MathFunction::ExtractBits => {
3474                        // Only argument 1 is re-used.
3475                        self.need_bake_expressions.insert(arg1.unwrap());
3476                    }
3477                    crate::MathFunction::InsertBits => {
3478                        // Only argument 2 is re-used.
3479                        self.need_bake_expressions.insert(arg2.unwrap());
3480                    }
3481                    crate::MathFunction::Sign => {
3482                        // WGSL's `sign` function works also on signed ints, but Metal's only
3483                        // works on floating points, so we emit inline code for integer `sign`
3484                        // calls. But that code uses each argument 2 times (see `put_isign`),
3485                        // so to avoid duplicated evaluation, we must bake the argument.
3486                        let inner = context.resolve_type(expr_handle);
3487                        if inner.scalar_kind() == Some(crate::ScalarKind::Sint) {
3488                            self.need_bake_expressions.insert(arg);
3489                        }
3490                    }
3491                    _ => {}
3492                }
3493            }
3494        }
3495    }
3496
3497    fn start_baking_expression(
3498        &mut self,
3499        handle: Handle<crate::Expression>,
3500        context: &ExpressionContext,
3501        name: &str,
3502    ) -> BackendResult {
3503        match context.info[handle].ty {
3504            TypeResolution::Handle(ty_handle) => {
3505                let ty_name = TypeContext {
3506                    handle: ty_handle,
3507                    gctx: context.module.to_ctx(),
3508                    names: &self.names,
3509                    access: crate::StorageAccess::empty(),
3510                    first_time: false,
3511                };
3512                write!(self.out, "{ty_name}")?;
3513            }
3514            TypeResolution::Value(crate::TypeInner::Scalar(scalar)) => {
3515                put_numeric_type(&mut self.out, scalar, &[])?;
3516            }
3517            TypeResolution::Value(crate::TypeInner::Vector { size, scalar }) => {
3518                put_numeric_type(&mut self.out, scalar, &[size])?;
3519            }
3520            TypeResolution::Value(crate::TypeInner::Matrix {
3521                columns,
3522                rows,
3523                scalar,
3524            }) => {
3525                put_numeric_type(&mut self.out, scalar, &[rows, columns])?;
3526            }
3527            TypeResolution::Value(crate::TypeInner::CooperativeMatrix {
3528                columns,
3529                rows,
3530                scalar,
3531                role: _,
3532            }) => {
3533                write!(
3534                    self.out,
3535                    "{}::simdgroup_{}{}x{}",
3536                    NAMESPACE,
3537                    scalar.to_msl_name(),
3538                    columns as u32,
3539                    rows as u32,
3540                )?;
3541            }
3542            TypeResolution::Value(ref other) => {
3543                log::warn!("Type {other:?} isn't a known local");
3544                return Err(Error::FeatureNotImplemented("weird local type".to_string()));
3545            }
3546        }
3547
3548        //TODO: figure out the naming scheme that wouldn't collide with user names.
3549        write!(self.out, " {name} = ")?;
3550
3551        Ok(())
3552    }
3553
3554    /// Cache a clamped level of detail value, if necessary.
3555    ///
3556    /// [`ImageLoad`] accesses covered by [`BoundsCheckPolicy::Restrict`] use a
3557    /// properly clamped level of detail value both in the access itself, and
3558    /// for fetching the size of the requested MIP level, needed to clamp the
3559    /// coordinates. To avoid recomputing this clamped level of detail, we cache
3560    /// it in a temporary variable, as part of the [`Emit`] statement covering
3561    /// the [`ImageLoad`] expression.
3562    ///
3563    /// [`ImageLoad`]: crate::Expression::ImageLoad
3564    /// [`BoundsCheckPolicy::Restrict`]: index::BoundsCheckPolicy::Restrict
3565    /// [`Emit`]: crate::Statement::Emit
3566    fn put_cache_restricted_level(
3567        &mut self,
3568        load: Handle<crate::Expression>,
3569        image: Handle<crate::Expression>,
3570        mip_level: Option<Handle<crate::Expression>>,
3571        indent: back::Level,
3572        context: &StatementContext,
3573    ) -> BackendResult {
3574        // Does this image access actually require (or even permit) a
3575        // level-of-detail, and does the policy require us to restrict it?
3576        let level_of_detail = match mip_level {
3577            Some(level) => level,
3578            None => return Ok(()),
3579        };
3580
3581        if context.expression.policies.image_load != index::BoundsCheckPolicy::Restrict
3582            || !context.expression.image_needs_lod(image)
3583        {
3584            return Ok(());
3585        }
3586
3587        write!(self.out, "{}uint {} = ", indent, ClampedLod(load),)?;
3588        self.put_restricted_scalar_image_index(
3589            image,
3590            level_of_detail,
3591            "get_num_mip_levels",
3592            &context.expression,
3593        )?;
3594        writeln!(self.out, ";")?;
3595
3596        Ok(())
3597    }
3598
3599    /// Convert the arguments of `Dot4{I, U}Packed` to `packed_(u?)char4`.
3600    ///
3601    /// Caches the results in temporary variables (whose names are derived from
3602    /// the original variable names). This caching avoids the need to redo the
3603    /// casting for each vector component when emitting the dot product.
3604    fn put_casting_to_packed_chars(
3605        &mut self,
3606        fun: crate::MathFunction,
3607        arg0: Handle<crate::Expression>,
3608        arg1: Handle<crate::Expression>,
3609        indent: back::Level,
3610        context: &StatementContext<'_>,
3611    ) -> Result<(), Error> {
3612        let packed_type = match fun {
3613            crate::MathFunction::Dot4I8Packed => "packed_char4",
3614            crate::MathFunction::Dot4U8Packed => "packed_uchar4",
3615            _ => unreachable!(),
3616        };
3617
3618        for arg in [arg0, arg1] {
3619            write!(
3620                self.out,
3621                "{indent}{packed_type} {0} = as_type<{packed_type}>(",
3622                Reinterpreted::new(packed_type, arg)
3623            )?;
3624            self.put_expression(arg, &context.expression, true)?;
3625            writeln!(self.out, ");")?;
3626        }
3627
3628        Ok(())
3629    }
3630
3631    fn put_block(
3632        &mut self,
3633        level: back::Level,
3634        statements: &[crate::Statement],
3635        context: &StatementContext,
3636    ) -> BackendResult {
3637        // Add to the set in order to track the stack size.
3638        #[cfg(test)]
3639        self.put_block_stack_pointers
3640            .insert(ptr::from_ref(&level).cast());
3641
3642        for statement in statements {
3643            log::trace!("statement[{}] {:?}", level.0, statement);
3644            match *statement {
3645                crate::Statement::Emit(ref range) => {
3646                    for handle in range.clone() {
3647                        use crate::MathFunction as Mf;
3648
3649                        match context.expression.function.expressions[handle] {
3650                            // `ImageLoad` expressions covered by the `Restrict` bounds check policy
3651                            // may need to cache a clamped version of their level-of-detail argument.
3652                            crate::Expression::ImageLoad {
3653                                image,
3654                                level: mip_level,
3655                                ..
3656                            } => {
3657                                self.put_cache_restricted_level(
3658                                    handle, image, mip_level, level, context,
3659                                )?;
3660                            }
3661
3662                            // If we are going to write a `Dot4I8Packed` or `Dot4U8Packed` on Metal
3663                            // 2.1+ then we introduce two intermediate variables that recast the two
3664                            // arguments as packed (signed or unsigned) chars. The actual dot product
3665                            // is implemented in `Self::put_expression`, and it uses both of these
3666                            // intermediate variables multiple times. There's no danger that the
3667                            // original arguments get modified between the definition of these
3668                            // intermediate variables and the implementation of the actual dot
3669                            // product since we require the inputs of `Dot4{I, U}Packed` to be baked.
3670                            crate::Expression::Math {
3671                                fun: fun @ (Mf::Dot4I8Packed | Mf::Dot4U8Packed),
3672                                arg,
3673                                arg1,
3674                                ..
3675                            } if context.expression.lang_version >= (2, 1) => {
3676                                self.put_casting_to_packed_chars(
3677                                    fun,
3678                                    arg,
3679                                    arg1.unwrap(),
3680                                    level,
3681                                    context,
3682                                )?;
3683                            }
3684
3685                            _ => (),
3686                        }
3687
3688                        let ptr_class = context.expression.resolve_type(handle).pointer_space();
3689                        let expr_name = if ptr_class.is_some() {
3690                            None // don't bake pointer expressions (just yet)
3691                        } else if let Some(name) =
3692                            context.expression.function.named_expressions.get(&handle)
3693                        {
3694                            // The `crate::Function::named_expressions` table holds
3695                            // expressions that should be saved in temporaries once they
3696                            // are `Emit`ted. We only add them to `self.named_expressions`
3697                            // when we reach the `Emit` that covers them, so that we don't
3698                            // try to use their names before we've actually initialized
3699                            // the temporary that holds them.
3700                            //
3701                            // Don't assume the names in `named_expressions` are unique,
3702                            // or even valid. Use the `Namer`.
3703                            Some(self.namer.call(name))
3704                        } else {
3705                            // If this expression is an index that we're going to first compare
3706                            // against a limit, and then actually use as an index, then we may
3707                            // want to cache it in a temporary, to avoid evaluating it twice.
3708                            let bake = if context.expression.guarded_indices.contains(handle) {
3709                                true
3710                            } else {
3711                                self.need_bake_expressions.contains(&handle)
3712                            };
3713
3714                            if bake {
3715                                Some(Baked(handle).to_string())
3716                            } else {
3717                                None
3718                            }
3719                        };
3720
3721                        if let Some(name) = expr_name {
3722                            write!(self.out, "{level}")?;
3723                            self.start_baking_expression(handle, &context.expression, &name)?;
3724                            self.put_expression(handle, &context.expression, true)?;
3725                            self.named_expressions.insert(handle, name);
3726                            writeln!(self.out, ";")?;
3727                        }
3728                    }
3729                }
3730                crate::Statement::Block(ref block) => {
3731                    if !block.is_empty() {
3732                        writeln!(self.out, "{level}{{")?;
3733                        self.put_block(level.next(), block, context)?;
3734                        writeln!(self.out, "{level}}}")?;
3735                    }
3736                }
3737                crate::Statement::If {
3738                    condition,
3739                    ref accept,
3740                    ref reject,
3741                } => {
3742                    write!(self.out, "{level}if (")?;
3743                    self.put_expression(condition, &context.expression, true)?;
3744                    writeln!(self.out, ") {{")?;
3745                    self.put_block(level.next(), accept, context)?;
3746                    if !reject.is_empty() {
3747                        writeln!(self.out, "{level}}} else {{")?;
3748                        self.put_block(level.next(), reject, context)?;
3749                    }
3750                    writeln!(self.out, "{level}}}")?;
3751                }
3752                crate::Statement::Switch {
3753                    selector,
3754                    ref cases,
3755                } => {
3756                    write!(self.out, "{level}switch(")?;
3757                    self.put_expression(selector, &context.expression, true)?;
3758                    writeln!(self.out, ") {{")?;
3759                    let lcase = level.next();
3760                    for case in cases.iter() {
3761                        match case.value {
3762                            crate::SwitchValue::I32(value) => {
3763                                write!(self.out, "{lcase}case {value}:")?;
3764                            }
3765                            crate::SwitchValue::U32(value) => {
3766                                write!(self.out, "{lcase}case {value}u:")?;
3767                            }
3768                            crate::SwitchValue::Default => {
3769                                write!(self.out, "{lcase}default:")?;
3770                            }
3771                        }
3772
3773                        let write_block_braces = !(case.fall_through && case.body.is_empty());
3774                        if write_block_braces {
3775                            writeln!(self.out, " {{")?;
3776                        } else {
3777                            writeln!(self.out)?;
3778                        }
3779
3780                        self.put_block(lcase.next(), &case.body, context)?;
3781                        if !case.fall_through && case.body.last().is_none_or(|s| !s.is_terminator())
3782                        {
3783                            writeln!(self.out, "{}break;", lcase.next())?;
3784                        }
3785
3786                        if write_block_braces {
3787                            writeln!(self.out, "{lcase}}}")?;
3788                        }
3789                    }
3790                    writeln!(self.out, "{level}}}")?;
3791                }
3792                crate::Statement::Loop {
3793                    ref body,
3794                    ref continuing,
3795                    break_if,
3796                } => {
3797                    let force_loop_bound_statements =
3798                        self.gen_force_bounded_loop_statements(level, context);
3799                    let gate_name = (!continuing.is_empty() || break_if.is_some())
3800                        .then(|| self.namer.call("loop_init"));
3801
3802                    if let Some((ref decl, _)) = force_loop_bound_statements {
3803                        writeln!(self.out, "{decl}")?;
3804                    }
3805                    if let Some(ref gate_name) = gate_name {
3806                        writeln!(self.out, "{level}bool {gate_name} = true;")?;
3807                    }
3808
3809                    writeln!(self.out, "{level}while(true) {{",)?;
3810                    if let Some((_, ref break_and_inc)) = force_loop_bound_statements {
3811                        writeln!(self.out, "{break_and_inc}")?;
3812                    }
3813                    if let Some(ref gate_name) = gate_name {
3814                        let lif = level.next();
3815                        let lcontinuing = lif.next();
3816                        writeln!(self.out, "{lif}if (!{gate_name}) {{")?;
3817                        self.put_block(lcontinuing, continuing, context)?;
3818                        if let Some(condition) = break_if {
3819                            write!(self.out, "{lcontinuing}if (")?;
3820                            self.put_expression(condition, &context.expression, true)?;
3821                            writeln!(self.out, ") {{")?;
3822                            writeln!(self.out, "{}break;", lcontinuing.next())?;
3823                            writeln!(self.out, "{lcontinuing}}}")?;
3824                        }
3825                        writeln!(self.out, "{lif}}}")?;
3826                        writeln!(self.out, "{lif}{gate_name} = false;")?;
3827                    }
3828                    self.put_block(level.next(), body, context)?;
3829
3830                    writeln!(self.out, "{level}}}")?;
3831                }
3832                crate::Statement::Break => {
3833                    writeln!(self.out, "{level}break;")?;
3834                }
3835                crate::Statement::Continue => {
3836                    writeln!(self.out, "{level}continue;")?;
3837                }
3838                crate::Statement::Return {
3839                    value: Some(expr_handle),
3840                } => {
3841                    self.put_return_value(
3842                        level,
3843                        expr_handle,
3844                        context.result_struct,
3845                        &context.expression,
3846                    )?;
3847                }
3848                crate::Statement::Return { value: None } => {
3849                    writeln!(self.out, "{level}return;")?;
3850                }
3851                crate::Statement::Kill => {
3852                    writeln!(self.out, "{level}{NAMESPACE}::discard_fragment();")?;
3853                }
3854                crate::Statement::ControlBarrier(flags)
3855                | crate::Statement::MemoryBarrier(flags) => {
3856                    self.write_barrier(flags, level)?;
3857                }
3858                crate::Statement::Store { pointer, value } => {
3859                    self.put_store(pointer, value, level, context)?
3860                }
3861                crate::Statement::ImageStore {
3862                    image,
3863                    coordinate,
3864                    array_index,
3865                    value,
3866                } => {
3867                    let address = TexelAddress {
3868                        coordinate,
3869                        array_index,
3870                        sample: None,
3871                        level: None,
3872                    };
3873                    self.put_image_store(level, image, &address, value, context)?
3874                }
3875                crate::Statement::Call {
3876                    function,
3877                    ref arguments,
3878                    result,
3879                } => {
3880                    write!(self.out, "{level}")?;
3881                    if let Some(expr) = result {
3882                        let name = Baked(expr).to_string();
3883                        self.start_baking_expression(expr, &context.expression, &name)?;
3884                        self.named_expressions.insert(expr, name);
3885                    }
3886                    let fun_name = &self.names[&NameKey::Function(function)];
3887                    write!(self.out, "{fun_name}(")?;
3888                    // first, write down the actual arguments
3889                    for (i, &handle) in arguments.iter().enumerate() {
3890                        if i != 0 {
3891                            write!(self.out, ", ")?;
3892                        }
3893                        self.put_expression(handle, &context.expression, true)?;
3894                    }
3895                    // follow-up with any global resources used
3896                    let mut separate = !arguments.is_empty();
3897                    let fun_info = &context.expression.mod_info[function];
3898                    let mut needs_buffer_sizes = false;
3899                    for (handle, var) in context.expression.module.global_variables.iter() {
3900                        if fun_info[handle].is_empty() {
3901                            continue;
3902                        }
3903                        if var.space.needs_pass_through() {
3904                            let name = &self.names[&NameKey::GlobalVariable(handle)];
3905                            if separate {
3906                                write!(self.out, ", ")?;
3907                            } else {
3908                                separate = true;
3909                            }
3910                            write!(self.out, "{name}")?;
3911                        }
3912                        needs_buffer_sizes |=
3913                            needs_array_length(var.ty, &context.expression.module.types);
3914                    }
3915                    if needs_buffer_sizes {
3916                        if separate {
3917                            write!(self.out, ", ")?;
3918                        }
3919                        write!(self.out, "_buffer_sizes")?;
3920                    }
3921
3922                    // done
3923                    writeln!(self.out, ");")?;
3924                }
3925                crate::Statement::Atomic {
3926                    pointer,
3927                    ref fun,
3928                    value,
3929                    result,
3930                } => {
3931                    let context = &context.expression;
3932
3933                    // This backend supports `SHADER_INT64_ATOMIC_MIN_MAX` but not
3934                    // `SHADER_INT64_ATOMIC_ALL_OPS`, so we can assume that if `result` is
3935                    // `Some`, we are not operating on a 64-bit value, and that if we are
3936                    // operating on a 64-bit value, `result` is `None`.
3937                    write!(self.out, "{level}")?;
3938                    let fun_key = if let Some(result) = result {
3939                        let res_name = Baked(result).to_string();
3940                        self.start_baking_expression(result, context, &res_name)?;
3941                        self.named_expressions.insert(result, res_name);
3942                        fun.to_msl()
3943                    } else if context.resolve_type(value).scalar_width() == Some(8) {
3944                        fun.to_msl_64_bit()?
3945                    } else {
3946                        fun.to_msl()
3947                    };
3948
3949                    // If the pointer we're passing to the atomic operation needs to be conditional
3950                    // for `ReadZeroSkipWrite`, the condition needs to *surround* the atomic op, and
3951                    // the pointer operand should be unchecked.
3952                    let policy = context.choose_bounds_check_policy(pointer);
3953                    let checked = policy == index::BoundsCheckPolicy::ReadZeroSkipWrite
3954                        && self.put_bounds_checks(pointer, context, back::Level(0), "")?;
3955
3956                    // If requested and successfully put bounds checks, continue the ternary expression.
3957                    if checked {
3958                        write!(self.out, " ? ")?;
3959                    }
3960
3961                    // Put the atomic function invocation.
3962                    match *fun {
3963                        crate::AtomicFunction::Exchange { compare: Some(cmp) } => {
3964                            write!(self.out, "{ATOMIC_COMP_EXCH_FUNCTION}({ATOMIC_REFERENCE}")?;
3965                            self.put_access_chain(pointer, policy, context)?;
3966                            write!(self.out, ", ")?;
3967                            self.put_expression(cmp, context, true)?;
3968                            write!(self.out, ", ")?;
3969                            self.put_expression(value, context, true)?;
3970                            write!(self.out, ")")?;
3971                        }
3972                        _ => {
3973                            write!(
3974                                self.out,
3975                                "{NAMESPACE}::atomic_{fun_key}_explicit({ATOMIC_REFERENCE}"
3976                            )?;
3977                            self.put_access_chain(pointer, policy, context)?;
3978                            write!(self.out, ", ")?;
3979                            self.put_expression(value, context, true)?;
3980                            write!(self.out, ", {NAMESPACE}::memory_order_relaxed)")?;
3981                        }
3982                    }
3983
3984                    // Finish the ternary expression.
3985                    if checked {
3986                        write!(self.out, " : DefaultConstructible()")?;
3987                    }
3988
3989                    // Done
3990                    writeln!(self.out, ";")?;
3991                }
3992                crate::Statement::ImageAtomic {
3993                    image,
3994                    coordinate,
3995                    array_index,
3996                    fun,
3997                    value,
3998                } => {
3999                    let address = TexelAddress {
4000                        coordinate,
4001                        array_index,
4002                        sample: None,
4003                        level: None,
4004                    };
4005                    self.put_image_atomic(level, image, &address, fun, value, context)?
4006                }
4007                crate::Statement::WorkGroupUniformLoad { pointer, result } => {
4008                    self.write_barrier(crate::Barrier::WORK_GROUP, level)?;
4009
4010                    write!(self.out, "{level}")?;
4011                    let name = self.namer.call("");
4012                    self.start_baking_expression(result, &context.expression, &name)?;
4013                    self.put_load(pointer, &context.expression, true)?;
4014                    self.named_expressions.insert(result, name);
4015
4016                    writeln!(self.out, ";")?;
4017                    self.write_barrier(crate::Barrier::WORK_GROUP, level)?;
4018                }
4019                crate::Statement::RayQuery { query, ref fun } => {
4020                    if context.expression.lang_version < (2, 4) {
4021                        return Err(Error::UnsupportedRayTracing);
4022                    }
4023
4024                    match *fun {
4025                        crate::RayQueryFunction::Initialize {
4026                            acceleration_structure,
4027                            descriptor,
4028                        } => {
4029                            //TODO: how to deal with winding?
4030                            write!(self.out, "{level}")?;
4031                            self.put_expression(query, &context.expression, true)?;
4032                            writeln!(self.out, ".{RAY_QUERY_FIELD_INTERSECTOR}.assume_geometry_type({RT_NAMESPACE}::geometry_type::triangle);")?;
4033                            {
4034                                let f_opaque = back::RayFlag::CULL_OPAQUE.bits();
4035                                let f_no_opaque = back::RayFlag::CULL_NO_OPAQUE.bits();
4036                                write!(self.out, "{level}")?;
4037                                self.put_expression(query, &context.expression, true)?;
4038                                write!(
4039                                    self.out,
4040                                    ".{RAY_QUERY_FIELD_INTERSECTOR}.set_opacity_cull_mode(("
4041                                )?;
4042                                self.put_expression(descriptor, &context.expression, true)?;
4043                                write!(self.out, ".flags & {f_opaque}) != 0 ? {RT_NAMESPACE}::opacity_cull_mode::opaque : (")?;
4044                                self.put_expression(descriptor, &context.expression, true)?;
4045                                write!(self.out, ".flags & {f_no_opaque}) != 0 ? {RT_NAMESPACE}::opacity_cull_mode::non_opaque : ")?;
4046                                writeln!(self.out, "{RT_NAMESPACE}::opacity_cull_mode::none);")?;
4047                            }
4048                            {
4049                                let f_opaque = back::RayFlag::OPAQUE.bits();
4050                                let f_no_opaque = back::RayFlag::NO_OPAQUE.bits();
4051                                write!(self.out, "{level}")?;
4052                                self.put_expression(query, &context.expression, true)?;
4053                                write!(self.out, ".{RAY_QUERY_FIELD_INTERSECTOR}.force_opacity((")?;
4054                                self.put_expression(descriptor, &context.expression, true)?;
4055                                write!(self.out, ".flags & {f_opaque}) != 0 ? {RT_NAMESPACE}::forced_opacity::opaque : (")?;
4056                                self.put_expression(descriptor, &context.expression, true)?;
4057                                write!(self.out, ".flags & {f_no_opaque}) != 0 ? {RT_NAMESPACE}::forced_opacity::non_opaque : ")?;
4058                                writeln!(self.out, "{RT_NAMESPACE}::forced_opacity::none);")?;
4059                            }
4060                            {
4061                                let flag = back::RayFlag::TERMINATE_ON_FIRST_HIT.bits();
4062                                write!(self.out, "{level}")?;
4063                                self.put_expression(query, &context.expression, true)?;
4064                                write!(
4065                                    self.out,
4066                                    ".{RAY_QUERY_FIELD_INTERSECTOR}.accept_any_intersection(("
4067                                )?;
4068                                self.put_expression(descriptor, &context.expression, true)?;
4069                                writeln!(self.out, ".flags & {flag}) != 0);")?;
4070                            }
4071
4072                            write!(self.out, "{level}")?;
4073                            self.put_expression(query, &context.expression, true)?;
4074                            write!(self.out, ".{RAY_QUERY_FIELD_INTERSECTION} = ")?;
4075                            self.put_expression(query, &context.expression, true)?;
4076                            write!(
4077                                self.out,
4078                                ".{RAY_QUERY_FIELD_INTERSECTOR}.intersect({RT_NAMESPACE}::ray("
4079                            )?;
4080                            self.put_expression(descriptor, &context.expression, true)?;
4081                            write!(self.out, ".origin, ")?;
4082                            self.put_expression(descriptor, &context.expression, true)?;
4083                            write!(self.out, ".dir, ")?;
4084                            self.put_expression(descriptor, &context.expression, true)?;
4085                            write!(self.out, ".tmin, ")?;
4086                            self.put_expression(descriptor, &context.expression, true)?;
4087                            write!(self.out, ".tmax), ")?;
4088                            self.put_expression(acceleration_structure, &context.expression, true)?;
4089                            write!(self.out, ", ")?;
4090                            self.put_expression(descriptor, &context.expression, true)?;
4091                            write!(self.out, ".cull_mask);")?;
4092
4093                            write!(self.out, "{level}")?;
4094                            self.put_expression(query, &context.expression, true)?;
4095                            writeln!(self.out, ".{RAY_QUERY_FIELD_READY} = true;")?;
4096                        }
4097                        crate::RayQueryFunction::Proceed { result } => {
4098                            write!(self.out, "{level}")?;
4099                            let name = Baked(result).to_string();
4100                            self.start_baking_expression(result, &context.expression, &name)?;
4101                            self.named_expressions.insert(result, name);
4102                            self.put_expression(query, &context.expression, true)?;
4103                            writeln!(self.out, ".{RAY_QUERY_FIELD_READY};")?;
4104                            if RAY_QUERY_MODERN_SUPPORT {
4105                                write!(self.out, "{level}")?;
4106                                self.put_expression(query, &context.expression, true)?;
4107                                writeln!(self.out, ".?.next();")?;
4108                            }
4109                        }
4110                        crate::RayQueryFunction::GenerateIntersection { hit_t } => {
4111                            if RAY_QUERY_MODERN_SUPPORT {
4112                                write!(self.out, "{level}")?;
4113                                self.put_expression(query, &context.expression, true)?;
4114                                write!(self.out, ".?.commit_bounding_box_intersection(")?;
4115                                self.put_expression(hit_t, &context.expression, true)?;
4116                                writeln!(self.out, ");")?;
4117                            } else {
4118                                log::warn!("Ray Query GenerateIntersection is not yet supported");
4119                            }
4120                        }
4121                        crate::RayQueryFunction::ConfirmIntersection => {
4122                            if RAY_QUERY_MODERN_SUPPORT {
4123                                write!(self.out, "{level}")?;
4124                                self.put_expression(query, &context.expression, true)?;
4125                                writeln!(self.out, ".?.commit_triangle_intersection();")?;
4126                            } else {
4127                                log::warn!("Ray Query ConfirmIntersection is not yet supported");
4128                            }
4129                        }
4130                        crate::RayQueryFunction::Terminate => {
4131                            if RAY_QUERY_MODERN_SUPPORT {
4132                                write!(self.out, "{level}")?;
4133                                self.put_expression(query, &context.expression, true)?;
4134                                writeln!(self.out, ".?.abort();")?;
4135                            }
4136                            write!(self.out, "{level}")?;
4137                            self.put_expression(query, &context.expression, true)?;
4138                            writeln!(self.out, ".{RAY_QUERY_FIELD_READY} = false;")?;
4139                        }
4140                    }
4141                }
4142                crate::Statement::SubgroupBallot { result, predicate } => {
4143                    write!(self.out, "{level}")?;
4144                    let name = self.namer.call("");
4145                    self.start_baking_expression(result, &context.expression, &name)?;
4146                    self.named_expressions.insert(result, name);
4147                    write!(
4148                        self.out,
4149                        "{NAMESPACE}::uint4((uint64_t){NAMESPACE}::simd_ballot("
4150                    )?;
4151                    if let Some(predicate) = predicate {
4152                        self.put_expression(predicate, &context.expression, true)?;
4153                    } else {
4154                        write!(self.out, "true")?;
4155                    }
4156                    writeln!(self.out, "), 0, 0, 0);")?;
4157                }
4158                crate::Statement::SubgroupCollectiveOperation {
4159                    op,
4160                    collective_op,
4161                    argument,
4162                    result,
4163                } => {
4164                    write!(self.out, "{level}")?;
4165                    let name = self.namer.call("");
4166                    self.start_baking_expression(result, &context.expression, &name)?;
4167                    self.named_expressions.insert(result, name);
4168                    match (collective_op, op) {
4169                        (crate::CollectiveOperation::Reduce, crate::SubgroupOperation::All) => {
4170                            write!(self.out, "{NAMESPACE}::simd_all(")?
4171                        }
4172                        (crate::CollectiveOperation::Reduce, crate::SubgroupOperation::Any) => {
4173                            write!(self.out, "{NAMESPACE}::simd_any(")?
4174                        }
4175                        (crate::CollectiveOperation::Reduce, crate::SubgroupOperation::Add) => {
4176                            write!(self.out, "{NAMESPACE}::simd_sum(")?
4177                        }
4178                        (crate::CollectiveOperation::Reduce, crate::SubgroupOperation::Mul) => {
4179                            write!(self.out, "{NAMESPACE}::simd_product(")?
4180                        }
4181                        (crate::CollectiveOperation::Reduce, crate::SubgroupOperation::Max) => {
4182                            write!(self.out, "{NAMESPACE}::simd_max(")?
4183                        }
4184                        (crate::CollectiveOperation::Reduce, crate::SubgroupOperation::Min) => {
4185                            write!(self.out, "{NAMESPACE}::simd_min(")?
4186                        }
4187                        (crate::CollectiveOperation::Reduce, crate::SubgroupOperation::And) => {
4188                            write!(self.out, "{NAMESPACE}::simd_and(")?
4189                        }
4190                        (crate::CollectiveOperation::Reduce, crate::SubgroupOperation::Or) => {
4191                            write!(self.out, "{NAMESPACE}::simd_or(")?
4192                        }
4193                        (crate::CollectiveOperation::Reduce, crate::SubgroupOperation::Xor) => {
4194                            write!(self.out, "{NAMESPACE}::simd_xor(")?
4195                        }
4196                        (
4197                            crate::CollectiveOperation::ExclusiveScan,
4198                            crate::SubgroupOperation::Add,
4199                        ) => write!(self.out, "{NAMESPACE}::simd_prefix_exclusive_sum(")?,
4200                        (
4201                            crate::CollectiveOperation::ExclusiveScan,
4202                            crate::SubgroupOperation::Mul,
4203                        ) => write!(self.out, "{NAMESPACE}::simd_prefix_exclusive_product(")?,
4204                        (
4205                            crate::CollectiveOperation::InclusiveScan,
4206                            crate::SubgroupOperation::Add,
4207                        ) => write!(self.out, "{NAMESPACE}::simd_prefix_inclusive_sum(")?,
4208                        (
4209                            crate::CollectiveOperation::InclusiveScan,
4210                            crate::SubgroupOperation::Mul,
4211                        ) => write!(self.out, "{NAMESPACE}::simd_prefix_inclusive_product(")?,
4212                        _ => unimplemented!(),
4213                    }
4214                    self.put_expression(argument, &context.expression, true)?;
4215                    writeln!(self.out, ");")?;
4216                }
4217                crate::Statement::SubgroupGather {
4218                    mode,
4219                    argument,
4220                    result,
4221                } => {
4222                    write!(self.out, "{level}")?;
4223                    let name = self.namer.call("");
4224                    self.start_baking_expression(result, &context.expression, &name)?;
4225                    self.named_expressions.insert(result, name);
4226                    match mode {
4227                        crate::GatherMode::BroadcastFirst => {
4228                            write!(self.out, "{NAMESPACE}::simd_broadcast_first(")?;
4229                        }
4230                        crate::GatherMode::Broadcast(_) => {
4231                            write!(self.out, "{NAMESPACE}::simd_broadcast(")?;
4232                        }
4233                        crate::GatherMode::Shuffle(_) => {
4234                            write!(self.out, "{NAMESPACE}::simd_shuffle(")?;
4235                        }
4236                        crate::GatherMode::ShuffleDown(_) => {
4237                            write!(self.out, "{NAMESPACE}::simd_shuffle_down(")?;
4238                        }
4239                        crate::GatherMode::ShuffleUp(_) => {
4240                            write!(self.out, "{NAMESPACE}::simd_shuffle_up(")?;
4241                        }
4242                        crate::GatherMode::ShuffleXor(_) => {
4243                            write!(self.out, "{NAMESPACE}::simd_shuffle_xor(")?;
4244                        }
4245                        crate::GatherMode::QuadBroadcast(_) => {
4246                            write!(self.out, "{NAMESPACE}::quad_broadcast(")?;
4247                        }
4248                        crate::GatherMode::QuadSwap(_) => {
4249                            write!(self.out, "{NAMESPACE}::quad_shuffle_xor(")?;
4250                        }
4251                    }
4252                    self.put_expression(argument, &context.expression, true)?;
4253                    match mode {
4254                        crate::GatherMode::BroadcastFirst => {}
4255                        crate::GatherMode::Broadcast(index)
4256                        | crate::GatherMode::Shuffle(index)
4257                        | crate::GatherMode::ShuffleDown(index)
4258                        | crate::GatherMode::ShuffleUp(index)
4259                        | crate::GatherMode::ShuffleXor(index)
4260                        | crate::GatherMode::QuadBroadcast(index) => {
4261                            write!(self.out, ", ")?;
4262                            self.put_expression(index, &context.expression, true)?;
4263                        }
4264                        crate::GatherMode::QuadSwap(direction) => {
4265                            write!(self.out, ", ")?;
4266                            match direction {
4267                                crate::Direction::X => {
4268                                    write!(self.out, "1u")?;
4269                                }
4270                                crate::Direction::Y => {
4271                                    write!(self.out, "2u")?;
4272                                }
4273                                crate::Direction::Diagonal => {
4274                                    write!(self.out, "3u")?;
4275                                }
4276                            }
4277                        }
4278                    }
4279                    writeln!(self.out, ");")?;
4280                }
4281                crate::Statement::CooperativeStore { target, ref data } => {
4282                    write!(self.out, "{level}simdgroup_store(")?;
4283                    self.put_expression(target, &context.expression, true)?;
4284                    write!(self.out, ", &")?;
4285                    self.put_access_chain(
4286                        data.pointer,
4287                        context.expression.policies.index,
4288                        &context.expression,
4289                    )?;
4290                    write!(self.out, ", ")?;
4291                    self.put_expression(data.stride, &context.expression, true)?;
4292                    if data.row_major {
4293                        let matrix_origin = "0";
4294                        let transpose = true;
4295                        write!(self.out, ", {matrix_origin}, {transpose}")?;
4296                    }
4297                    writeln!(self.out, ");")?;
4298                }
4299                crate::Statement::RayPipelineFunction(_) => unreachable!(),
4300            }
4301        }
4302
4303        // un-emit expressions
4304        //TODO: take care of loop/continuing?
4305        for statement in statements {
4306            if let crate::Statement::Emit(ref range) = *statement {
4307                for handle in range.clone() {
4308                    self.named_expressions.shift_remove(&handle);
4309                }
4310            }
4311        }
4312        Ok(())
4313    }
4314
4315    fn put_store(
4316        &mut self,
4317        pointer: Handle<crate::Expression>,
4318        value: Handle<crate::Expression>,
4319        level: back::Level,
4320        context: &StatementContext,
4321    ) -> BackendResult {
4322        let policy = context.expression.choose_bounds_check_policy(pointer);
4323        if policy == index::BoundsCheckPolicy::ReadZeroSkipWrite
4324            && self.put_bounds_checks(pointer, &context.expression, level, "if (")?
4325        {
4326            writeln!(self.out, ") {{")?;
4327            self.put_unchecked_store(pointer, value, policy, level.next(), context)?;
4328            writeln!(self.out, "{level}}}")?;
4329        } else {
4330            self.put_unchecked_store(pointer, value, policy, level, context)?;
4331        }
4332
4333        Ok(())
4334    }
4335
4336    fn put_unchecked_store(
4337        &mut self,
4338        pointer: Handle<crate::Expression>,
4339        value: Handle<crate::Expression>,
4340        policy: index::BoundsCheckPolicy,
4341        level: back::Level,
4342        context: &StatementContext,
4343    ) -> BackendResult {
4344        let is_atomic_pointer = context
4345            .expression
4346            .resolve_type(pointer)
4347            .is_atomic_pointer(&context.expression.module.types);
4348
4349        if is_atomic_pointer {
4350            write!(
4351                self.out,
4352                "{level}{NAMESPACE}::atomic_store_explicit({ATOMIC_REFERENCE}"
4353            )?;
4354            self.put_access_chain(pointer, policy, &context.expression)?;
4355            write!(self.out, ", ")?;
4356            self.put_expression(value, &context.expression, true)?;
4357            writeln!(self.out, ", {NAMESPACE}::memory_order_relaxed);")?;
4358        } else {
4359            write!(self.out, "{level}")?;
4360            self.put_access_chain(pointer, policy, &context.expression)?;
4361            write!(self.out, " = ")?;
4362            self.put_expression(value, &context.expression, true)?;
4363            writeln!(self.out, ";")?;
4364        }
4365
4366        Ok(())
4367    }
4368
4369    pub fn write(
4370        &mut self,
4371        module: &crate::Module,
4372        info: &valid::ModuleInfo,
4373        options: &Options,
4374        pipeline_options: &PipelineOptions,
4375    ) -> Result<TranslationInfo, Error> {
4376        self.names.clear();
4377        self.namer.reset(
4378            module,
4379            &super::keywords::RESERVED_SET,
4380            proc::KeywordSet::empty(),
4381            proc::CaseInsensitiveKeywordSet::empty(),
4382            &[CLAMPED_LOD_LOAD_PREFIX],
4383            &mut self.names,
4384        );
4385        self.wrapped_functions.clear();
4386        self.struct_member_pads.clear();
4387
4388        writeln!(
4389            self.out,
4390            "// language: metal{}.{}",
4391            options.lang_version.0, options.lang_version.1
4392        )?;
4393        writeln!(self.out, "#include <metal_stdlib>")?;
4394        writeln!(self.out, "#include <simd/simd.h>")?;
4395        writeln!(self.out)?;
4396        // Work around Metal bug where `uint` is not available by default
4397        writeln!(self.out, "using {NAMESPACE}::uint;")?;
4398
4399        let mut uses_ray_query = false;
4400        for (_, ty) in module.types.iter() {
4401            match ty.inner {
4402                crate::TypeInner::AccelerationStructure { .. } => {
4403                    if options.lang_version < (2, 4) {
4404                        return Err(Error::UnsupportedRayTracing);
4405                    }
4406                }
4407                crate::TypeInner::RayQuery { .. } => {
4408                    if options.lang_version < (2, 4) {
4409                        return Err(Error::UnsupportedRayTracing);
4410                    }
4411                    uses_ray_query = true;
4412                }
4413                _ => (),
4414            }
4415        }
4416
4417        if module.special_types.ray_desc.is_some()
4418            || module.special_types.ray_intersection.is_some()
4419        {
4420            if options.lang_version < (2, 4) {
4421                return Err(Error::UnsupportedRayTracing);
4422            }
4423        }
4424
4425        if uses_ray_query {
4426            self.put_ray_query_type()?;
4427        }
4428
4429        if options
4430            .bounds_check_policies
4431            .contains(index::BoundsCheckPolicy::ReadZeroSkipWrite)
4432        {
4433            self.put_default_constructible()?;
4434        }
4435        writeln!(self.out)?;
4436
4437        {
4438            // Make a `Vec` of all the `GlobalVariable`s that contain
4439            // runtime-sized arrays.
4440            let globals: Vec<Handle<crate::GlobalVariable>> = module
4441                .global_variables
4442                .iter()
4443                .filter(|&(_, var)| needs_array_length(var.ty, &module.types))
4444                .map(|(handle, _)| handle)
4445                .collect();
4446
4447            let mut buffer_indices = vec![];
4448            for vbm in &pipeline_options.vertex_buffer_mappings {
4449                buffer_indices.push(vbm.id);
4450            }
4451
4452            if !globals.is_empty() || !buffer_indices.is_empty() {
4453                writeln!(self.out, "struct _mslBufferSizes {{")?;
4454
4455                for global in globals {
4456                    writeln!(
4457                        self.out,
4458                        "{}uint {};",
4459                        back::INDENT,
4460                        ArraySizeMember(global)
4461                    )?;
4462                }
4463
4464                for idx in buffer_indices {
4465                    writeln!(self.out, "{}uint buffer_size{};", back::INDENT, idx)?;
4466                }
4467
4468                writeln!(self.out, "}};")?;
4469                writeln!(self.out)?;
4470            }
4471        };
4472
4473        self.write_type_defs(module)?;
4474        self.write_global_constants(module, info)?;
4475        self.write_functions(module, info, options, pipeline_options)
4476    }
4477
4478    /// Write the definition for the `DefaultConstructible` class.
4479    ///
4480    /// The [`ReadZeroSkipWrite`] bounds check policy requires us to be able to
4481    /// produce 'zero' values for any type, including structs, arrays, and so
4482    /// on. We could do this by emitting default constructor applications, but
4483    /// that would entail printing the name of the type, which is more trouble
4484    /// than you'd think. Instead, we just construct this magic C++14 class that
4485    /// can be converted to any type that can be default constructed, using
4486    /// template parameter inference to detect which type is needed, so we don't
4487    /// have to figure out the name.
4488    ///
4489    /// [`ReadZeroSkipWrite`]: index::BoundsCheckPolicy::ReadZeroSkipWrite
4490    fn put_default_constructible(&mut self) -> BackendResult {
4491        let tab = back::INDENT;
4492        writeln!(self.out, "struct DefaultConstructible {{")?;
4493        writeln!(self.out, "{tab}template<typename T>")?;
4494        writeln!(self.out, "{tab}operator T() && {{")?;
4495        writeln!(self.out, "{tab}{tab}return T {{}};")?;
4496        writeln!(self.out, "{tab}}}")?;
4497        writeln!(self.out, "}};")?;
4498        Ok(())
4499    }
4500
4501    fn put_ray_query_type(&mut self) -> BackendResult {
4502        let tab = back::INDENT;
4503        writeln!(self.out, "struct {RAY_QUERY_TYPE} {{")?;
4504        let full_type = format!("{RT_NAMESPACE}::intersector<{RT_NAMESPACE}::instancing, {RT_NAMESPACE}::triangle_data, {RT_NAMESPACE}::world_space_data>");
4505        writeln!(self.out, "{tab}{full_type} {RAY_QUERY_FIELD_INTERSECTOR};")?;
4506        writeln!(
4507            self.out,
4508            "{tab}{full_type}::result_type {RAY_QUERY_FIELD_INTERSECTION};"
4509        )?;
4510        writeln!(self.out, "{tab}bool {RAY_QUERY_FIELD_READY} = false;")?;
4511        writeln!(self.out, "}};")?;
4512        writeln!(self.out, "constexpr {NAMESPACE}::uint {RAY_QUERY_FUN_MAP_INTERSECTION}(const {RT_NAMESPACE}::intersection_type ty) {{")?;
4513        let v_triangle = back::RayIntersectionType::Triangle as u32;
4514        let v_bbox = back::RayIntersectionType::BoundingBox as u32;
4515        writeln!(
4516            self.out,
4517            "{tab}return ty=={RT_NAMESPACE}::intersection_type::triangle ? {v_triangle} : "
4518        )?;
4519        writeln!(
4520            self.out,
4521            "{tab}{tab}ty=={RT_NAMESPACE}::intersection_type::bounding_box ? {v_bbox} : 0;"
4522        )?;
4523        writeln!(self.out, "}}")?;
4524        Ok(())
4525    }
4526
4527    fn write_type_defs(&mut self, module: &crate::Module) -> BackendResult {
4528        let mut generated_argument_buffer_wrapper = false;
4529        let mut generated_external_texture_wrapper = false;
4530        for (handle, ty) in module.types.iter() {
4531            match ty.inner {
4532                crate::TypeInner::BindingArray { .. } if !generated_argument_buffer_wrapper => {
4533                    writeln!(self.out, "template <typename T>")?;
4534                    writeln!(self.out, "struct {ARGUMENT_BUFFER_WRAPPER_STRUCT} {{")?;
4535                    writeln!(self.out, "{}T {WRAPPED_ARRAY_FIELD};", back::INDENT)?;
4536                    writeln!(self.out, "}};")?;
4537                    generated_argument_buffer_wrapper = true;
4538                }
4539                crate::TypeInner::Image {
4540                    class: crate::ImageClass::External,
4541                    ..
4542                } if !generated_external_texture_wrapper => {
4543                    let params_ty_name = &self.names
4544                        [&NameKey::Type(module.special_types.external_texture_params.unwrap())];
4545                    writeln!(self.out, "struct {EXTERNAL_TEXTURE_WRAPPER_STRUCT} {{")?;
4546                    writeln!(
4547                        self.out,
4548                        "{}{NAMESPACE}::texture2d<float, {NAMESPACE}::access::sample> plane0;",
4549                        back::INDENT
4550                    )?;
4551                    writeln!(
4552                        self.out,
4553                        "{}{NAMESPACE}::texture2d<float, {NAMESPACE}::access::sample> plane1;",
4554                        back::INDENT
4555                    )?;
4556                    writeln!(
4557                        self.out,
4558                        "{}{NAMESPACE}::texture2d<float, {NAMESPACE}::access::sample> plane2;",
4559                        back::INDENT
4560                    )?;
4561                    writeln!(self.out, "{}{params_ty_name} params;", back::INDENT)?;
4562                    writeln!(self.out, "}};")?;
4563                    generated_external_texture_wrapper = true;
4564                }
4565                _ => {}
4566            }
4567
4568            if !ty.needs_alias() {
4569                continue;
4570            }
4571            let name = &self.names[&NameKey::Type(handle)];
4572            match ty.inner {
4573                // Naga IR can pass around arrays by value, but Metal, following
4574                // C++, performs an array-to-pointer conversion (C++ [conv.array])
4575                // on expressions of array type, so assigning the array by value
4576                // isn't possible. However, Metal *does* assign structs by
4577                // value. So in our Metal output, we wrap all array types in
4578                // synthetic struct types:
4579                //
4580                //     struct type1 {
4581                //         float inner[10]
4582                //     };
4583                //
4584                // Then we carefully include `.inner` (`WRAPPED_ARRAY_FIELD`) in
4585                // any expression that actually wants access to the array.
4586                crate::TypeInner::Array {
4587                    base,
4588                    size,
4589                    stride: _,
4590                } => {
4591                    let base_name = TypeContext {
4592                        handle: base,
4593                        gctx: module.to_ctx(),
4594                        names: &self.names,
4595                        access: crate::StorageAccess::empty(),
4596                        first_time: false,
4597                    };
4598
4599                    match size.resolve(module.to_ctx())? {
4600                        proc::IndexableLength::Known(size) => {
4601                            writeln!(self.out, "struct {name} {{")?;
4602                            writeln!(
4603                                self.out,
4604                                "{}{} {}[{}];",
4605                                back::INDENT,
4606                                base_name,
4607                                WRAPPED_ARRAY_FIELD,
4608                                size
4609                            )?;
4610                            writeln!(self.out, "}};")?;
4611                        }
4612                        proc::IndexableLength::Dynamic => {
4613                            writeln!(self.out, "typedef {base_name} {name}[1];")?;
4614                        }
4615                    }
4616                }
4617                crate::TypeInner::Struct {
4618                    ref members, span, ..
4619                } => {
4620                    writeln!(self.out, "struct {name} {{")?;
4621                    let mut last_offset = 0;
4622                    for (index, member) in members.iter().enumerate() {
4623                        if member.offset > last_offset {
4624                            self.struct_member_pads.insert((handle, index as u32));
4625                            let pad = member.offset - last_offset;
4626                            writeln!(self.out, "{}char _pad{}[{}];", back::INDENT, index, pad)?;
4627                        }
4628                        let ty_inner = &module.types[member.ty].inner;
4629                        last_offset = member.offset + ty_inner.size(module.to_ctx());
4630
4631                        let member_name = &self.names[&NameKey::StructMember(handle, index as u32)];
4632
4633                        // If the member should be packed (as is the case for a misaligned vec3) issue a packed vector
4634                        match should_pack_struct_member(members, span, index, module) {
4635                            Some(scalar) => {
4636                                writeln!(
4637                                    self.out,
4638                                    "{}{}::packed_{}3 {};",
4639                                    back::INDENT,
4640                                    NAMESPACE,
4641                                    scalar.to_msl_name(),
4642                                    member_name
4643                                )?;
4644                            }
4645                            None => {
4646                                let base_name = TypeContext {
4647                                    handle: member.ty,
4648                                    gctx: module.to_ctx(),
4649                                    names: &self.names,
4650                                    access: crate::StorageAccess::empty(),
4651                                    first_time: false,
4652                                };
4653                                writeln!(
4654                                    self.out,
4655                                    "{}{} {};",
4656                                    back::INDENT,
4657                                    base_name,
4658                                    member_name
4659                                )?;
4660
4661                                // for 3-component vectors, add one component
4662                                if let crate::TypeInner::Vector {
4663                                    size: crate::VectorSize::Tri,
4664                                    scalar,
4665                                } = *ty_inner
4666                                {
4667                                    last_offset += scalar.width as u32;
4668                                }
4669                            }
4670                        }
4671                    }
4672                    if last_offset < span {
4673                        let pad = span - last_offset;
4674                        writeln!(
4675                            self.out,
4676                            "{}char _pad{}[{}];",
4677                            back::INDENT,
4678                            members.len(),
4679                            pad
4680                        )?;
4681                    }
4682                    writeln!(self.out, "}};")?;
4683                }
4684                _ => {
4685                    let ty_name = TypeContext {
4686                        handle,
4687                        gctx: module.to_ctx(),
4688                        names: &self.names,
4689                        access: crate::StorageAccess::empty(),
4690                        first_time: true,
4691                    };
4692                    writeln!(self.out, "typedef {ty_name} {name};")?;
4693                }
4694            }
4695        }
4696
4697        // Write functions to create special types.
4698        for (type_key, struct_ty) in module.special_types.predeclared_types.iter() {
4699            match type_key {
4700                &crate::PredeclaredType::ModfResult { size, scalar }
4701                | &crate::PredeclaredType::FrexpResult { size, scalar } => {
4702                    let arg_type_name_owner;
4703                    let arg_type_name = if let Some(size) = size {
4704                        arg_type_name_owner = format!(
4705                            "{NAMESPACE}::{}{}",
4706                            if scalar.width == 8 { "double" } else { "float" },
4707                            size as u8
4708                        );
4709                        &arg_type_name_owner
4710                    } else if scalar.width == 8 {
4711                        "double"
4712                    } else {
4713                        "float"
4714                    };
4715
4716                    let other_type_name_owner;
4717                    let (defined_func_name, called_func_name, other_type_name) =
4718                        if matches!(type_key, &crate::PredeclaredType::ModfResult { .. }) {
4719                            (MODF_FUNCTION, "modf", arg_type_name)
4720                        } else {
4721                            let other_type_name = if let Some(size) = size {
4722                                other_type_name_owner = format!("int{}", size as u8);
4723                                &other_type_name_owner
4724                            } else {
4725                                "int"
4726                            };
4727                            (FREXP_FUNCTION, "frexp", other_type_name)
4728                        };
4729
4730                    let struct_name = &self.names[&NameKey::Type(*struct_ty)];
4731
4732                    writeln!(self.out)?;
4733                    writeln!(
4734                        self.out,
4735                        "{struct_name} {defined_func_name}({arg_type_name} arg) {{
4736    {other_type_name} other;
4737    {arg_type_name} fract = {NAMESPACE}::{called_func_name}(arg, other);
4738    return {struct_name}{{ fract, other }};
4739}}"
4740                    )?;
4741                }
4742                &crate::PredeclaredType::AtomicCompareExchangeWeakResult(scalar) => {
4743                    let arg_type_name = scalar.to_msl_name();
4744                    let called_func_name = "atomic_compare_exchange_weak_explicit";
4745                    let defined_func_name = ATOMIC_COMP_EXCH_FUNCTION;
4746                    let struct_name = &self.names[&NameKey::Type(*struct_ty)];
4747
4748                    writeln!(self.out)?;
4749
4750                    for address_space_name in ["device", "threadgroup"] {
4751                        writeln!(
4752                            self.out,
4753                            "\
4754template <typename A>
4755{struct_name} {defined_func_name}(
4756    {address_space_name} A *atomic_ptr,
4757    {arg_type_name} cmp,
4758    {arg_type_name} v
4759) {{
4760    bool swapped = {NAMESPACE}::{called_func_name}(
4761        atomic_ptr, &cmp, v,
4762        metal::memory_order_relaxed, metal::memory_order_relaxed
4763    );
4764    return {struct_name}{{cmp, swapped}};
4765}}"
4766                        )?;
4767                    }
4768                }
4769            }
4770        }
4771
4772        Ok(())
4773    }
4774
4775    /// Writes all named constants
4776    fn write_global_constants(
4777        &mut self,
4778        module: &crate::Module,
4779        mod_info: &valid::ModuleInfo,
4780    ) -> BackendResult {
4781        let constants = module.constants.iter().filter(|&(_, c)| c.name.is_some());
4782
4783        for (handle, constant) in constants {
4784            let ty_name = TypeContext {
4785                handle: constant.ty,
4786                gctx: module.to_ctx(),
4787                names: &self.names,
4788                access: crate::StorageAccess::empty(),
4789                first_time: false,
4790            };
4791            let name = &self.names[&NameKey::Constant(handle)];
4792            write!(self.out, "constant {ty_name} {name} = ")?;
4793            self.put_const_expression(constant.init, module, mod_info, &module.global_expressions)?;
4794            writeln!(self.out, ";")?;
4795        }
4796
4797        Ok(())
4798    }
4799
4800    fn put_inline_sampler_properties(
4801        &mut self,
4802        level: back::Level,
4803        sampler: &sm::InlineSampler,
4804    ) -> BackendResult {
4805        for (&letter, address) in ['s', 't', 'r'].iter().zip(sampler.address.iter()) {
4806            writeln!(
4807                self.out,
4808                "{}{}::{}_address::{},",
4809                level,
4810                NAMESPACE,
4811                letter,
4812                address.as_str(),
4813            )?;
4814        }
4815        writeln!(
4816            self.out,
4817            "{}{}::mag_filter::{},",
4818            level,
4819            NAMESPACE,
4820            sampler.mag_filter.as_str(),
4821        )?;
4822        writeln!(
4823            self.out,
4824            "{}{}::min_filter::{},",
4825            level,
4826            NAMESPACE,
4827            sampler.min_filter.as_str(),
4828        )?;
4829        if let Some(filter) = sampler.mip_filter {
4830            writeln!(
4831                self.out,
4832                "{}{}::mip_filter::{},",
4833                level,
4834                NAMESPACE,
4835                filter.as_str(),
4836            )?;
4837        }
4838        // avoid setting it on platforms that don't support it
4839        if sampler.border_color != sm::BorderColor::TransparentBlack {
4840            writeln!(
4841                self.out,
4842                "{}{}::border_color::{},",
4843                level,
4844                NAMESPACE,
4845                sampler.border_color.as_str(),
4846            )?;
4847        }
4848        //TODO: I'm not able to feed this in a way that MSL likes:
4849        //>error: use of undeclared identifier 'lod_clamp'
4850        //>error: no member named 'max_anisotropy' in namespace 'metal'
4851        if false {
4852            if let Some(ref lod) = sampler.lod_clamp {
4853                writeln!(self.out, "{}lod_clamp({},{}),", level, lod.start, lod.end,)?;
4854            }
4855            if let Some(aniso) = sampler.max_anisotropy {
4856                writeln!(self.out, "{}max_anisotropy({}),", level, aniso.get(),)?;
4857            }
4858        }
4859        if sampler.compare_func != sm::CompareFunc::Never {
4860            writeln!(
4861                self.out,
4862                "{}{}::compare_func::{},",
4863                level,
4864                NAMESPACE,
4865                sampler.compare_func.as_str(),
4866            )?;
4867        }
4868        writeln!(
4869            self.out,
4870            "{}{}::coord::{}",
4871            level,
4872            NAMESPACE,
4873            sampler.coord.as_str()
4874        )?;
4875        Ok(())
4876    }
4877
4878    fn write_unpacking_function(
4879        &mut self,
4880        format: back::msl::VertexFormat,
4881    ) -> Result<(String, u32, u32), Error> {
4882        use back::msl::VertexFormat::*;
4883        match format {
4884            Uint8 => {
4885                let name = self.namer.call("unpackUint8");
4886                writeln!(self.out, "uint {name}(metal::uchar b0) {{")?;
4887                writeln!(self.out, "{}return uint(b0);", back::INDENT)?;
4888                writeln!(self.out, "}}")?;
4889                Ok((name, 1, 1))
4890            }
4891            Uint8x2 => {
4892                let name = self.namer.call("unpackUint8x2");
4893                writeln!(
4894                    self.out,
4895                    "metal::uint2 {name}(metal::uchar b0, \
4896                                         metal::uchar b1) {{"
4897                )?;
4898                writeln!(self.out, "{}return metal::uint2(b0, b1);", back::INDENT)?;
4899                writeln!(self.out, "}}")?;
4900                Ok((name, 2, 2))
4901            }
4902            Uint8x4 => {
4903                let name = self.namer.call("unpackUint8x4");
4904                writeln!(
4905                    self.out,
4906                    "metal::uint4 {name}(metal::uchar b0, \
4907                                         metal::uchar b1, \
4908                                         metal::uchar b2, \
4909                                         metal::uchar b3) {{"
4910                )?;
4911                writeln!(
4912                    self.out,
4913                    "{}return metal::uint4(b0, b1, b2, b3);",
4914                    back::INDENT
4915                )?;
4916                writeln!(self.out, "}}")?;
4917                Ok((name, 4, 4))
4918            }
4919            Sint8 => {
4920                let name = self.namer.call("unpackSint8");
4921                writeln!(self.out, "int {name}(metal::uchar b0) {{")?;
4922                writeln!(self.out, "{}return int(as_type<char>(b0));", back::INDENT)?;
4923                writeln!(self.out, "}}")?;
4924                Ok((name, 1, 1))
4925            }
4926            Sint8x2 => {
4927                let name = self.namer.call("unpackSint8x2");
4928                writeln!(
4929                    self.out,
4930                    "metal::int2 {name}(metal::uchar b0, \
4931                                        metal::uchar b1) {{"
4932                )?;
4933                writeln!(
4934                    self.out,
4935                    "{}return metal::int2(as_type<char>(b0), \
4936                                          as_type<char>(b1));",
4937                    back::INDENT
4938                )?;
4939                writeln!(self.out, "}}")?;
4940                Ok((name, 2, 2))
4941            }
4942            Sint8x4 => {
4943                let name = self.namer.call("unpackSint8x4");
4944                writeln!(
4945                    self.out,
4946                    "metal::int4 {name}(metal::uchar b0, \
4947                                        metal::uchar b1, \
4948                                        metal::uchar b2, \
4949                                        metal::uchar b3) {{"
4950                )?;
4951                writeln!(
4952                    self.out,
4953                    "{}return metal::int4(as_type<char>(b0), \
4954                                          as_type<char>(b1), \
4955                                          as_type<char>(b2), \
4956                                          as_type<char>(b3));",
4957                    back::INDENT
4958                )?;
4959                writeln!(self.out, "}}")?;
4960                Ok((name, 4, 4))
4961            }
4962            Unorm8 => {
4963                let name = self.namer.call("unpackUnorm8");
4964                writeln!(self.out, "float {name}(metal::uchar b0) {{")?;
4965                writeln!(
4966                    self.out,
4967                    "{}return float(float(b0) / 255.0f);",
4968                    back::INDENT
4969                )?;
4970                writeln!(self.out, "}}")?;
4971                Ok((name, 1, 1))
4972            }
4973            Unorm8x2 => {
4974                let name = self.namer.call("unpackUnorm8x2");
4975                writeln!(
4976                    self.out,
4977                    "metal::float2 {name}(metal::uchar b0, \
4978                                          metal::uchar b1) {{"
4979                )?;
4980                writeln!(
4981                    self.out,
4982                    "{}return metal::float2(float(b0) / 255.0f, \
4983                                            float(b1) / 255.0f);",
4984                    back::INDENT
4985                )?;
4986                writeln!(self.out, "}}")?;
4987                Ok((name, 2, 2))
4988            }
4989            Unorm8x4 => {
4990                let name = self.namer.call("unpackUnorm8x4");
4991                writeln!(
4992                    self.out,
4993                    "metal::float4 {name}(metal::uchar b0, \
4994                                          metal::uchar b1, \
4995                                          metal::uchar b2, \
4996                                          metal::uchar b3) {{"
4997                )?;
4998                writeln!(
4999                    self.out,
5000                    "{}return metal::float4(float(b0) / 255.0f, \
5001                                            float(b1) / 255.0f, \
5002                                            float(b2) / 255.0f, \
5003                                            float(b3) / 255.0f);",
5004                    back::INDENT
5005                )?;
5006                writeln!(self.out, "}}")?;
5007                Ok((name, 4, 4))
5008            }
5009            Snorm8 => {
5010                let name = self.namer.call("unpackSnorm8");
5011                writeln!(self.out, "float {name}(metal::uchar b0) {{")?;
5012                writeln!(
5013                    self.out,
5014                    "{}return float(metal::max(-1.0f, as_type<char>(b0) / 127.0f));",
5015                    back::INDENT
5016                )?;
5017                writeln!(self.out, "}}")?;
5018                Ok((name, 1, 1))
5019            }
5020            Snorm8x2 => {
5021                let name = self.namer.call("unpackSnorm8x2");
5022                writeln!(
5023                    self.out,
5024                    "metal::float2 {name}(metal::uchar b0, \
5025                                          metal::uchar b1) {{"
5026                )?;
5027                writeln!(
5028                    self.out,
5029                    "{}return metal::float2(metal::max(-1.0f, as_type<char>(b0) / 127.0f), \
5030                                            metal::max(-1.0f, as_type<char>(b1) / 127.0f));",
5031                    back::INDENT
5032                )?;
5033                writeln!(self.out, "}}")?;
5034                Ok((name, 2, 2))
5035            }
5036            Snorm8x4 => {
5037                let name = self.namer.call("unpackSnorm8x4");
5038                writeln!(
5039                    self.out,
5040                    "metal::float4 {name}(metal::uchar b0, \
5041                                          metal::uchar b1, \
5042                                          metal::uchar b2, \
5043                                          metal::uchar b3) {{"
5044                )?;
5045                writeln!(
5046                    self.out,
5047                    "{}return metal::float4(metal::max(-1.0f, as_type<char>(b0) / 127.0f), \
5048                                            metal::max(-1.0f, as_type<char>(b1) / 127.0f), \
5049                                            metal::max(-1.0f, as_type<char>(b2) / 127.0f), \
5050                                            metal::max(-1.0f, as_type<char>(b3) / 127.0f));",
5051                    back::INDENT
5052                )?;
5053                writeln!(self.out, "}}")?;
5054                Ok((name, 4, 4))
5055            }
5056            Uint16 => {
5057                let name = self.namer.call("unpackUint16");
5058                writeln!(
5059                    self.out,
5060                    "metal::uint {name}(metal::uint b0, \
5061                                        metal::uint b1) {{"
5062                )?;
5063                writeln!(
5064                    self.out,
5065                    "{}return metal::uint(b1 << 8 | b0);",
5066                    back::INDENT
5067                )?;
5068                writeln!(self.out, "}}")?;
5069                Ok((name, 2, 1))
5070            }
5071            Uint16x2 => {
5072                let name = self.namer.call("unpackUint16x2");
5073                writeln!(
5074                    self.out,
5075                    "metal::uint2 {name}(metal::uint b0, \
5076                                         metal::uint b1, \
5077                                         metal::uint b2, \
5078                                         metal::uint b3) {{"
5079                )?;
5080                writeln!(
5081                    self.out,
5082                    "{}return metal::uint2(b1 << 8 | b0, \
5083                                           b3 << 8 | b2);",
5084                    back::INDENT
5085                )?;
5086                writeln!(self.out, "}}")?;
5087                Ok((name, 4, 2))
5088            }
5089            Uint16x4 => {
5090                let name = self.namer.call("unpackUint16x4");
5091                writeln!(
5092                    self.out,
5093                    "metal::uint4 {name}(metal::uint b0, \
5094                                         metal::uint b1, \
5095                                         metal::uint b2, \
5096                                         metal::uint b3, \
5097                                         metal::uint b4, \
5098                                         metal::uint b5, \
5099                                         metal::uint b6, \
5100                                         metal::uint b7) {{"
5101                )?;
5102                writeln!(
5103                    self.out,
5104                    "{}return metal::uint4(b1 << 8 | b0, \
5105                                           b3 << 8 | b2, \
5106                                           b5 << 8 | b4, \
5107                                           b7 << 8 | b6);",
5108                    back::INDENT
5109                )?;
5110                writeln!(self.out, "}}")?;
5111                Ok((name, 8, 4))
5112            }
5113            Sint16 => {
5114                let name = self.namer.call("unpackSint16");
5115                writeln!(
5116                    self.out,
5117                    "int {name}(metal::ushort b0, \
5118                                metal::ushort b1) {{"
5119                )?;
5120                writeln!(
5121                    self.out,
5122                    "{}return int(as_type<short>(metal::ushort(b1 << 8 | b0)));",
5123                    back::INDENT
5124                )?;
5125                writeln!(self.out, "}}")?;
5126                Ok((name, 2, 1))
5127            }
5128            Sint16x2 => {
5129                let name = self.namer.call("unpackSint16x2");
5130                writeln!(
5131                    self.out,
5132                    "metal::int2 {name}(metal::ushort b0, \
5133                                        metal::ushort b1, \
5134                                        metal::ushort b2, \
5135                                        metal::ushort b3) {{"
5136                )?;
5137                writeln!(
5138                    self.out,
5139                    "{}return metal::int2(as_type<short>(metal::ushort(b1 << 8 | b0)), \
5140                                          as_type<short>(metal::ushort(b3 << 8 | b2)));",
5141                    back::INDENT
5142                )?;
5143                writeln!(self.out, "}}")?;
5144                Ok((name, 4, 2))
5145            }
5146            Sint16x4 => {
5147                let name = self.namer.call("unpackSint16x4");
5148                writeln!(
5149                    self.out,
5150                    "metal::int4 {name}(metal::ushort b0, \
5151                                        metal::ushort b1, \
5152                                        metal::ushort b2, \
5153                                        metal::ushort b3, \
5154                                        metal::ushort b4, \
5155                                        metal::ushort b5, \
5156                                        metal::ushort b6, \
5157                                        metal::ushort b7) {{"
5158                )?;
5159                writeln!(
5160                    self.out,
5161                    "{}return metal::int4(as_type<short>(metal::ushort(b1 << 8 | b0)), \
5162                                          as_type<short>(metal::ushort(b3 << 8 | b2)), \
5163                                          as_type<short>(metal::ushort(b5 << 8 | b4)), \
5164                                          as_type<short>(metal::ushort(b7 << 8 | b6)));",
5165                    back::INDENT
5166                )?;
5167                writeln!(self.out, "}}")?;
5168                Ok((name, 8, 4))
5169            }
5170            Unorm16 => {
5171                let name = self.namer.call("unpackUnorm16");
5172                writeln!(
5173                    self.out,
5174                    "float {name}(metal::ushort b0, \
5175                                  metal::ushort b1) {{"
5176                )?;
5177                writeln!(
5178                    self.out,
5179                    "{}return float(float(b1 << 8 | b0) / 65535.0f);",
5180                    back::INDENT
5181                )?;
5182                writeln!(self.out, "}}")?;
5183                Ok((name, 2, 1))
5184            }
5185            Unorm16x2 => {
5186                let name = self.namer.call("unpackUnorm16x2");
5187                writeln!(
5188                    self.out,
5189                    "metal::float2 {name}(metal::ushort b0, \
5190                                          metal::ushort b1, \
5191                                          metal::ushort b2, \
5192                                          metal::ushort b3) {{"
5193                )?;
5194                writeln!(
5195                    self.out,
5196                    "{}return metal::float2(float(b1 << 8 | b0) / 65535.0f, \
5197                                            float(b3 << 8 | b2) / 65535.0f);",
5198                    back::INDENT
5199                )?;
5200                writeln!(self.out, "}}")?;
5201                Ok((name, 4, 2))
5202            }
5203            Unorm16x4 => {
5204                let name = self.namer.call("unpackUnorm16x4");
5205                writeln!(
5206                    self.out,
5207                    "metal::float4 {name}(metal::ushort b0, \
5208                                          metal::ushort b1, \
5209                                          metal::ushort b2, \
5210                                          metal::ushort b3, \
5211                                          metal::ushort b4, \
5212                                          metal::ushort b5, \
5213                                          metal::ushort b6, \
5214                                          metal::ushort b7) {{"
5215                )?;
5216                writeln!(
5217                    self.out,
5218                    "{}return metal::float4(float(b1 << 8 | b0) / 65535.0f, \
5219                                            float(b3 << 8 | b2) / 65535.0f, \
5220                                            float(b5 << 8 | b4) / 65535.0f, \
5221                                            float(b7 << 8 | b6) / 65535.0f);",
5222                    back::INDENT
5223                )?;
5224                writeln!(self.out, "}}")?;
5225                Ok((name, 8, 4))
5226            }
5227            Snorm16 => {
5228                let name = self.namer.call("unpackSnorm16");
5229                writeln!(
5230                    self.out,
5231                    "float {name}(metal::ushort b0, \
5232                                  metal::ushort b1) {{"
5233                )?;
5234                writeln!(
5235                    self.out,
5236                    "{}return metal::unpack_snorm2x16_to_float(b1 << 8 | b0).x;",
5237                    back::INDENT
5238                )?;
5239                writeln!(self.out, "}}")?;
5240                Ok((name, 2, 1))
5241            }
5242            Snorm16x2 => {
5243                let name = self.namer.call("unpackSnorm16x2");
5244                writeln!(
5245                    self.out,
5246                    "metal::float2 {name}(uint b0, \
5247                                          uint b1, \
5248                                          uint b2, \
5249                                          uint b3) {{"
5250                )?;
5251                writeln!(
5252                    self.out,
5253                    "{}return metal::unpack_snorm2x16_to_float(b3 << 24 | b2 << 16 | b1 << 8 | b0);",
5254                    back::INDENT
5255                )?;
5256                writeln!(self.out, "}}")?;
5257                Ok((name, 4, 2))
5258            }
5259            Snorm16x4 => {
5260                let name = self.namer.call("unpackSnorm16x4");
5261                writeln!(
5262                    self.out,
5263                    "metal::float4 {name}(uint b0, \
5264                                          uint b1, \
5265                                          uint b2, \
5266                                          uint b3, \
5267                                          uint b4, \
5268                                          uint b5, \
5269                                          uint b6, \
5270                                          uint b7) {{"
5271                )?;
5272                writeln!(
5273                    self.out,
5274                    "{}return metal::float4(metal::unpack_snorm2x16_to_float(b3 << 24 | b2 << 16 | b1 << 8 | b0), \
5275                                            metal::unpack_snorm2x16_to_float(b7 << 24 | b6 << 16 | b5 << 8 | b4));",
5276                    back::INDENT
5277                )?;
5278                writeln!(self.out, "}}")?;
5279                Ok((name, 8, 4))
5280            }
5281            Float16 => {
5282                let name = self.namer.call("unpackFloat16");
5283                writeln!(
5284                    self.out,
5285                    "float {name}(metal::ushort b0, \
5286                                  metal::ushort b1) {{"
5287                )?;
5288                writeln!(
5289                    self.out,
5290                    "{}return float(as_type<half>(metal::ushort(b1 << 8 | b0)));",
5291                    back::INDENT
5292                )?;
5293                writeln!(self.out, "}}")?;
5294                Ok((name, 2, 1))
5295            }
5296            Float16x2 => {
5297                let name = self.namer.call("unpackFloat16x2");
5298                writeln!(
5299                    self.out,
5300                    "metal::float2 {name}(metal::ushort b0, \
5301                                          metal::ushort b1, \
5302                                          metal::ushort b2, \
5303                                          metal::ushort b3) {{"
5304                )?;
5305                writeln!(
5306                    self.out,
5307                    "{}return metal::float2(as_type<half>(metal::ushort(b1 << 8 | b0)), \
5308                                            as_type<half>(metal::ushort(b3 << 8 | b2)));",
5309                    back::INDENT
5310                )?;
5311                writeln!(self.out, "}}")?;
5312                Ok((name, 4, 2))
5313            }
5314            Float16x4 => {
5315                let name = self.namer.call("unpackFloat16x4");
5316                writeln!(
5317                    self.out,
5318                    "metal::float4 {name}(metal::ushort b0, \
5319                                        metal::ushort b1, \
5320                                        metal::ushort b2, \
5321                                        metal::ushort b3, \
5322                                        metal::ushort b4, \
5323                                        metal::ushort b5, \
5324                                        metal::ushort b6, \
5325                                        metal::ushort b7) {{"
5326                )?;
5327                writeln!(
5328                    self.out,
5329                    "{}return metal::float4(as_type<half>(metal::ushort(b1 << 8 | b0)), \
5330                                          as_type<half>(metal::ushort(b3 << 8 | b2)), \
5331                                          as_type<half>(metal::ushort(b5 << 8 | b4)), \
5332                                          as_type<half>(metal::ushort(b7 << 8 | b6)));",
5333                    back::INDENT
5334                )?;
5335                writeln!(self.out, "}}")?;
5336                Ok((name, 8, 4))
5337            }
5338            Float32 => {
5339                let name = self.namer.call("unpackFloat32");
5340                writeln!(
5341                    self.out,
5342                    "float {name}(uint b0, \
5343                                  uint b1, \
5344                                  uint b2, \
5345                                  uint b3) {{"
5346                )?;
5347                writeln!(
5348                    self.out,
5349                    "{}return as_type<float>(b3 << 24 | b2 << 16 | b1 << 8 | b0);",
5350                    back::INDENT
5351                )?;
5352                writeln!(self.out, "}}")?;
5353                Ok((name, 4, 1))
5354            }
5355            Float32x2 => {
5356                let name = self.namer.call("unpackFloat32x2");
5357                writeln!(
5358                    self.out,
5359                    "metal::float2 {name}(uint b0, \
5360                                          uint b1, \
5361                                          uint b2, \
5362                                          uint b3, \
5363                                          uint b4, \
5364                                          uint b5, \
5365                                          uint b6, \
5366                                          uint b7) {{"
5367                )?;
5368                writeln!(
5369                    self.out,
5370                    "{}return metal::float2(as_type<float>(b3 << 24 | b2 << 16 | b1 << 8 | b0), \
5371                                            as_type<float>(b7 << 24 | b6 << 16 | b5 << 8 | b4));",
5372                    back::INDENT
5373                )?;
5374                writeln!(self.out, "}}")?;
5375                Ok((name, 8, 2))
5376            }
5377            Float32x3 => {
5378                let name = self.namer.call("unpackFloat32x3");
5379                writeln!(
5380                    self.out,
5381                    "metal::float3 {name}(uint b0, \
5382                                          uint b1, \
5383                                          uint b2, \
5384                                          uint b3, \
5385                                          uint b4, \
5386                                          uint b5, \
5387                                          uint b6, \
5388                                          uint b7, \
5389                                          uint b8, \
5390                                          uint b9, \
5391                                          uint b10, \
5392                                          uint b11) {{"
5393                )?;
5394                writeln!(
5395                    self.out,
5396                    "{}return metal::float3(as_type<float>(b3 << 24 | b2 << 16 | b1 << 8 | b0), \
5397                                            as_type<float>(b7 << 24 | b6 << 16 | b5 << 8 | b4), \
5398                                            as_type<float>(b11 << 24 | b10 << 16 | b9 << 8 | b8));",
5399                    back::INDENT
5400                )?;
5401                writeln!(self.out, "}}")?;
5402                Ok((name, 12, 3))
5403            }
5404            Float32x4 => {
5405                let name = self.namer.call("unpackFloat32x4");
5406                writeln!(
5407                    self.out,
5408                    "metal::float4 {name}(uint b0, \
5409                                          uint b1, \
5410                                          uint b2, \
5411                                          uint b3, \
5412                                          uint b4, \
5413                                          uint b5, \
5414                                          uint b6, \
5415                                          uint b7, \
5416                                          uint b8, \
5417                                          uint b9, \
5418                                          uint b10, \
5419                                          uint b11, \
5420                                          uint b12, \
5421                                          uint b13, \
5422                                          uint b14, \
5423                                          uint b15) {{"
5424                )?;
5425                writeln!(
5426                    self.out,
5427                    "{}return metal::float4(as_type<float>(b3 << 24 | b2 << 16 | b1 << 8 | b0), \
5428                                            as_type<float>(b7 << 24 | b6 << 16 | b5 << 8 | b4), \
5429                                            as_type<float>(b11 << 24 | b10 << 16 | b9 << 8 | b8), \
5430                                            as_type<float>(b15 << 24 | b14 << 16 | b13 << 8 | b12));",
5431                    back::INDENT
5432                )?;
5433                writeln!(self.out, "}}")?;
5434                Ok((name, 16, 4))
5435            }
5436            Uint32 => {
5437                let name = self.namer.call("unpackUint32");
5438                writeln!(
5439                    self.out,
5440                    "uint {name}(uint b0, \
5441                                 uint b1, \
5442                                 uint b2, \
5443                                 uint b3) {{"
5444                )?;
5445                writeln!(
5446                    self.out,
5447                    "{}return (b3 << 24 | b2 << 16 | b1 << 8 | b0);",
5448                    back::INDENT
5449                )?;
5450                writeln!(self.out, "}}")?;
5451                Ok((name, 4, 1))
5452            }
5453            Uint32x2 => {
5454                let name = self.namer.call("unpackUint32x2");
5455                writeln!(
5456                    self.out,
5457                    "uint2 {name}(uint b0, \
5458                                  uint b1, \
5459                                  uint b2, \
5460                                  uint b3, \
5461                                  uint b4, \
5462                                  uint b5, \
5463                                  uint b6, \
5464                                  uint b7) {{"
5465                )?;
5466                writeln!(
5467                    self.out,
5468                    "{}return uint2((b3 << 24 | b2 << 16 | b1 << 8 | b0), \
5469                                    (b7 << 24 | b6 << 16 | b5 << 8 | b4));",
5470                    back::INDENT
5471                )?;
5472                writeln!(self.out, "}}")?;
5473                Ok((name, 8, 2))
5474            }
5475            Uint32x3 => {
5476                let name = self.namer.call("unpackUint32x3");
5477                writeln!(
5478                    self.out,
5479                    "uint3 {name}(uint b0, \
5480                                  uint b1, \
5481                                  uint b2, \
5482                                  uint b3, \
5483                                  uint b4, \
5484                                  uint b5, \
5485                                  uint b6, \
5486                                  uint b7, \
5487                                  uint b8, \
5488                                  uint b9, \
5489                                  uint b10, \
5490                                  uint b11) {{"
5491                )?;
5492                writeln!(
5493                    self.out,
5494                    "{}return uint3((b3 << 24 | b2 << 16 | b1 << 8 | b0), \
5495                                    (b7 << 24 | b6 << 16 | b5 << 8 | b4), \
5496                                    (b11 << 24 | b10 << 16 | b9 << 8 | b8));",
5497                    back::INDENT
5498                )?;
5499                writeln!(self.out, "}}")?;
5500                Ok((name, 12, 3))
5501            }
5502            Uint32x4 => {
5503                let name = self.namer.call("unpackUint32x4");
5504                writeln!(
5505                    self.out,
5506                    "{NAMESPACE}::uint4 {name}(uint b0, \
5507                                  uint b1, \
5508                                  uint b2, \
5509                                  uint b3, \
5510                                  uint b4, \
5511                                  uint b5, \
5512                                  uint b6, \
5513                                  uint b7, \
5514                                  uint b8, \
5515                                  uint b9, \
5516                                  uint b10, \
5517                                  uint b11, \
5518                                  uint b12, \
5519                                  uint b13, \
5520                                  uint b14, \
5521                                  uint b15) {{"
5522                )?;
5523                writeln!(
5524                    self.out,
5525                    "{}return {NAMESPACE}::uint4((b3 << 24 | b2 << 16 | b1 << 8 | b0), \
5526                                    (b7 << 24 | b6 << 16 | b5 << 8 | b4), \
5527                                    (b11 << 24 | b10 << 16 | b9 << 8 | b8), \
5528                                    (b15 << 24 | b14 << 16 | b13 << 8 | b12));",
5529                    back::INDENT
5530                )?;
5531                writeln!(self.out, "}}")?;
5532                Ok((name, 16, 4))
5533            }
5534            Sint32 => {
5535                let name = self.namer.call("unpackSint32");
5536                writeln!(
5537                    self.out,
5538                    "int {name}(uint b0, \
5539                                uint b1, \
5540                                uint b2, \
5541                                uint b3) {{"
5542                )?;
5543                writeln!(
5544                    self.out,
5545                    "{}return as_type<int>(b3 << 24 | b2 << 16 | b1 << 8 | b0);",
5546                    back::INDENT
5547                )?;
5548                writeln!(self.out, "}}")?;
5549                Ok((name, 4, 1))
5550            }
5551            Sint32x2 => {
5552                let name = self.namer.call("unpackSint32x2");
5553                writeln!(
5554                    self.out,
5555                    "metal::int2 {name}(uint b0, \
5556                                        uint b1, \
5557                                        uint b2, \
5558                                        uint b3, \
5559                                        uint b4, \
5560                                        uint b5, \
5561                                        uint b6, \
5562                                        uint b7) {{"
5563                )?;
5564                writeln!(
5565                    self.out,
5566                    "{}return metal::int2(as_type<int>(b3 << 24 | b2 << 16 | b1 << 8 | b0), \
5567                                          as_type<int>(b7 << 24 | b6 << 16 | b5 << 8 | b4));",
5568                    back::INDENT
5569                )?;
5570                writeln!(self.out, "}}")?;
5571                Ok((name, 8, 2))
5572            }
5573            Sint32x3 => {
5574                let name = self.namer.call("unpackSint32x3");
5575                writeln!(
5576                    self.out,
5577                    "metal::int3 {name}(uint b0, \
5578                                        uint b1, \
5579                                        uint b2, \
5580                                        uint b3, \
5581                                        uint b4, \
5582                                        uint b5, \
5583                                        uint b6, \
5584                                        uint b7, \
5585                                        uint b8, \
5586                                        uint b9, \
5587                                        uint b10, \
5588                                        uint b11) {{"
5589                )?;
5590                writeln!(
5591                    self.out,
5592                    "{}return metal::int3(as_type<int>(b3 << 24 | b2 << 16 | b1 << 8 | b0), \
5593                                          as_type<int>(b7 << 24 | b6 << 16 | b5 << 8 | b4), \
5594                                          as_type<int>(b11 << 24 | b10 << 16 | b9 << 8 | b8));",
5595                    back::INDENT
5596                )?;
5597                writeln!(self.out, "}}")?;
5598                Ok((name, 12, 3))
5599            }
5600            Sint32x4 => {
5601                let name = self.namer.call("unpackSint32x4");
5602                writeln!(
5603                    self.out,
5604                    "metal::int4 {name}(uint b0, \
5605                                        uint b1, \
5606                                        uint b2, \
5607                                        uint b3, \
5608                                        uint b4, \
5609                                        uint b5, \
5610                                        uint b6, \
5611                                        uint b7, \
5612                                        uint b8, \
5613                                        uint b9, \
5614                                        uint b10, \
5615                                        uint b11, \
5616                                        uint b12, \
5617                                        uint b13, \
5618                                        uint b14, \
5619                                        uint b15) {{"
5620                )?;
5621                writeln!(
5622                    self.out,
5623                    "{}return metal::int4(as_type<int>(b3 << 24 | b2 << 16 | b1 << 8 | b0), \
5624                                          as_type<int>(b7 << 24 | b6 << 16 | b5 << 8 | b4), \
5625                                          as_type<int>(b11 << 24 | b10 << 16 | b9 << 8 | b8), \
5626                                          as_type<int>(b15 << 24 | b14 << 16 | b13 << 8 | b12));",
5627                    back::INDENT
5628                )?;
5629                writeln!(self.out, "}}")?;
5630                Ok((name, 16, 4))
5631            }
5632            Unorm10_10_10_2 => {
5633                let name = self.namer.call("unpackUnorm10_10_10_2");
5634                writeln!(
5635                    self.out,
5636                    "metal::float4 {name}(uint b0, \
5637                                          uint b1, \
5638                                          uint b2, \
5639                                          uint b3) {{"
5640                )?;
5641                writeln!(
5642                    self.out,
5643                    // The following is correct for RGBA packing, but our format seems to
5644                    // match ABGR, which can be fed into the Metal builtin function
5645                    // unpack_unorm10a2_to_float.
5646                    /*
5647                    "{}uint v = (b3 << 24 | b2 << 16 | b1 << 8 | b0); \
5648                       uint r = (v & 0xFFC00000) >> 22; \
5649                       uint g = (v & 0x003FF000) >> 12; \
5650                       uint b = (v & 0x00000FFC) >> 2; \
5651                       uint a = (v & 0x00000003); \
5652                       return metal::float4(float(r) / 1023.0f, float(g) / 1023.0f, float(b) / 1023.0f, float(a) / 3.0f);",
5653                    */
5654                    "{}return metal::unpack_unorm10a2_to_float(b3 << 24 | b2 << 16 | b1 << 8 | b0);",
5655                    back::INDENT
5656                )?;
5657                writeln!(self.out, "}}")?;
5658                Ok((name, 4, 4))
5659            }
5660            Unorm8x4Bgra => {
5661                let name = self.namer.call("unpackUnorm8x4Bgra");
5662                writeln!(
5663                    self.out,
5664                    "metal::float4 {name}(metal::uchar b0, \
5665                                          metal::uchar b1, \
5666                                          metal::uchar b2, \
5667                                          metal::uchar b3) {{"
5668                )?;
5669                writeln!(
5670                    self.out,
5671                    "{}return metal::float4(float(b2) / 255.0f, \
5672                                            float(b1) / 255.0f, \
5673                                            float(b0) / 255.0f, \
5674                                            float(b3) / 255.0f);",
5675                    back::INDENT
5676                )?;
5677                writeln!(self.out, "}}")?;
5678                Ok((name, 4, 4))
5679            }
5680        }
5681    }
5682
5683    fn write_wrapped_unary_op(
5684        &mut self,
5685        module: &crate::Module,
5686        func_ctx: &back::FunctionCtx,
5687        op: crate::UnaryOperator,
5688        operand: Handle<crate::Expression>,
5689    ) -> BackendResult {
5690        let operand_ty = func_ctx.resolve_type(operand, &module.types);
5691        match op {
5692            // Negating the TYPE_MIN of a two's complement signed integer
5693            // type causes overflow, which is undefined behaviour in MSL. To
5694            // avoid this we bitcast the value to unsigned and negate it,
5695            // then bitcast back to signed.
5696            // This adheres to the WGSL spec in that the negative of the
5697            // type's minimum value should equal to the minimum value.
5698            crate::UnaryOperator::Negate
5699                if operand_ty.scalar_kind() == Some(crate::ScalarKind::Sint) =>
5700            {
5701                let Some((vector_size, scalar)) = operand_ty.vector_size_and_scalar() else {
5702                    return Ok(());
5703                };
5704                let wrapped = WrappedFunction::UnaryOp {
5705                    op,
5706                    ty: (vector_size, scalar),
5707                };
5708                if !self.wrapped_functions.insert(wrapped) {
5709                    return Ok(());
5710                }
5711
5712                let unsigned_scalar = crate::Scalar {
5713                    kind: crate::ScalarKind::Uint,
5714                    ..scalar
5715                };
5716                let mut type_name = String::new();
5717                let mut unsigned_type_name = String::new();
5718                match vector_size {
5719                    None => {
5720                        put_numeric_type(&mut type_name, scalar, &[])?;
5721                        put_numeric_type(&mut unsigned_type_name, unsigned_scalar, &[])?
5722                    }
5723                    Some(size) => {
5724                        put_numeric_type(&mut type_name, scalar, &[size])?;
5725                        put_numeric_type(&mut unsigned_type_name, unsigned_scalar, &[size])?;
5726                    }
5727                };
5728
5729                writeln!(self.out, "{type_name} {NEG_FUNCTION}({type_name} val) {{")?;
5730                let level = back::Level(1);
5731                writeln!(
5732                    self.out,
5733                    "{level}return as_type<{type_name}>(-as_type<{unsigned_type_name}>(val));"
5734                )?;
5735                writeln!(self.out, "}}")?;
5736                writeln!(self.out)?;
5737            }
5738            _ => {}
5739        }
5740        Ok(())
5741    }
5742
5743    fn write_wrapped_binary_op(
5744        &mut self,
5745        module: &crate::Module,
5746        func_ctx: &back::FunctionCtx,
5747        expr: Handle<crate::Expression>,
5748        op: crate::BinaryOperator,
5749        left: Handle<crate::Expression>,
5750        right: Handle<crate::Expression>,
5751    ) -> BackendResult {
5752        let expr_ty = func_ctx.resolve_type(expr, &module.types);
5753        let left_ty = func_ctx.resolve_type(left, &module.types);
5754        let right_ty = func_ctx.resolve_type(right, &module.types);
5755        match (op, expr_ty.scalar_kind()) {
5756            // Signed integer division of TYPE_MIN / -1, or signed or
5757            // unsigned division by zero, gives an unspecified value in MSL.
5758            // We override the divisor to 1 in these cases.
5759            // This adheres to the WGSL spec in that:
5760            // * TYPE_MIN / -1 == TYPE_MIN
5761            // * x / 0 == x
5762            (
5763                crate::BinaryOperator::Divide,
5764                Some(crate::ScalarKind::Sint | crate::ScalarKind::Uint),
5765            ) => {
5766                let Some(left_wrapped_ty) = left_ty.vector_size_and_scalar() else {
5767                    return Ok(());
5768                };
5769                let Some(right_wrapped_ty) = right_ty.vector_size_and_scalar() else {
5770                    return Ok(());
5771                };
5772                let wrapped = WrappedFunction::BinaryOp {
5773                    op,
5774                    left_ty: left_wrapped_ty,
5775                    right_ty: right_wrapped_ty,
5776                };
5777                if !self.wrapped_functions.insert(wrapped) {
5778                    return Ok(());
5779                }
5780
5781                let Some((vector_size, scalar)) = expr_ty.vector_size_and_scalar() else {
5782                    return Ok(());
5783                };
5784                let mut type_name = String::new();
5785                match vector_size {
5786                    None => put_numeric_type(&mut type_name, scalar, &[])?,
5787                    Some(size) => put_numeric_type(&mut type_name, scalar, &[size])?,
5788                };
5789                writeln!(
5790                    self.out,
5791                    "{type_name} {DIV_FUNCTION}({type_name} lhs, {type_name} rhs) {{"
5792                )?;
5793                let level = back::Level(1);
5794                match scalar.kind {
5795                    crate::ScalarKind::Sint => {
5796                        let min_val = match scalar.width {
5797                            4 => crate::Literal::I32(i32::MIN),
5798                            8 => crate::Literal::I64(i64::MIN),
5799                            _ => {
5800                                return Err(Error::GenericValidation(format!(
5801                                    "Unexpected width for scalar {scalar:?}"
5802                                )));
5803                            }
5804                        };
5805                        write!(
5806                            self.out,
5807                            "{level}return lhs / metal::select(rhs, 1, (lhs == "
5808                        )?;
5809                        self.put_literal(min_val)?;
5810                        writeln!(self.out, " & rhs == -1) | (rhs == 0));")?
5811                    }
5812                    crate::ScalarKind::Uint => writeln!(
5813                        self.out,
5814                        "{level}return lhs / metal::select(rhs, 1u, rhs == 0u);"
5815                    )?,
5816                    _ => unreachable!(),
5817                }
5818                writeln!(self.out, "}}")?;
5819                writeln!(self.out)?;
5820            }
5821            // Integer modulo where one or both operands are negative, or the
5822            // divisor is zero, is undefined behaviour in MSL. To avoid this
5823            // we use the following equation:
5824            //
5825            // dividend - (dividend / divisor) * divisor
5826            //
5827            // overriding the divisor to 1 if either it is 0, or it is -1
5828            // and the dividend is TYPE_MIN.
5829            //
5830            // This adheres to the WGSL spec in that:
5831            // * TYPE_MIN % -1 == 0
5832            // * x % 0 == 0
5833            (
5834                crate::BinaryOperator::Modulo,
5835                Some(crate::ScalarKind::Sint | crate::ScalarKind::Uint),
5836            ) => {
5837                let Some(left_wrapped_ty) = left_ty.vector_size_and_scalar() else {
5838                    return Ok(());
5839                };
5840                let Some((right_vector_size, right_scalar)) = right_ty.vector_size_and_scalar()
5841                else {
5842                    return Ok(());
5843                };
5844                let wrapped = WrappedFunction::BinaryOp {
5845                    op,
5846                    left_ty: left_wrapped_ty,
5847                    right_ty: (right_vector_size, right_scalar),
5848                };
5849                if !self.wrapped_functions.insert(wrapped) {
5850                    return Ok(());
5851                }
5852
5853                let Some((vector_size, scalar)) = expr_ty.vector_size_and_scalar() else {
5854                    return Ok(());
5855                };
5856                let mut type_name = String::new();
5857                match vector_size {
5858                    None => put_numeric_type(&mut type_name, scalar, &[])?,
5859                    Some(size) => put_numeric_type(&mut type_name, scalar, &[size])?,
5860                };
5861                let mut rhs_type_name = String::new();
5862                match right_vector_size {
5863                    None => put_numeric_type(&mut rhs_type_name, right_scalar, &[])?,
5864                    Some(size) => put_numeric_type(&mut rhs_type_name, right_scalar, &[size])?,
5865                };
5866
5867                writeln!(
5868                    self.out,
5869                    "{type_name} {MOD_FUNCTION}({type_name} lhs, {type_name} rhs) {{"
5870                )?;
5871                let level = back::Level(1);
5872                match scalar.kind {
5873                    crate::ScalarKind::Sint => {
5874                        let min_val = match scalar.width {
5875                            4 => crate::Literal::I32(i32::MIN),
5876                            8 => crate::Literal::I64(i64::MIN),
5877                            _ => {
5878                                return Err(Error::GenericValidation(format!(
5879                                    "Unexpected width for scalar {scalar:?}"
5880                                )));
5881                            }
5882                        };
5883                        write!(
5884                            self.out,
5885                            "{level}{rhs_type_name} divisor = metal::select(rhs, 1, (lhs == "
5886                        )?;
5887                        self.put_literal(min_val)?;
5888                        writeln!(self.out, " & rhs == -1) | (rhs == 0));")?;
5889                        writeln!(self.out, "{level}return lhs - (lhs / divisor) * divisor;")?
5890                    }
5891                    crate::ScalarKind::Uint => writeln!(
5892                        self.out,
5893                        "{level}return lhs % metal::select(rhs, 1u, rhs == 0u);"
5894                    )?,
5895                    _ => unreachable!(),
5896                }
5897                writeln!(self.out, "}}")?;
5898                writeln!(self.out)?;
5899            }
5900            _ => {}
5901        }
5902        Ok(())
5903    }
5904
5905    /// Build the mangled helper name for integer vector dot products.
5906    ///
5907    /// `scalar` must be a concrete integer scalar type.
5908    ///
5909    /// Result format: `{DOT_FUNCTION_PREFIX}_{type}{N}` (e.g., `naga_dot_int3`).
5910    fn get_dot_wrapper_function_helper_name(
5911        &self,
5912        scalar: crate::Scalar,
5913        size: crate::VectorSize,
5914    ) -> String {
5915        // Check for consistency with [`super::keywords::RESERVED_SET`]
5916        debug_assert!(concrete_int_scalars().any(|s| s == scalar));
5917
5918        let type_name = scalar.to_msl_name();
5919        let size_suffix = common::vector_size_str(size);
5920        format!("{DOT_FUNCTION_PREFIX}_{type_name}{size_suffix}")
5921    }
5922
5923    #[allow(clippy::too_many_arguments)]
5924    fn write_wrapped_math_function(
5925        &mut self,
5926        module: &crate::Module,
5927        func_ctx: &back::FunctionCtx,
5928        fun: crate::MathFunction,
5929        arg: Handle<crate::Expression>,
5930        _arg1: Option<Handle<crate::Expression>>,
5931        _arg2: Option<Handle<crate::Expression>>,
5932        _arg3: Option<Handle<crate::Expression>>,
5933    ) -> BackendResult {
5934        let arg_ty = func_ctx.resolve_type(arg, &module.types);
5935        match fun {
5936            // Taking the absolute value of the TYPE_MIN of a two's
5937            // complement signed integer type causes overflow, which is
5938            // undefined behaviour in MSL. To avoid this, when the value is
5939            // negative we bitcast the value to unsigned and negate it, then
5940            // bitcast back to signed.
5941            // This adheres to the WGSL spec in that the absolute of the
5942            // type's minimum value should equal to the minimum value.
5943            crate::MathFunction::Abs if arg_ty.scalar_kind() == Some(crate::ScalarKind::Sint) => {
5944                let Some((vector_size, scalar)) = arg_ty.vector_size_and_scalar() else {
5945                    return Ok(());
5946                };
5947                let wrapped = WrappedFunction::Math {
5948                    fun,
5949                    arg_ty: (vector_size, scalar),
5950                };
5951                if !self.wrapped_functions.insert(wrapped) {
5952                    return Ok(());
5953                }
5954
5955                let unsigned_scalar = crate::Scalar {
5956                    kind: crate::ScalarKind::Uint,
5957                    ..scalar
5958                };
5959                let mut type_name = String::new();
5960                let mut unsigned_type_name = String::new();
5961                match vector_size {
5962                    None => {
5963                        put_numeric_type(&mut type_name, scalar, &[])?;
5964                        put_numeric_type(&mut unsigned_type_name, unsigned_scalar, &[])?
5965                    }
5966                    Some(size) => {
5967                        put_numeric_type(&mut type_name, scalar, &[size])?;
5968                        put_numeric_type(&mut unsigned_type_name, unsigned_scalar, &[size])?;
5969                    }
5970                };
5971
5972                writeln!(self.out, "{type_name} {ABS_FUNCTION}({type_name} val) {{")?;
5973                let level = back::Level(1);
5974                writeln!(self.out, "{level}return metal::select(as_type<{type_name}>(-as_type<{unsigned_type_name}>(val)), val, val >= 0);")?;
5975                writeln!(self.out, "}}")?;
5976                writeln!(self.out)?;
5977            }
5978
5979            crate::MathFunction::Dot => match *arg_ty {
5980                crate::TypeInner::Vector { size, scalar }
5981                    if matches!(
5982                        scalar.kind,
5983                        crate::ScalarKind::Sint | crate::ScalarKind::Uint
5984                    ) =>
5985                {
5986                    // De-duplicate per (fun, arg type) like other wrapped math functions
5987                    let wrapped = WrappedFunction::Math {
5988                        fun,
5989                        arg_ty: (Some(size), scalar),
5990                    };
5991                    if !self.wrapped_functions.insert(wrapped) {
5992                        return Ok(());
5993                    }
5994
5995                    let mut vec_ty = String::new();
5996                    put_numeric_type(&mut vec_ty, scalar, &[size])?;
5997                    let mut ret_ty = String::new();
5998                    put_numeric_type(&mut ret_ty, scalar, &[])?;
5999
6000                    let fun_name = self.get_dot_wrapper_function_helper_name(scalar, size);
6001
6002                    // Emit function signature and body using put_dot_product for the expression
6003                    writeln!(self.out, "{ret_ty} {fun_name}({vec_ty} a, {vec_ty} b) {{")?;
6004                    let level = back::Level(1);
6005                    write!(self.out, "{level}return ")?;
6006                    self.put_dot_product("a", "b", size as usize, |writer, name, index| {
6007                        write!(writer.out, "{name}.{}", back::COMPONENTS[index])?;
6008                        Ok(())
6009                    })?;
6010                    writeln!(self.out, ";")?;
6011                    writeln!(self.out, "}}")?;
6012                    writeln!(self.out)?;
6013                }
6014                _ => {}
6015            },
6016
6017            _ => {}
6018        }
6019        Ok(())
6020    }
6021
6022    fn write_wrapped_cast(
6023        &mut self,
6024        module: &crate::Module,
6025        func_ctx: &back::FunctionCtx,
6026        expr: Handle<crate::Expression>,
6027        kind: crate::ScalarKind,
6028        convert: Option<crate::Bytes>,
6029    ) -> BackendResult {
6030        // Avoid undefined behaviour when casting from a float to integer
6031        // when the value is out of range for the target type. Additionally
6032        // ensure we clamp to the correct value as per the WGSL spec.
6033        //
6034        // https://www.w3.org/TR/WGSL/#floating-point-conversion:
6035        // * If X is exactly representable in the target type T, then the
6036        //   result is that value.
6037        // * Otherwise, the result is the value in T closest to
6038        //   truncate(X) and also exactly representable in the original
6039        //   floating point type.
6040        let src_ty = func_ctx.resolve_type(expr, &module.types);
6041        let Some(width) = convert else {
6042            return Ok(());
6043        };
6044        let Some((vector_size, src_scalar)) = src_ty.vector_size_and_scalar() else {
6045            return Ok(());
6046        };
6047        let dst_scalar = crate::Scalar { kind, width };
6048        if src_scalar.kind != crate::ScalarKind::Float
6049            || (dst_scalar.kind != crate::ScalarKind::Sint
6050                && dst_scalar.kind != crate::ScalarKind::Uint)
6051        {
6052            return Ok(());
6053        }
6054        let wrapped = WrappedFunction::Cast {
6055            src_scalar,
6056            vector_size,
6057            dst_scalar,
6058        };
6059        if !self.wrapped_functions.insert(wrapped) {
6060            return Ok(());
6061        }
6062        let (min, max) = proc::min_max_float_representable_by(src_scalar, dst_scalar);
6063
6064        let mut src_type_name = String::new();
6065        match vector_size {
6066            None => put_numeric_type(&mut src_type_name, src_scalar, &[])?,
6067            Some(size) => put_numeric_type(&mut src_type_name, src_scalar, &[size])?,
6068        };
6069        let mut dst_type_name = String::new();
6070        match vector_size {
6071            None => put_numeric_type(&mut dst_type_name, dst_scalar, &[])?,
6072            Some(size) => put_numeric_type(&mut dst_type_name, dst_scalar, &[size])?,
6073        };
6074        let fun_name = match dst_scalar {
6075            crate::Scalar::I32 => F2I32_FUNCTION,
6076            crate::Scalar::U32 => F2U32_FUNCTION,
6077            crate::Scalar::I64 => F2I64_FUNCTION,
6078            crate::Scalar::U64 => F2U64_FUNCTION,
6079            _ => unreachable!(),
6080        };
6081
6082        writeln!(
6083            self.out,
6084            "{dst_type_name} {fun_name}({src_type_name} value) {{"
6085        )?;
6086        let level = back::Level(1);
6087        write!(
6088            self.out,
6089            "{level}return static_cast<{dst_type_name}>({NAMESPACE}::clamp(value, "
6090        )?;
6091        self.put_literal(min)?;
6092        write!(self.out, ", ")?;
6093        self.put_literal(max)?;
6094        writeln!(self.out, "));")?;
6095        writeln!(self.out, "}}")?;
6096        writeln!(self.out)?;
6097        Ok(())
6098    }
6099
6100    /// Helper function used by [`Self::write_wrapped_image_load`] and
6101    /// [`Self::write_wrapped_image_sample`] to write the shared YUV to RGB
6102    /// conversion code for external textures. Expects the preceding code to
6103    /// declare the Y component as a `float` variable of name `y`, the UV
6104    /// components as a `float2` variable of name `uv`, and the external
6105    /// texture params as a variable of name `params`. The emitted code will
6106    /// return the result.
6107    fn write_convert_yuv_to_rgb_and_return(
6108        &mut self,
6109        level: back::Level,
6110        y: &str,
6111        uv: &str,
6112        params: &str,
6113    ) -> BackendResult {
6114        let l1 = level;
6115        let l2 = l1.next();
6116
6117        // Convert from YUV to non-linear RGB in the source color space.
6118        writeln!(
6119            self.out,
6120            "{l1}float3 srcGammaRgb = ({params}.yuv_conversion_matrix * float4({y}, {uv}, 1.0)).rgb;"
6121        )?;
6122
6123        // Apply the inverse of the source transfer function to convert to
6124        // linear RGB in the source color space.
6125        writeln!(self.out, "{l1}float3 srcLinearRgb = {NAMESPACE}::select(")?;
6126        writeln!(self.out, "{l2}{NAMESPACE}::pow((srcGammaRgb + {params}.src_tf.a - 1.0) / {params}.src_tf.a, {params}.src_tf.g),")?;
6127        writeln!(self.out, "{l2}srcGammaRgb / {params}.src_tf.k,")?;
6128        writeln!(
6129            self.out,
6130            "{l2}srcGammaRgb < {params}.src_tf.k * {params}.src_tf.b);"
6131        )?;
6132
6133        // Multiply by the gamut conversion matrix to convert to linear RGB in
6134        // the destination color space.
6135        writeln!(
6136            self.out,
6137            "{l1}float3 dstLinearRgb = {params}.gamut_conversion_matrix * srcLinearRgb;"
6138        )?;
6139
6140        // Finally, apply the dest transfer function to convert to non-linear
6141        // RGB in the destination color space, and return the result.
6142        writeln!(self.out, "{l1}float3 dstGammaRgb = {NAMESPACE}::select(")?;
6143        writeln!(self.out, "{l2}{params}.dst_tf.a * {NAMESPACE}::pow(dstLinearRgb, 1.0 / {params}.dst_tf.g) - ({params}.dst_tf.a - 1),")?;
6144        writeln!(self.out, "{l2}{params}.dst_tf.k * dstLinearRgb,")?;
6145        writeln!(self.out, "{l2}dstLinearRgb < {params}.dst_tf.b);")?;
6146
6147        writeln!(self.out, "{l1}return float4(dstGammaRgb, 1.0);")?;
6148        Ok(())
6149    }
6150
6151    #[allow(clippy::too_many_arguments)]
6152    fn write_wrapped_image_load(
6153        &mut self,
6154        module: &crate::Module,
6155        func_ctx: &back::FunctionCtx,
6156        image: Handle<crate::Expression>,
6157        _coordinate: Handle<crate::Expression>,
6158        _array_index: Option<Handle<crate::Expression>>,
6159        _sample: Option<Handle<crate::Expression>>,
6160        _level: Option<Handle<crate::Expression>>,
6161    ) -> BackendResult {
6162        // We currently only need to wrap image loads for external textures
6163        let class = match *func_ctx.resolve_type(image, &module.types) {
6164            crate::TypeInner::Image { class, .. } => class,
6165            _ => unreachable!(),
6166        };
6167        if class != crate::ImageClass::External {
6168            return Ok(());
6169        }
6170        let wrapped = WrappedFunction::ImageLoad { class };
6171        if !self.wrapped_functions.insert(wrapped) {
6172            return Ok(());
6173        }
6174
6175        writeln!(self.out, "float4 {IMAGE_LOAD_EXTERNAL_FUNCTION}({EXTERNAL_TEXTURE_WRAPPER_STRUCT} tex, uint2 coords) {{")?;
6176        let l1 = back::Level(1);
6177        let l2 = l1.next();
6178        let l3 = l2.next();
6179        writeln!(
6180            self.out,
6181            "{l1}uint2 plane0_size = uint2(tex.plane0.get_width(), tex.plane0.get_height());"
6182        )?;
6183        // Clamp coords to provided size of external texture to prevent OOB
6184        // read. If params.size is zero then clamp to the actual size of the
6185        // texture.
6186        writeln!(
6187            self.out,
6188            "{l1}uint2 cropped_size = {NAMESPACE}::any(tex.params.size != 0) ? tex.params.size : plane0_size;"
6189        )?;
6190        writeln!(
6191            self.out,
6192            "{l1}coords = {NAMESPACE}::min(coords, cropped_size - 1);"
6193        )?;
6194
6195        // Apply load transformation
6196        writeln!(self.out, "{l1}uint2 plane0_coords = uint2({NAMESPACE}::round(tex.params.load_transform * float3(float2(coords), 1.0)));")?;
6197        writeln!(self.out, "{l1}if (tex.params.num_planes == 1u) {{")?;
6198        // For single plane, simply read from plane0
6199        writeln!(self.out, "{l2}return tex.plane0.read(plane0_coords);")?;
6200        writeln!(self.out, "{l1}}} else {{")?;
6201
6202        // Chroma planes may be subsampled so we must scale the coords accordingly.
6203        writeln!(
6204            self.out,
6205            "{l2}uint2 plane1_size = uint2(tex.plane1.get_width(), tex.plane1.get_height());"
6206        )?;
6207        writeln!(self.out, "{l2}uint2 plane1_coords = uint2({NAMESPACE}::floor(float2(plane0_coords) * float2(plane1_size) / float2(plane0_size)));")?;
6208
6209        // For multi-plane, read the Y value from plane 0
6210        writeln!(self.out, "{l2}float y = tex.plane0.read(plane0_coords).x;")?;
6211
6212        writeln!(self.out, "{l2}float2 uv;")?;
6213        writeln!(self.out, "{l2}if (tex.params.num_planes == 2u) {{")?;
6214        // For 2 planes, read UV from interleaved plane 1
6215        writeln!(self.out, "{l3}uv = tex.plane1.read(plane1_coords).xy;")?;
6216        writeln!(self.out, "{l2}}} else {{")?;
6217        // For 3 planes, read U and V from planes 1 and 2 respectively
6218        writeln!(
6219            self.out,
6220            "{l2}uint2 plane2_size = uint2(tex.plane2.get_width(), tex.plane2.get_height());"
6221        )?;
6222        writeln!(self.out, "{l2}uint2 plane2_coords = uint2({NAMESPACE}::floor(float2(plane0_coords) * float2(plane2_size) / float2(plane0_size)));")?;
6223        writeln!(
6224            self.out,
6225            "{l3}uv = float2(tex.plane1.read(plane1_coords).x, tex.plane2.read(plane2_coords).x);"
6226        )?;
6227        writeln!(self.out, "{l2}}}")?;
6228
6229        self.write_convert_yuv_to_rgb_and_return(l2, "y", "uv", "tex.params")?;
6230
6231        writeln!(self.out, "{l1}}}")?;
6232        writeln!(self.out, "}}")?;
6233        writeln!(self.out)?;
6234        Ok(())
6235    }
6236
6237    #[allow(clippy::too_many_arguments)]
6238    fn write_wrapped_image_sample(
6239        &mut self,
6240        module: &crate::Module,
6241        func_ctx: &back::FunctionCtx,
6242        image: Handle<crate::Expression>,
6243        _sampler: Handle<crate::Expression>,
6244        _gather: Option<crate::SwizzleComponent>,
6245        _coordinate: Handle<crate::Expression>,
6246        _array_index: Option<Handle<crate::Expression>>,
6247        _offset: Option<Handle<crate::Expression>>,
6248        _level: crate::SampleLevel,
6249        _depth_ref: Option<Handle<crate::Expression>>,
6250        clamp_to_edge: bool,
6251    ) -> BackendResult {
6252        // We currently only need to wrap textureSampleBaseClampToEdge, for
6253        // both sampled and external textures.
6254        if !clamp_to_edge {
6255            return Ok(());
6256        }
6257        let class = match *func_ctx.resolve_type(image, &module.types) {
6258            crate::TypeInner::Image { class, .. } => class,
6259            _ => unreachable!(),
6260        };
6261        let wrapped = WrappedFunction::ImageSample {
6262            class,
6263            clamp_to_edge: true,
6264        };
6265        if !self.wrapped_functions.insert(wrapped) {
6266            return Ok(());
6267        }
6268        match class {
6269            crate::ImageClass::External => {
6270                writeln!(self.out, "float4 {IMAGE_SAMPLE_BASE_CLAMP_TO_EDGE_FUNCTION}({EXTERNAL_TEXTURE_WRAPPER_STRUCT} tex, {NAMESPACE}::sampler samp, float2 coords) {{")?;
6271                let l1 = back::Level(1);
6272                let l2 = l1.next();
6273                let l3 = l2.next();
6274                writeln!(self.out, "{l1}uint2 plane0_size = uint2(tex.plane0.get_width(), tex.plane0.get_height());")?;
6275                writeln!(
6276                    self.out,
6277                    "{l1}coords = tex.params.sample_transform * float3(coords, 1.0);"
6278                )?;
6279
6280                // Calculate the sample bounds. The purported size of the texture
6281                // (params.size) is irrelevant here as we are dealing with normalized
6282                // coordinates. Usually we would clamp to (0,0)..(1,1). However, we must
6283                // apply the sample transformation to that, also bearing in mind that it
6284                // may contain a flip on either axis. We calculate and adjust for the
6285                // half-texel separately for each plane as it depends on the actual
6286                // texture size which may vary between planes.
6287                writeln!(
6288                    self.out,
6289                    "{l1}float2 bounds_min = tex.params.sample_transform * float3(0.0, 0.0, 1.0);"
6290                )?;
6291                writeln!(
6292                    self.out,
6293                    "{l1}float2 bounds_max = tex.params.sample_transform * float3(1.0, 1.0, 1.0);"
6294                )?;
6295                writeln!(self.out, "{l1}float4 bounds = float4({NAMESPACE}::min(bounds_min, bounds_max), {NAMESPACE}::max(bounds_min, bounds_max));")?;
6296                writeln!(
6297                    self.out,
6298                    "{l1}float2 plane0_half_texel = float2(0.5, 0.5) / float2(plane0_size);"
6299                )?;
6300                writeln!(
6301                    self.out,
6302                    "{l1}float2 plane0_coords = {NAMESPACE}::clamp(coords, bounds.xy + plane0_half_texel, bounds.zw - plane0_half_texel);"
6303                )?;
6304                writeln!(self.out, "{l1}if (tex.params.num_planes == 1u) {{")?;
6305                // For single plane, simply sample from plane0
6306                writeln!(
6307                    self.out,
6308                    "{l2}return tex.plane0.sample(samp, plane0_coords, {NAMESPACE}::level(0.0f));"
6309                )?;
6310                writeln!(self.out, "{l1}}} else {{")?;
6311                writeln!(self.out, "{l2}uint2 plane1_size = uint2(tex.plane1.get_width(), tex.plane1.get_height());")?;
6312                writeln!(
6313                    self.out,
6314                    "{l2}float2 plane1_half_texel = float2(0.5, 0.5) / float2(plane1_size);"
6315                )?;
6316                writeln!(
6317                    self.out,
6318                    "{l2}float2 plane1_coords = {NAMESPACE}::clamp(coords, bounds.xy + plane1_half_texel, bounds.zw - plane1_half_texel);"
6319                )?;
6320
6321                // For multi-plane, sample the Y value from plane 0
6322                writeln!(
6323                    self.out,
6324                    "{l2}float y = tex.plane0.sample(samp, plane0_coords, {NAMESPACE}::level(0.0f)).r;"
6325                )?;
6326                writeln!(self.out, "{l2}float2 uv = float2(0.0, 0.0);")?;
6327                writeln!(self.out, "{l2}if (tex.params.num_planes == 2u) {{")?;
6328                // For 2 planes, sample UV from interleaved plane 1
6329                writeln!(
6330                    self.out,
6331                    "{l3}uv = tex.plane1.sample(samp, plane1_coords, {NAMESPACE}::level(0.0f)).xy;"
6332                )?;
6333                writeln!(self.out, "{l2}}} else {{")?;
6334                // For 3 planes, sample U and V from planes 1 and 2 respectively
6335                writeln!(self.out, "{l3}uint2 plane2_size = uint2(tex.plane2.get_width(), tex.plane2.get_height());")?;
6336                writeln!(
6337                    self.out,
6338                    "{l3}float2 plane2_half_texel = float2(0.5, 0.5) / float2(plane2_size);"
6339                )?;
6340                writeln!(
6341                    self.out,
6342                    "{l3}float2 plane2_coords = {NAMESPACE}::clamp(coords, bounds.xy + plane2_half_texel, bounds.zw - plane1_half_texel);"
6343                )?;
6344                writeln!(self.out, "{l3}uv.x = tex.plane1.sample(samp, plane1_coords, {NAMESPACE}::level(0.0f)).x;")?;
6345                writeln!(self.out, "{l3}uv.y = tex.plane2.sample(samp, plane2_coords, {NAMESPACE}::level(0.0f)).x;")?;
6346                writeln!(self.out, "{l2}}}")?;
6347
6348                self.write_convert_yuv_to_rgb_and_return(l2, "y", "uv", "tex.params")?;
6349
6350                writeln!(self.out, "{l1}}}")?;
6351                writeln!(self.out, "}}")?;
6352                writeln!(self.out)?;
6353            }
6354            _ => {
6355                writeln!(self.out, "{NAMESPACE}::float4 {IMAGE_SAMPLE_BASE_CLAMP_TO_EDGE_FUNCTION}({NAMESPACE}::texture2d<float, {NAMESPACE}::access::sample> tex, {NAMESPACE}::sampler samp, {NAMESPACE}::float2 coords) {{")?;
6356                let l1 = back::Level(1);
6357                writeln!(self.out, "{l1}{NAMESPACE}::float2 half_texel = 0.5 / {NAMESPACE}::float2(tex.get_width(0u), tex.get_height(0u));")?;
6358                writeln!(
6359                    self.out,
6360                    "{l1}return tex.sample(samp, {NAMESPACE}::clamp(coords, half_texel, 1.0 - half_texel), {NAMESPACE}::level(0.0));"
6361                )?;
6362                writeln!(self.out, "}}")?;
6363                writeln!(self.out)?;
6364            }
6365        }
6366        Ok(())
6367    }
6368
6369    fn write_wrapped_image_query(
6370        &mut self,
6371        module: &crate::Module,
6372        func_ctx: &back::FunctionCtx,
6373        image: Handle<crate::Expression>,
6374        query: crate::ImageQuery,
6375    ) -> BackendResult {
6376        // We currently only need to wrap size image queries for external textures
6377        if !matches!(query, crate::ImageQuery::Size { .. }) {
6378            return Ok(());
6379        }
6380        let class = match *func_ctx.resolve_type(image, &module.types) {
6381            crate::TypeInner::Image { class, .. } => class,
6382            _ => unreachable!(),
6383        };
6384        if class != crate::ImageClass::External {
6385            return Ok(());
6386        }
6387        let wrapped = WrappedFunction::ImageQuerySize { class };
6388        if !self.wrapped_functions.insert(wrapped) {
6389            return Ok(());
6390        }
6391        writeln!(
6392            self.out,
6393            "uint2 {IMAGE_SIZE_EXTERNAL_FUNCTION}({EXTERNAL_TEXTURE_WRAPPER_STRUCT} tex) {{"
6394        )?;
6395        let l1 = back::Level(1);
6396        let l2 = l1.next();
6397        writeln!(
6398            self.out,
6399            "{l1}if ({NAMESPACE}::any(tex.params.size != uint2(0u))) {{"
6400        )?;
6401        writeln!(self.out, "{l2}return tex.params.size;")?;
6402        writeln!(self.out, "{l1}}} else {{")?;
6403        // params.size == (0, 0) indicates to query and return plane 0's actual size
6404        writeln!(
6405            self.out,
6406            "{l2}return uint2(tex.plane0.get_width(), tex.plane0.get_height());"
6407        )?;
6408        writeln!(self.out, "{l1}}}")?;
6409        writeln!(self.out, "}}")?;
6410        writeln!(self.out)?;
6411        Ok(())
6412    }
6413
6414    fn write_wrapped_cooperative_load(
6415        &mut self,
6416        module: &crate::Module,
6417        func_ctx: &back::FunctionCtx,
6418        columns: crate::CooperativeSize,
6419        rows: crate::CooperativeSize,
6420        pointer: Handle<crate::Expression>,
6421    ) -> BackendResult {
6422        let ptr_ty = func_ctx.resolve_type(pointer, &module.types);
6423        let space = ptr_ty.pointer_space().unwrap();
6424        let space_name = space.to_msl_name().unwrap_or_default();
6425        let scalar = ptr_ty
6426            .pointer_base_type()
6427            .unwrap()
6428            .inner_with(&module.types)
6429            .scalar()
6430            .unwrap();
6431        let wrapped = WrappedFunction::CooperativeLoad {
6432            space_name,
6433            columns,
6434            rows,
6435            scalar,
6436        };
6437        if !self.wrapped_functions.insert(wrapped) {
6438            return Ok(());
6439        }
6440        let scalar_name = scalar.to_msl_name();
6441        writeln!(
6442            self.out,
6443            "{NAMESPACE}::simdgroup_{scalar_name}{}x{} {COOPERATIVE_LOAD_FUNCTION}(const {space_name} {scalar_name}* ptr, int stride, bool is_row_major) {{",
6444            columns as u32, rows as u32,
6445        )?;
6446        let l1 = back::Level(1);
6447        writeln!(
6448            self.out,
6449            "{l1}{NAMESPACE}::simdgroup_{scalar_name}{}x{} m;",
6450            columns as u32, rows as u32
6451        )?;
6452        let matrix_origin = "0";
6453        writeln!(
6454            self.out,
6455            "{l1}simdgroup_load(m, ptr, stride, {matrix_origin}, is_row_major);"
6456        )?;
6457        writeln!(self.out, "{l1}return m;")?;
6458        writeln!(self.out, "}}")?;
6459        writeln!(self.out)?;
6460        Ok(())
6461    }
6462
6463    fn write_wrapped_cooperative_multiply_add(
6464        &mut self,
6465        module: &crate::Module,
6466        func_ctx: &back::FunctionCtx,
6467        space: crate::AddressSpace,
6468        a: Handle<crate::Expression>,
6469        b: Handle<crate::Expression>,
6470    ) -> BackendResult {
6471        let space_name = space.to_msl_name().unwrap_or_default();
6472        let (a_c, a_r, scalar) = match *func_ctx.resolve_type(a, &module.types) {
6473            crate::TypeInner::CooperativeMatrix {
6474                columns,
6475                rows,
6476                scalar,
6477                ..
6478            } => (columns, rows, scalar),
6479            _ => unreachable!(),
6480        };
6481        let (b_c, b_r) = match *func_ctx.resolve_type(b, &module.types) {
6482            crate::TypeInner::CooperativeMatrix { columns, rows, .. } => (columns, rows),
6483            _ => unreachable!(),
6484        };
6485        let wrapped = WrappedFunction::CooperativeMultiplyAdd {
6486            space_name,
6487            columns: b_c,
6488            rows: a_r,
6489            intermediate: a_c,
6490            scalar,
6491        };
6492        if !self.wrapped_functions.insert(wrapped) {
6493            return Ok(());
6494        }
6495        let scalar_name = scalar.to_msl_name();
6496        writeln!(
6497            self.out,
6498            "{NAMESPACE}::simdgroup_{scalar_name}{}x{} {COOPERATIVE_MULTIPLY_ADD_FUNCTION}(const {space_name} {NAMESPACE}::simdgroup_{scalar_name}{}x{}& a, const {space_name} {NAMESPACE}::simdgroup_{scalar_name}{}x{}& b, const {space_name} {NAMESPACE}::simdgroup_{scalar_name}{}x{}& c) {{",
6499            b_c as u32, a_r as u32, a_c as u32, a_r as u32, b_c as u32, b_r as u32, b_c as u32, a_r as u32,
6500        )?;
6501        let l1 = back::Level(1);
6502        writeln!(
6503            self.out,
6504            "{l1}{NAMESPACE}::simdgroup_{scalar_name}{}x{} d;",
6505            b_c as u32, a_r as u32
6506        )?;
6507        writeln!(self.out, "{l1}simdgroup_multiply_accumulate(d,a,b,c);")?;
6508        writeln!(self.out, "{l1}return d;")?;
6509        writeln!(self.out, "}}")?;
6510        writeln!(self.out)?;
6511        Ok(())
6512    }
6513
6514    pub(super) fn write_wrapped_functions(
6515        &mut self,
6516        module: &crate::Module,
6517        func_ctx: &back::FunctionCtx,
6518    ) -> BackendResult {
6519        for (expr_handle, expr) in func_ctx.expressions.iter() {
6520            match *expr {
6521                crate::Expression::Unary { op, expr: operand } => {
6522                    self.write_wrapped_unary_op(module, func_ctx, op, operand)?;
6523                }
6524                crate::Expression::Binary { op, left, right } => {
6525                    self.write_wrapped_binary_op(module, func_ctx, expr_handle, op, left, right)?;
6526                }
6527                crate::Expression::Math {
6528                    fun,
6529                    arg,
6530                    arg1,
6531                    arg2,
6532                    arg3,
6533                } => {
6534                    self.write_wrapped_math_function(module, func_ctx, fun, arg, arg1, arg2, arg3)?;
6535                }
6536                crate::Expression::As {
6537                    expr,
6538                    kind,
6539                    convert,
6540                } => {
6541                    self.write_wrapped_cast(module, func_ctx, expr, kind, convert)?;
6542                }
6543                crate::Expression::ImageLoad {
6544                    image,
6545                    coordinate,
6546                    array_index,
6547                    sample,
6548                    level,
6549                } => {
6550                    self.write_wrapped_image_load(
6551                        module,
6552                        func_ctx,
6553                        image,
6554                        coordinate,
6555                        array_index,
6556                        sample,
6557                        level,
6558                    )?;
6559                }
6560                crate::Expression::ImageSample {
6561                    image,
6562                    sampler,
6563                    gather,
6564                    coordinate,
6565                    array_index,
6566                    offset,
6567                    level,
6568                    depth_ref,
6569                    clamp_to_edge,
6570                } => {
6571                    self.write_wrapped_image_sample(
6572                        module,
6573                        func_ctx,
6574                        image,
6575                        sampler,
6576                        gather,
6577                        coordinate,
6578                        array_index,
6579                        offset,
6580                        level,
6581                        depth_ref,
6582                        clamp_to_edge,
6583                    )?;
6584                }
6585                crate::Expression::ImageQuery { image, query } => {
6586                    self.write_wrapped_image_query(module, func_ctx, image, query)?;
6587                }
6588                crate::Expression::CooperativeLoad {
6589                    columns,
6590                    rows,
6591                    role: _,
6592                    ref data,
6593                } => {
6594                    self.write_wrapped_cooperative_load(
6595                        module,
6596                        func_ctx,
6597                        columns,
6598                        rows,
6599                        data.pointer,
6600                    )?;
6601                }
6602                crate::Expression::CooperativeMultiplyAdd { a, b, c: _ } => {
6603                    let space = crate::AddressSpace::Private;
6604                    self.write_wrapped_cooperative_multiply_add(module, func_ctx, space, a, b)?;
6605                }
6606                _ => {}
6607            }
6608        }
6609
6610        Ok(())
6611    }
6612
6613    // Returns the array of mapped entry point names.
6614    fn write_functions(
6615        &mut self,
6616        module: &crate::Module,
6617        mod_info: &valid::ModuleInfo,
6618        options: &Options,
6619        pipeline_options: &PipelineOptions,
6620    ) -> Result<TranslationInfo, Error> {
6621        use back::msl::VertexFormat;
6622
6623        // Define structs to hold resolved/generated data for vertex buffers and
6624        // their attributes.
6625        struct AttributeMappingResolved {
6626            ty_name: String,
6627            dimension: u32,
6628            ty_is_int: bool,
6629            name: String,
6630        }
6631        let mut am_resolved = FastHashMap::<u32, AttributeMappingResolved>::default();
6632
6633        struct VertexBufferMappingResolved<'a> {
6634            id: u32,
6635            stride: u32,
6636            step_mode: back::msl::VertexBufferStepMode,
6637            ty_name: String,
6638            param_name: String,
6639            elem_name: String,
6640            attributes: &'a Vec<back::msl::AttributeMapping>,
6641        }
6642        let mut vbm_resolved = Vec::<VertexBufferMappingResolved>::new();
6643
6644        // Define a struct to hold a named reference to a byte-unpacking function.
6645        struct UnpackingFunction {
6646            name: String,
6647            byte_count: u32,
6648            dimension: u32,
6649        }
6650        let mut unpacking_functions = FastHashMap::<VertexFormat, UnpackingFunction>::default();
6651
6652        // Check if we are attempting vertex pulling. If we are, generate some
6653        // names we'll need, and iterate the vertex buffer mappings to output
6654        // all the conversion functions we'll need to unpack the attribute data.
6655        // We can re-use these names for all entry points that need them, since
6656        // those entry points also use self.namer.
6657        let mut needs_vertex_id = false;
6658        let v_id = self.namer.call("v_id");
6659
6660        let mut needs_instance_id = false;
6661        let i_id = self.namer.call("i_id");
6662        if pipeline_options.vertex_pulling_transform {
6663            for vbm in &pipeline_options.vertex_buffer_mappings {
6664                let buffer_id = vbm.id;
6665                let buffer_stride = vbm.stride;
6666
6667                assert!(
6668                    buffer_stride > 0,
6669                    "Vertex pulling requires a non-zero buffer stride."
6670                );
6671
6672                match vbm.step_mode {
6673                    back::msl::VertexBufferStepMode::Constant => {}
6674                    back::msl::VertexBufferStepMode::ByVertex => {
6675                        needs_vertex_id = true;
6676                    }
6677                    back::msl::VertexBufferStepMode::ByInstance => {
6678                        needs_instance_id = true;
6679                    }
6680                }
6681
6682                let buffer_ty = self.namer.call(format!("vb_{buffer_id}_type").as_str());
6683                let buffer_param = self.namer.call(format!("vb_{buffer_id}_in").as_str());
6684                let buffer_elem = self.namer.call(format!("vb_{buffer_id}_elem").as_str());
6685
6686                vbm_resolved.push(VertexBufferMappingResolved {
6687                    id: buffer_id,
6688                    stride: buffer_stride,
6689                    step_mode: vbm.step_mode,
6690                    ty_name: buffer_ty,
6691                    param_name: buffer_param,
6692                    elem_name: buffer_elem,
6693                    attributes: &vbm.attributes,
6694                });
6695
6696                // Iterate the attributes and generate needed unpacking functions.
6697                for attribute in &vbm.attributes {
6698                    if unpacking_functions.contains_key(&attribute.format) {
6699                        continue;
6700                    }
6701                    let (name, byte_count, dimension) =
6702                        match self.write_unpacking_function(attribute.format) {
6703                            Ok((name, byte_count, dimension)) => (name, byte_count, dimension),
6704                            _ => {
6705                                continue;
6706                            }
6707                        };
6708                    unpacking_functions.insert(
6709                        attribute.format,
6710                        UnpackingFunction {
6711                            name,
6712                            byte_count,
6713                            dimension,
6714                        },
6715                    );
6716                }
6717            }
6718        }
6719
6720        let mut pass_through_globals = Vec::new();
6721        for (fun_handle, fun) in module.functions.iter() {
6722            log::trace!(
6723                "function {:?}, handle {:?}",
6724                fun.name.as_deref().unwrap_or("(anonymous)"),
6725                fun_handle
6726            );
6727
6728            let ctx = back::FunctionCtx {
6729                ty: back::FunctionType::Function(fun_handle),
6730                info: &mod_info[fun_handle],
6731                expressions: &fun.expressions,
6732                named_expressions: &fun.named_expressions,
6733            };
6734
6735            writeln!(self.out)?;
6736            self.write_wrapped_functions(module, &ctx)?;
6737
6738            let fun_info = &mod_info[fun_handle];
6739            pass_through_globals.clear();
6740            let mut needs_buffer_sizes = false;
6741            for (handle, var) in module.global_variables.iter() {
6742                if !fun_info[handle].is_empty() {
6743                    if var.space.needs_pass_through() {
6744                        pass_through_globals.push(handle);
6745                    }
6746                    needs_buffer_sizes |= needs_array_length(var.ty, &module.types);
6747                }
6748            }
6749
6750            let fun_name = &self.names[&NameKey::Function(fun_handle)];
6751            match fun.result {
6752                Some(ref result) => {
6753                    let ty_name = TypeContext {
6754                        handle: result.ty,
6755                        gctx: module.to_ctx(),
6756                        names: &self.names,
6757                        access: crate::StorageAccess::empty(),
6758                        first_time: false,
6759                    };
6760                    write!(self.out, "{ty_name}")?;
6761                }
6762                None => {
6763                    write!(self.out, "void")?;
6764                }
6765            }
6766            writeln!(self.out, " {fun_name}(")?;
6767
6768            for (index, arg) in fun.arguments.iter().enumerate() {
6769                let name = &self.names[&NameKey::FunctionArgument(fun_handle, index as u32)];
6770                let param_type_name = TypeContext {
6771                    handle: arg.ty,
6772                    gctx: module.to_ctx(),
6773                    names: &self.names,
6774                    access: crate::StorageAccess::empty(),
6775                    first_time: false,
6776                };
6777                let separator = separate(
6778                    !pass_through_globals.is_empty()
6779                        || index + 1 != fun.arguments.len()
6780                        || needs_buffer_sizes,
6781                );
6782                writeln!(
6783                    self.out,
6784                    "{}{} {}{}",
6785                    back::INDENT,
6786                    param_type_name,
6787                    name,
6788                    separator
6789                )?;
6790            }
6791            for (index, &handle) in pass_through_globals.iter().enumerate() {
6792                let tyvar = TypedGlobalVariable {
6793                    module,
6794                    names: &self.names,
6795                    handle,
6796                    usage: fun_info[handle],
6797                    reference: true,
6798                };
6799                let separator =
6800                    separate(index + 1 != pass_through_globals.len() || needs_buffer_sizes);
6801                write!(self.out, "{}", back::INDENT)?;
6802                tyvar.try_fmt(&mut self.out)?;
6803                writeln!(self.out, "{separator}")?;
6804            }
6805
6806            if needs_buffer_sizes {
6807                writeln!(
6808                    self.out,
6809                    "{}constant _mslBufferSizes& _buffer_sizes",
6810                    back::INDENT
6811                )?;
6812            }
6813
6814            writeln!(self.out, ") {{")?;
6815
6816            let guarded_indices =
6817                index::find_checked_indexes(module, fun, fun_info, options.bounds_check_policies);
6818
6819            let context = StatementContext {
6820                expression: ExpressionContext {
6821                    function: fun,
6822                    origin: FunctionOrigin::Handle(fun_handle),
6823                    info: fun_info,
6824                    lang_version: options.lang_version,
6825                    policies: options.bounds_check_policies,
6826                    guarded_indices,
6827                    module,
6828                    mod_info,
6829                    pipeline_options,
6830                    force_loop_bounding: options.force_loop_bounding,
6831                },
6832                result_struct: None,
6833            };
6834
6835            self.put_locals(&context.expression)?;
6836            self.update_expressions_to_bake(fun, fun_info, &context.expression);
6837            self.put_block(back::Level(1), &fun.body, &context)?;
6838            writeln!(self.out, "}}")?;
6839            self.named_expressions.clear();
6840        }
6841
6842        let ep_range = get_entry_points(module, pipeline_options.entry_point.as_ref())
6843            .map_err(|(stage, name)| Error::EntryPointNotFound(stage, name))?;
6844
6845        let mut info = TranslationInfo {
6846            entry_point_names: Vec::with_capacity(ep_range.len()),
6847        };
6848
6849        for ep_index in ep_range {
6850            let ep = &module.entry_points[ep_index];
6851            let fun = &ep.function;
6852            let fun_info = mod_info.get_entry_point(ep_index);
6853            let mut ep_error = None;
6854
6855            // For vertex_id and instance_id arguments, presume that we'll
6856            // use our generated names, but switch to the name of an
6857            // existing @builtin param, if we find one.
6858            let mut v_existing_id = None;
6859            let mut i_existing_id = None;
6860
6861            log::trace!(
6862                "entry point {:?}, index {:?}",
6863                fun.name.as_deref().unwrap_or("(anonymous)"),
6864                ep_index
6865            );
6866
6867            let ctx = back::FunctionCtx {
6868                ty: back::FunctionType::EntryPoint(ep_index as u16),
6869                info: fun_info,
6870                expressions: &fun.expressions,
6871                named_expressions: &fun.named_expressions,
6872            };
6873
6874            self.write_wrapped_functions(module, &ctx)?;
6875
6876            let (em_str, in_mode, out_mode, can_vertex_pull) = match ep.stage {
6877                crate::ShaderStage::Vertex => (
6878                    "vertex",
6879                    LocationMode::VertexInput,
6880                    LocationMode::VertexOutput,
6881                    true,
6882                ),
6883                crate::ShaderStage::Fragment => (
6884                    "fragment",
6885                    LocationMode::FragmentInput,
6886                    LocationMode::FragmentOutput,
6887                    false,
6888                ),
6889                crate::ShaderStage::Compute => (
6890                    "kernel",
6891                    LocationMode::Uniform,
6892                    LocationMode::Uniform,
6893                    false,
6894                ),
6895                crate::ShaderStage::Task | crate::ShaderStage::Mesh => unimplemented!(),
6896                crate::ShaderStage::RayGeneration
6897                | crate::ShaderStage::AnyHit
6898                | crate::ShaderStage::ClosestHit
6899                | crate::ShaderStage::Miss => unimplemented!(),
6900            };
6901
6902            // Should this entry point be modified to do vertex pulling?
6903            let do_vertex_pulling = can_vertex_pull
6904                && pipeline_options.vertex_pulling_transform
6905                && !pipeline_options.vertex_buffer_mappings.is_empty();
6906
6907            // Is any global variable used by this entry point dynamically sized?
6908            let needs_buffer_sizes = do_vertex_pulling
6909                || module
6910                    .global_variables
6911                    .iter()
6912                    .filter(|&(handle, _)| !fun_info[handle].is_empty())
6913                    .any(|(_, var)| needs_array_length(var.ty, &module.types));
6914
6915            // skip this entry point if any global bindings are missing,
6916            // or their types are incompatible.
6917            if !options.fake_missing_bindings {
6918                for (var_handle, var) in module.global_variables.iter() {
6919                    if fun_info[var_handle].is_empty() {
6920                        continue;
6921                    }
6922                    match var.space {
6923                        crate::AddressSpace::Uniform
6924                        | crate::AddressSpace::Storage { .. }
6925                        | crate::AddressSpace::Handle => {
6926                            let br = match var.binding {
6927                                Some(ref br) => br,
6928                                None => {
6929                                    let var_name = var.name.clone().unwrap_or_default();
6930                                    ep_error =
6931                                        Some(super::EntryPointError::MissingBinding(var_name));
6932                                    break;
6933                                }
6934                            };
6935                            let target = options.get_resource_binding_target(ep, br);
6936                            let good = match target {
6937                                Some(target) => {
6938                                    // We intentionally don't dereference binding_arrays here,
6939                                    // so that binding arrays fall to the buffer location.
6940
6941                                    match module.types[var.ty].inner {
6942                                        crate::TypeInner::Image {
6943                                            class: crate::ImageClass::External,
6944                                            ..
6945                                        } => target.external_texture.is_some(),
6946                                        crate::TypeInner::Image { .. } => target.texture.is_some(),
6947                                        crate::TypeInner::Sampler { .. } => {
6948                                            target.sampler.is_some()
6949                                        }
6950                                        _ => target.buffer.is_some(),
6951                                    }
6952                                }
6953                                None => false,
6954                            };
6955                            if !good {
6956                                ep_error = Some(super::EntryPointError::MissingBindTarget(*br));
6957                                break;
6958                            }
6959                        }
6960                        crate::AddressSpace::Immediate => {
6961                            if let Err(e) = options.resolve_immediates(ep) {
6962                                ep_error = Some(e);
6963                                break;
6964                            }
6965                        }
6966                        crate::AddressSpace::TaskPayload => {
6967                            unimplemented!()
6968                        }
6969                        crate::AddressSpace::Function
6970                        | crate::AddressSpace::Private
6971                        | crate::AddressSpace::WorkGroup => {}
6972                        crate::AddressSpace::RayPayload
6973                        | crate::AddressSpace::IncomingRayPayload => unimplemented!(),
6974                    }
6975                }
6976                if needs_buffer_sizes {
6977                    if let Err(err) = options.resolve_sizes_buffer(ep) {
6978                        ep_error = Some(err);
6979                    }
6980                }
6981            }
6982
6983            if let Some(err) = ep_error {
6984                info.entry_point_names.push(Err(err));
6985                continue;
6986            }
6987            let fun_name = &self.names[&NameKey::EntryPoint(ep_index as _)];
6988            info.entry_point_names.push(Ok(fun_name.clone()));
6989
6990            writeln!(self.out)?;
6991
6992            // Since `Namer.reset` wasn't expecting struct members to be
6993            // suddenly injected into another namespace like this,
6994            // `self.names` doesn't keep them distinct from other variables.
6995            // Generate fresh names for these arguments, and remember the
6996            // mapping.
6997            let mut flattened_member_names = FastHashMap::default();
6998            // Varyings' members get their own namespace
6999            let mut varyings_namer = proc::Namer::default();
7000
7001            // List all the Naga `EntryPoint`'s `Function`'s arguments,
7002            // flattening structs into their members. In Metal, we will pass
7003            // each of these values to the entry point as a separate argument—
7004            // except for the varyings, handled next.
7005            let mut flattened_arguments = Vec::new();
7006            for (arg_index, arg) in fun.arguments.iter().enumerate() {
7007                match module.types[arg.ty].inner {
7008                    crate::TypeInner::Struct { ref members, .. } => {
7009                        for (member_index, member) in members.iter().enumerate() {
7010                            let member_index = member_index as u32;
7011                            flattened_arguments.push((
7012                                NameKey::StructMember(arg.ty, member_index),
7013                                member.ty,
7014                                member.binding.as_ref(),
7015                            ));
7016                            let name_key = NameKey::StructMember(arg.ty, member_index);
7017                            let name = match member.binding {
7018                                Some(crate::Binding::Location { .. }) => {
7019                                    if do_vertex_pulling {
7020                                        self.namer.call(&self.names[&name_key])
7021                                    } else {
7022                                        varyings_namer.call(&self.names[&name_key])
7023                                    }
7024                                }
7025                                _ => self.namer.call(&self.names[&name_key]),
7026                            };
7027                            flattened_member_names.insert(name_key, name);
7028                        }
7029                    }
7030                    _ => flattened_arguments.push((
7031                        NameKey::EntryPointArgument(ep_index as _, arg_index as u32),
7032                        arg.ty,
7033                        arg.binding.as_ref(),
7034                    )),
7035                }
7036            }
7037
7038            // Identify the varyings among the argument values, and maybe emit
7039            // a struct type named `<fun>Input` to hold them. If we are doing
7040            // vertex pulling, we instead update our attribute mapping to
7041            // note the types, names, and zero values of the attributes.
7042            let stage_in_name = self.namer.call(&format!("{fun_name}Input"));
7043            let varyings_member_name = self.namer.call("varyings");
7044            let mut has_varyings = false;
7045            if !flattened_arguments.is_empty() {
7046                if !do_vertex_pulling {
7047                    writeln!(self.out, "struct {stage_in_name} {{")?;
7048                }
7049                for &(ref name_key, ty, binding) in flattened_arguments.iter() {
7050                    let Some(binding) = binding else {
7051                        continue;
7052                    };
7053                    let name = match *name_key {
7054                        NameKey::StructMember(..) => &flattened_member_names[name_key],
7055                        _ => &self.names[name_key],
7056                    };
7057                    let ty_name = TypeContext {
7058                        handle: ty,
7059                        gctx: module.to_ctx(),
7060                        names: &self.names,
7061                        access: crate::StorageAccess::empty(),
7062                        first_time: false,
7063                    };
7064                    let resolved = options.resolve_local_binding(binding, in_mode)?;
7065                    let location = match *binding {
7066                        crate::Binding::Location { location, .. } => Some(location),
7067                        crate::Binding::BuiltIn(crate::BuiltIn::Barycentric { .. }) => None,
7068                        crate::Binding::BuiltIn(_) => continue,
7069                    };
7070                    if do_vertex_pulling {
7071                        let Some(location) = location else {
7072                            continue;
7073                        };
7074                        // Update our attribute mapping.
7075                        am_resolved.insert(
7076                            location,
7077                            AttributeMappingResolved {
7078                                ty_name: ty_name.to_string(),
7079                                dimension: ty_name.vertex_input_dimension(),
7080                                ty_is_int: ty_name.scalar().is_some_and(scalar_is_int),
7081                                name: name.to_string(),
7082                            },
7083                        );
7084                    } else {
7085                        has_varyings = true;
7086                        write!(self.out, "{}{} {}", back::INDENT, ty_name, name)?;
7087                        resolved.try_fmt(&mut self.out)?;
7088                        writeln!(self.out, ";")?;
7089                    }
7090                }
7091                if !do_vertex_pulling {
7092                    writeln!(self.out, "}};")?;
7093                }
7094            }
7095
7096            // Define a struct type named for the return value, if any, named
7097            // `<fun>Output`.
7098            let stage_out_name = self.namer.call(&format!("{fun_name}Output"));
7099            let result_member_name = self.namer.call("member");
7100            let result_type_name = match fun.result {
7101                Some(ref result) => {
7102                    let mut result_members = Vec::new();
7103                    if let crate::TypeInner::Struct { ref members, .. } =
7104                        module.types[result.ty].inner
7105                    {
7106                        for (member_index, member) in members.iter().enumerate() {
7107                            result_members.push((
7108                                &self.names[&NameKey::StructMember(result.ty, member_index as u32)],
7109                                member.ty,
7110                                member.binding.as_ref(),
7111                            ));
7112                        }
7113                    } else {
7114                        result_members.push((
7115                            &result_member_name,
7116                            result.ty,
7117                            result.binding.as_ref(),
7118                        ));
7119                    }
7120
7121                    writeln!(self.out, "struct {stage_out_name} {{")?;
7122                    let mut has_point_size = false;
7123                    for (name, ty, binding) in result_members {
7124                        let ty_name = TypeContext {
7125                            handle: ty,
7126                            gctx: module.to_ctx(),
7127                            names: &self.names,
7128                            access: crate::StorageAccess::empty(),
7129                            first_time: true,
7130                        };
7131                        let binding = binding.ok_or_else(|| {
7132                            Error::GenericValidation("Expected binding, got None".into())
7133                        })?;
7134
7135                        if let crate::Binding::BuiltIn(crate::BuiltIn::PointSize) = *binding {
7136                            has_point_size = true;
7137                            if !pipeline_options.allow_and_force_point_size {
7138                                continue;
7139                            }
7140                        }
7141
7142                        let array_len = match module.types[ty].inner {
7143                            crate::TypeInner::Array {
7144                                size: crate::ArraySize::Constant(size),
7145                                ..
7146                            } => Some(size),
7147                            _ => None,
7148                        };
7149                        let resolved = options.resolve_local_binding(binding, out_mode)?;
7150                        write!(self.out, "{}{} {}", back::INDENT, ty_name, name)?;
7151                        if let Some(array_len) = array_len {
7152                            write!(self.out, " [{array_len}]")?;
7153                        }
7154                        resolved.try_fmt(&mut self.out)?;
7155                        writeln!(self.out, ";")?;
7156                    }
7157
7158                    if pipeline_options.allow_and_force_point_size
7159                        && ep.stage == crate::ShaderStage::Vertex
7160                        && !has_point_size
7161                    {
7162                        // inject the point size output last
7163                        writeln!(
7164                            self.out,
7165                            "{}float _point_size [[point_size]];",
7166                            back::INDENT
7167                        )?;
7168                    }
7169                    writeln!(self.out, "}};")?;
7170                    &stage_out_name
7171                }
7172                None => "void",
7173            };
7174
7175            // If we're doing a vertex pulling transform, define the buffer
7176            // structure types.
7177            if do_vertex_pulling {
7178                for vbm in &vbm_resolved {
7179                    let buffer_stride = vbm.stride;
7180                    let buffer_ty = &vbm.ty_name;
7181
7182                    // Define a structure of bytes of the appropriate size.
7183                    // When we access the attributes, we'll be unpacking these
7184                    // bytes at some offset.
7185                    writeln!(
7186                        self.out,
7187                        "struct {buffer_ty} {{ metal::uchar data[{buffer_stride}]; }};"
7188                    )?;
7189                }
7190            }
7191
7192            // Write the entry point function's name, and begin its argument list.
7193            writeln!(self.out, "{em_str} {result_type_name} {fun_name}(")?;
7194
7195            let mut is_first_argument = true;
7196            let mut separator = || {
7197                if is_first_argument {
7198                    is_first_argument = false;
7199                    ' '
7200                } else {
7201                    ','
7202                }
7203            };
7204
7205            // If we have produced a struct holding the `EntryPoint`'s
7206            // `Function`'s arguments' varyings, pass that struct first.
7207            if has_varyings {
7208                writeln!(
7209                    self.out,
7210                    "{} {stage_in_name} {varyings_member_name} [[stage_in]]",
7211                    separator()
7212                )?;
7213            }
7214
7215            let mut local_invocation_id = None;
7216
7217            // Then pass the remaining arguments not included in the varyings
7218            // struct.
7219            for &(ref name_key, ty, binding) in flattened_arguments.iter() {
7220                let binding = match binding {
7221                    Some(&crate::Binding::BuiltIn(crate::BuiltIn::Barycentric { .. })) => continue,
7222                    Some(binding @ &crate::Binding::BuiltIn { .. }) => binding,
7223                    _ => continue,
7224                };
7225                let name = match *name_key {
7226                    NameKey::StructMember(..) => &flattened_member_names[name_key],
7227                    _ => &self.names[name_key],
7228                };
7229
7230                if binding == &crate::Binding::BuiltIn(crate::BuiltIn::LocalInvocationId) {
7231                    local_invocation_id = Some(name_key);
7232                }
7233
7234                let ty_name = TypeContext {
7235                    handle: ty,
7236                    gctx: module.to_ctx(),
7237                    names: &self.names,
7238                    access: crate::StorageAccess::empty(),
7239                    first_time: false,
7240                };
7241
7242                match *binding {
7243                    crate::Binding::BuiltIn(crate::BuiltIn::VertexIndex) => {
7244                        v_existing_id = Some(name.clone());
7245                    }
7246                    crate::Binding::BuiltIn(crate::BuiltIn::InstanceIndex) => {
7247                        i_existing_id = Some(name.clone());
7248                    }
7249                    _ => {}
7250                };
7251
7252                let resolved = options.resolve_local_binding(binding, in_mode)?;
7253                write!(self.out, "{} {ty_name} {name}", separator())?;
7254                resolved.try_fmt(&mut self.out)?;
7255                writeln!(self.out)?;
7256            }
7257
7258            let need_workgroup_variables_initialization =
7259                self.need_workgroup_variables_initialization(options, ep, module, fun_info);
7260
7261            if need_workgroup_variables_initialization && local_invocation_id.is_none() {
7262                writeln!(
7263                    self.out,
7264                    "{} {NAMESPACE}::uint3 __local_invocation_id [[thread_position_in_threadgroup]]", separator()
7265                )?;
7266            }
7267
7268            // Those global variables used by this entry point and its callees
7269            // get passed as arguments. `Private` globals are an exception, they
7270            // don't outlive this invocation, so we declare them below as locals
7271            // within the entry point.
7272            for (handle, var) in module.global_variables.iter() {
7273                let usage = fun_info[handle];
7274                if usage.is_empty() || var.space == crate::AddressSpace::Private {
7275                    continue;
7276                }
7277
7278                if options.lang_version < (1, 2) {
7279                    match var.space {
7280                        // This restriction is not documented in the MSL spec
7281                        // but validation will fail if it is not upheld.
7282                        //
7283                        // We infer the required version from the "Function
7284                        // Buffer Read-Writes" section of [what's new], where
7285                        // the feature sets listed correspond with the ones
7286                        // supporting MSL 1.2.
7287                        //
7288                        // [what's new]: https://developer.apple.com/library/archive/documentation/Miscellaneous/Conceptual/MetalProgrammingGuide/WhatsNewiniOS10tvOS10andOSX1012/WhatsNewiniOS10tvOS10andOSX1012.html
7289                        crate::AddressSpace::Storage { access }
7290                            if access.contains(crate::StorageAccess::STORE)
7291                                && ep.stage == crate::ShaderStage::Fragment =>
7292                        {
7293                            return Err(Error::UnsupportedWriteableStorageBuffer)
7294                        }
7295                        crate::AddressSpace::Handle => {
7296                            match module.types[var.ty].inner {
7297                                crate::TypeInner::Image {
7298                                    class: crate::ImageClass::Storage { access, .. },
7299                                    ..
7300                                } => {
7301                                    // This restriction is not documented in the MSL spec
7302                                    // but validation will fail if it is not upheld.
7303                                    //
7304                                    // We infer the required version from the "Function
7305                                    // Texture Read-Writes" section of [what's new], where
7306                                    // the feature sets listed correspond with the ones
7307                                    // supporting MSL 1.2.
7308                                    //
7309                                    // [what's new]: https://developer.apple.com/library/archive/documentation/Miscellaneous/Conceptual/MetalProgrammingGuide/WhatsNewiniOS10tvOS10andOSX1012/WhatsNewiniOS10tvOS10andOSX1012.html
7310                                    if access.contains(crate::StorageAccess::STORE)
7311                                        && (ep.stage == crate::ShaderStage::Vertex
7312                                            || ep.stage == crate::ShaderStage::Fragment)
7313                                    {
7314                                        return Err(Error::UnsupportedWriteableStorageTexture(
7315                                            ep.stage,
7316                                        ));
7317                                    }
7318
7319                                    if access.contains(
7320                                        crate::StorageAccess::LOAD | crate::StorageAccess::STORE,
7321                                    ) {
7322                                        return Err(Error::UnsupportedRWStorageTexture);
7323                                    }
7324                                }
7325                                _ => {}
7326                            }
7327                        }
7328                        _ => {}
7329                    }
7330                }
7331
7332                // Check min MSL version for binding arrays
7333                match var.space {
7334                    crate::AddressSpace::Handle => match module.types[var.ty].inner {
7335                        crate::TypeInner::BindingArray { base, .. } => {
7336                            match module.types[base].inner {
7337                                crate::TypeInner::Sampler { .. } => {
7338                                    if options.lang_version < (2, 0) {
7339                                        return Err(Error::UnsupportedArrayOf(
7340                                            "samplers".to_string(),
7341                                        ));
7342                                    }
7343                                }
7344                                crate::TypeInner::Image { class, .. } => match class {
7345                                    crate::ImageClass::Sampled { .. }
7346                                    | crate::ImageClass::Depth { .. }
7347                                    | crate::ImageClass::Storage {
7348                                        access: crate::StorageAccess::LOAD,
7349                                        ..
7350                                    } => {
7351                                        // Array of textures since:
7352                                        // - iOS: Metal 1.2 (check depends on https://github.com/gfx-rs/naga/issues/2164)
7353                                        // - macOS: Metal 2
7354
7355                                        if options.lang_version < (2, 0) {
7356                                            return Err(Error::UnsupportedArrayOf(
7357                                                "textures".to_string(),
7358                                            ));
7359                                        }
7360                                    }
7361                                    crate::ImageClass::Storage {
7362                                        access: crate::StorageAccess::STORE,
7363                                        ..
7364                                    } => {
7365                                        // Array of write-only textures since:
7366                                        // - iOS: Metal 2.2 (check depends on https://github.com/gfx-rs/naga/issues/2164)
7367                                        // - macOS: Metal 2
7368
7369                                        if options.lang_version < (2, 0) {
7370                                            return Err(Error::UnsupportedArrayOf(
7371                                                "write-only textures".to_string(),
7372                                            ));
7373                                        }
7374                                    }
7375                                    crate::ImageClass::Storage { .. } => {
7376                                        if options.lang_version < (3, 0) {
7377                                            return Err(Error::UnsupportedArrayOf(
7378                                                "read-write textures".to_string(),
7379                                            ));
7380                                        }
7381                                    }
7382                                    crate::ImageClass::External => {
7383                                        return Err(Error::UnsupportedArrayOf(
7384                                            "external textures".to_string(),
7385                                        ));
7386                                    }
7387                                },
7388                                _ => {
7389                                    return Err(Error::UnsupportedArrayOfType(base));
7390                                }
7391                            }
7392                        }
7393                        _ => {}
7394                    },
7395                    _ => {}
7396                }
7397
7398                // the resolves have already been checked for `!fake_missing_bindings` case
7399                let resolved = match var.space {
7400                    crate::AddressSpace::Immediate => options.resolve_immediates(ep).ok(),
7401                    crate::AddressSpace::WorkGroup => None,
7402                    _ => options
7403                        .resolve_resource_binding(ep, var.binding.as_ref().unwrap())
7404                        .ok(),
7405                };
7406                if let Some(ref resolved) = resolved {
7407                    // Inline samplers are be defined in the EP body
7408                    if resolved.as_inline_sampler(options).is_some() {
7409                        continue;
7410                    }
7411                }
7412
7413                match module.types[var.ty].inner {
7414                    crate::TypeInner::Image {
7415                        class: crate::ImageClass::External,
7416                        ..
7417                    } => {
7418                        // External texture global variables get lowered to 3 textures
7419                        // and a constant buffer. We must emit a separate argument for
7420                        // each of these.
7421                        let target = match resolved {
7422                            Some(back::msl::ResolvedBinding::Resource(target)) => {
7423                                target.external_texture
7424                            }
7425                            _ => None,
7426                        };
7427
7428                        for i in 0..3 {
7429                            write!(self.out, "{} ", separator())?;
7430
7431                            let plane_name = &self.names[&NameKey::ExternalTextureGlobalVariable(
7432                                handle,
7433                                ExternalTextureNameKey::Plane(i),
7434                            )];
7435                            write!(
7436                              self.out,
7437                              "{NAMESPACE}::texture2d<float, {NAMESPACE}::access::sample> {plane_name}"
7438                            )?;
7439                            if let Some(ref target) = target {
7440                                write!(self.out, " [[texture({})]]", target.planes[i])?;
7441                            }
7442                            writeln!(self.out)?;
7443                        }
7444                        let params_ty_name = &self.names
7445                            [&NameKey::Type(module.special_types.external_texture_params.unwrap())];
7446                        let params_name = &self.names[&NameKey::ExternalTextureGlobalVariable(
7447                            handle,
7448                            ExternalTextureNameKey::Params,
7449                        )];
7450                        write!(self.out, "{} ", separator())?;
7451                        write!(self.out, "constant {params_ty_name}& {params_name}")?;
7452                        if let Some(ref target) = target {
7453                            write!(self.out, " [[buffer({})]]", target.params)?;
7454                        }
7455                    }
7456                    _ => {
7457                        let tyvar = TypedGlobalVariable {
7458                            module,
7459                            names: &self.names,
7460                            handle,
7461                            usage,
7462                            reference: true,
7463                        };
7464                        write!(self.out, "{} ", separator())?;
7465                        tyvar.try_fmt(&mut self.out)?;
7466                        if let Some(resolved) = resolved {
7467                            resolved.try_fmt(&mut self.out)?;
7468                        }
7469                        if let Some(value) = var.init {
7470                            write!(self.out, " = ")?;
7471                            self.put_const_expression(
7472                                value,
7473                                module,
7474                                mod_info,
7475                                &module.global_expressions,
7476                            )?;
7477                        }
7478                    }
7479                }
7480                writeln!(self.out)?;
7481            }
7482
7483            if do_vertex_pulling {
7484                if needs_vertex_id && v_existing_id.is_none() {
7485                    // Write the [[vertex_id]] argument.
7486                    writeln!(self.out, "{} uint {v_id} [[vertex_id]]", separator())?;
7487                }
7488
7489                if needs_instance_id && i_existing_id.is_none() {
7490                    writeln!(self.out, "{} uint {i_id} [[instance_id]]", separator())?;
7491                }
7492
7493                // Iterate vbm_resolved, output one argument for every vertex buffer,
7494                // using the names we generated earlier.
7495                for vbm in &vbm_resolved {
7496                    let id = &vbm.id;
7497                    let ty_name = &vbm.ty_name;
7498                    let param_name = &vbm.param_name;
7499                    writeln!(
7500                        self.out,
7501                        "{} const device {ty_name}* {param_name} [[buffer({id})]]",
7502                        separator()
7503                    )?;
7504                }
7505            }
7506
7507            // If this entry uses any variable-length arrays, their sizes are
7508            // passed as a final struct-typed argument.
7509            if needs_buffer_sizes {
7510                // this is checked earlier
7511                let resolved = options.resolve_sizes_buffer(ep).unwrap();
7512                write!(
7513                    self.out,
7514                    "{} constant _mslBufferSizes& _buffer_sizes",
7515                    separator()
7516                )?;
7517                resolved.try_fmt(&mut self.out)?;
7518                writeln!(self.out)?;
7519            }
7520
7521            // end of the entry point argument list
7522            writeln!(self.out, ") {{")?;
7523
7524            // Starting the function body.
7525            if do_vertex_pulling {
7526                // Provide zero values for all the attributes, which we will overwrite with
7527                // real data from the vertex attribute buffers, if the indices are in-bounds.
7528                for vbm in &vbm_resolved {
7529                    for attribute in vbm.attributes {
7530                        let location = attribute.shader_location;
7531                        let am_option = am_resolved.get(&location);
7532                        if am_option.is_none() {
7533                            // This bound attribute isn't used in this entry point, so
7534                            // don't bother zero-initializing it.
7535                            continue;
7536                        }
7537                        let am = am_option.unwrap();
7538                        let attribute_ty_name = &am.ty_name;
7539                        let attribute_name = &am.name;
7540
7541                        writeln!(
7542                            self.out,
7543                            "{}{attribute_ty_name} {attribute_name} = {{}};",
7544                            back::Level(1)
7545                        )?;
7546                    }
7547
7548                    // Output a bounds check block that will set real values for the
7549                    // attributes, if the bounds are satisfied.
7550                    write!(self.out, "{}if (", back::Level(1))?;
7551
7552                    let idx = &vbm.id;
7553                    let stride = &vbm.stride;
7554                    let index_name = match vbm.step_mode {
7555                        back::msl::VertexBufferStepMode::Constant => "0",
7556                        back::msl::VertexBufferStepMode::ByVertex => {
7557                            if let Some(ref name) = v_existing_id {
7558                                name
7559                            } else {
7560                                &v_id
7561                            }
7562                        }
7563                        back::msl::VertexBufferStepMode::ByInstance => {
7564                            if let Some(ref name) = i_existing_id {
7565                                name
7566                            } else {
7567                                &i_id
7568                            }
7569                        }
7570                    };
7571                    write!(
7572                        self.out,
7573                        "{index_name} < (_buffer_sizes.buffer_size{idx} / {stride})"
7574                    )?;
7575
7576                    writeln!(self.out, ") {{")?;
7577
7578                    // Pull the bytes out of the vertex buffer.
7579                    let ty_name = &vbm.ty_name;
7580                    let elem_name = &vbm.elem_name;
7581                    let param_name = &vbm.param_name;
7582
7583                    writeln!(
7584                        self.out,
7585                        "{}const {ty_name} {elem_name} = {param_name}[{index_name}];",
7586                        back::Level(2),
7587                    )?;
7588
7589                    // Now set real values for each of the attributes, by unpacking the data
7590                    // from the buffer elements.
7591                    for attribute in vbm.attributes {
7592                        let location = attribute.shader_location;
7593                        let Some(am) = am_resolved.get(&location) else {
7594                            // This bound attribute isn't used in this entry point, so
7595                            // don't bother extracting the data. Too bad we emitted the
7596                            // unpacking function earlier -- it might not get used.
7597                            continue;
7598                        };
7599                        let attribute_name = &am.name;
7600                        let attribute_ty_name = &am.ty_name;
7601
7602                        let offset = attribute.offset;
7603                        let func = unpacking_functions
7604                            .get(&attribute.format)
7605                            .expect("Should have generated this unpacking function earlier.");
7606                        let func_name = &func.name;
7607
7608                        // Check dimensionality of the attribute compared to the unpacking
7609                        // function. If attribute dimension > unpack dimension, we have to
7610                        // pad out the unpack value from a vec4(0, 0, 0, 1) of matching
7611                        // scalar type. Otherwise, if attribute dimension is < unpack
7612                        // dimension, then we need to explicitly truncate the result.
7613
7614                        let needs_padding_or_truncation = am.dimension.cmp(&func.dimension);
7615
7616                        if needs_padding_or_truncation != Ordering::Equal {
7617                            // Emit a comment flagging that a conversion is happening,
7618                            // since the actual logic can be at the end of a long line.
7619                            writeln!(
7620                                self.out,
7621                                "{}// {attribute_ty_name} <- {:?}",
7622                                back::Level(2),
7623                                attribute.format
7624                            )?;
7625                        }
7626
7627                        write!(self.out, "{}{attribute_name} = ", back::Level(2),)?;
7628
7629                        if needs_padding_or_truncation == Ordering::Greater {
7630                            // Needs padding: emit constructor call for wider type
7631                            write!(self.out, "{attribute_ty_name}(")?;
7632                        }
7633
7634                        // Emit call to unpacking function
7635                        write!(self.out, "{func_name}({elem_name}.data[{offset}]",)?;
7636                        for i in (offset + 1)..(offset + func.byte_count) {
7637                            write!(self.out, ", {elem_name}.data[{i}]")?;
7638                        }
7639                        write!(self.out, ")")?;
7640
7641                        match needs_padding_or_truncation {
7642                            Ordering::Greater => {
7643                                // Padding
7644                                let zero_value = if am.ty_is_int { "0" } else { "0.0" };
7645                                let one_value = if am.ty_is_int { "1" } else { "1.0" };
7646                                for i in func.dimension..am.dimension {
7647                                    write!(
7648                                        self.out,
7649                                        ", {}",
7650                                        if i == 3 { one_value } else { zero_value }
7651                                    )?;
7652                                }
7653                                write!(self.out, ")")?;
7654                            }
7655                            Ordering::Less => {
7656                                // Truncate to the first `am.dimension` components
7657                                write!(
7658                                    self.out,
7659                                    ".{}",
7660                                    &"xyzw"[0..usize::try_from(am.dimension).unwrap()]
7661                                )?;
7662                            }
7663                            Ordering::Equal => {}
7664                        }
7665
7666                        writeln!(self.out, ";")?;
7667                    }
7668
7669                    // End the bounds check / attribute setting block.
7670                    writeln!(self.out, "{}}}", back::Level(1))?;
7671                }
7672            }
7673
7674            if need_workgroup_variables_initialization {
7675                self.write_workgroup_variables_initialization(
7676                    module,
7677                    mod_info,
7678                    fun_info,
7679                    local_invocation_id,
7680                )?;
7681            }
7682
7683            // Metal doesn't support private mutable variables outside of functions,
7684            // so we put them here, just like the locals.
7685            for (handle, var) in module.global_variables.iter() {
7686                let usage = fun_info[handle];
7687                if usage.is_empty() {
7688                    continue;
7689                }
7690                if var.space == crate::AddressSpace::Private {
7691                    let tyvar = TypedGlobalVariable {
7692                        module,
7693                        names: &self.names,
7694                        handle,
7695                        usage,
7696
7697                        reference: false,
7698                    };
7699                    write!(self.out, "{}", back::INDENT)?;
7700                    tyvar.try_fmt(&mut self.out)?;
7701                    match var.init {
7702                        Some(value) => {
7703                            write!(self.out, " = ")?;
7704                            self.put_const_expression(
7705                                value,
7706                                module,
7707                                mod_info,
7708                                &module.global_expressions,
7709                            )?;
7710                            writeln!(self.out, ";")?;
7711                        }
7712                        None => {
7713                            writeln!(self.out, " = {{}};")?;
7714                        }
7715                    };
7716                } else if let Some(ref binding) = var.binding {
7717                    let resolved = options.resolve_resource_binding(ep, binding).unwrap();
7718                    if let Some(sampler) = resolved.as_inline_sampler(options) {
7719                        // write an inline sampler
7720                        let name = &self.names[&NameKey::GlobalVariable(handle)];
7721                        writeln!(
7722                            self.out,
7723                            "{}constexpr {}::sampler {}(",
7724                            back::INDENT,
7725                            NAMESPACE,
7726                            name
7727                        )?;
7728                        self.put_inline_sampler_properties(back::Level(2), sampler)?;
7729                        writeln!(self.out, "{});", back::INDENT)?;
7730                    } else if let crate::TypeInner::Image {
7731                        class: crate::ImageClass::External,
7732                        ..
7733                    } = module.types[var.ty].inner
7734                    {
7735                        // Wrap the individual arguments for each external texture global
7736                        // in a struct which can be easily passed around.
7737                        let wrapper_name = &self.names[&NameKey::GlobalVariable(handle)];
7738                        let l1 = back::Level(1);
7739                        let l2 = l1.next();
7740                        writeln!(
7741                            self.out,
7742                            "{l1}const {EXTERNAL_TEXTURE_WRAPPER_STRUCT} {wrapper_name} {{"
7743                        )?;
7744                        for i in 0..3 {
7745                            let plane_name = &self.names[&NameKey::ExternalTextureGlobalVariable(
7746                                handle,
7747                                ExternalTextureNameKey::Plane(i),
7748                            )];
7749                            writeln!(self.out, "{l2}.plane{i} = {plane_name},")?;
7750                        }
7751                        let params_name = &self.names[&NameKey::ExternalTextureGlobalVariable(
7752                            handle,
7753                            ExternalTextureNameKey::Params,
7754                        )];
7755                        writeln!(self.out, "{l2}.params = {params_name},")?;
7756                        writeln!(self.out, "{l1}}};")?;
7757                    }
7758                }
7759            }
7760
7761            // Now take the arguments that we gathered into structs, and the
7762            // structs that we flattened into arguments, and emit local
7763            // variables with initializers that put everything back the way the
7764            // body code expects.
7765            //
7766            // If we had to generate fresh names for struct members passed as
7767            // arguments, be sure to use those names when rebuilding the struct.
7768            //
7769            // "Each day, I change some zeros to ones, and some ones to zeros.
7770            // The rest, I leave alone."
7771            for (arg_index, arg) in fun.arguments.iter().enumerate() {
7772                let arg_name =
7773                    &self.names[&NameKey::EntryPointArgument(ep_index as _, arg_index as u32)];
7774                match module.types[arg.ty].inner {
7775                    crate::TypeInner::Struct { ref members, .. } => {
7776                        let struct_name = &self.names[&NameKey::Type(arg.ty)];
7777                        write!(
7778                            self.out,
7779                            "{}const {} {} = {{ ",
7780                            back::INDENT,
7781                            struct_name,
7782                            arg_name
7783                        )?;
7784                        for (member_index, member) in members.iter().enumerate() {
7785                            let key = NameKey::StructMember(arg.ty, member_index as u32);
7786                            let name = &flattened_member_names[&key];
7787                            if member_index != 0 {
7788                                write!(self.out, ", ")?;
7789                            }
7790                            // insert padding initialization, if needed
7791                            if self
7792                                .struct_member_pads
7793                                .contains(&(arg.ty, member_index as u32))
7794                            {
7795                                write!(self.out, "{{}}, ")?;
7796                            }
7797                            if let Some(crate::Binding::Location { .. }) = member.binding {
7798                                if has_varyings {
7799                                    write!(self.out, "{varyings_member_name}.")?;
7800                                }
7801                            }
7802                            write!(self.out, "{name}")?;
7803                        }
7804                        writeln!(self.out, " }};")?;
7805                    }
7806                    _ => match arg.binding {
7807                        Some(crate::Binding::Location { .. })
7808                        | Some(crate::Binding::BuiltIn(crate::BuiltIn::Barycentric { .. })) => {
7809                            if has_varyings {
7810                                writeln!(
7811                                    self.out,
7812                                    "{}const auto {} = {}.{};",
7813                                    back::INDENT,
7814                                    arg_name,
7815                                    varyings_member_name,
7816                                    arg_name
7817                                )?;
7818                            }
7819                        }
7820                        _ => {}
7821                    },
7822                }
7823            }
7824
7825            let guarded_indices =
7826                index::find_checked_indexes(module, fun, fun_info, options.bounds_check_policies);
7827
7828            let context = StatementContext {
7829                expression: ExpressionContext {
7830                    function: fun,
7831                    origin: FunctionOrigin::EntryPoint(ep_index as _),
7832                    info: fun_info,
7833                    lang_version: options.lang_version,
7834                    policies: options.bounds_check_policies,
7835                    guarded_indices,
7836                    module,
7837                    mod_info,
7838                    pipeline_options,
7839                    force_loop_bounding: options.force_loop_bounding,
7840                },
7841                result_struct: Some(&stage_out_name),
7842            };
7843
7844            // Finally, declare all the local variables that we need
7845            //TODO: we can postpone this till the relevant expressions are emitted
7846            self.put_locals(&context.expression)?;
7847            self.update_expressions_to_bake(fun, fun_info, &context.expression);
7848            self.put_block(back::Level(1), &fun.body, &context)?;
7849            writeln!(self.out, "}}")?;
7850            if ep_index + 1 != module.entry_points.len() {
7851                writeln!(self.out)?;
7852            }
7853            self.named_expressions.clear();
7854        }
7855
7856        Ok(info)
7857    }
7858
7859    fn write_barrier(&mut self, flags: crate::Barrier, level: back::Level) -> BackendResult {
7860        // Note: OR-ring bitflags requires `__HAVE_MEMFLAG_OPERATORS__`,
7861        // so we try to avoid it here.
7862        if flags.is_empty() {
7863            writeln!(
7864                self.out,
7865                "{level}{NAMESPACE}::threadgroup_barrier({NAMESPACE}::mem_flags::mem_none);",
7866            )?;
7867        }
7868        if flags.contains(crate::Barrier::STORAGE) {
7869            writeln!(
7870                self.out,
7871                "{level}{NAMESPACE}::threadgroup_barrier({NAMESPACE}::mem_flags::mem_device);",
7872            )?;
7873        }
7874        if flags.contains(crate::Barrier::WORK_GROUP) {
7875            writeln!(
7876                self.out,
7877                "{level}{NAMESPACE}::threadgroup_barrier({NAMESPACE}::mem_flags::mem_threadgroup);",
7878            )?;
7879        }
7880        if flags.contains(crate::Barrier::SUB_GROUP) {
7881            writeln!(
7882                self.out,
7883                "{level}{NAMESPACE}::simdgroup_barrier({NAMESPACE}::mem_flags::mem_threadgroup);",
7884            )?;
7885        }
7886        if flags.contains(crate::Barrier::TEXTURE) {
7887            writeln!(
7888                self.out,
7889                "{level}{NAMESPACE}::threadgroup_barrier({NAMESPACE}::mem_flags::mem_texture);",
7890            )?;
7891        }
7892        Ok(())
7893    }
7894}
7895
7896/// Initializing workgroup variables is more tricky for Metal because we have to deal
7897/// with atomics at the type-level (which don't have a copy constructor).
7898mod workgroup_mem_init {
7899    use crate::EntryPoint;
7900
7901    use super::*;
7902
7903    enum Access {
7904        GlobalVariable(Handle<crate::GlobalVariable>),
7905        StructMember(Handle<crate::Type>, u32),
7906        Array(usize),
7907    }
7908
7909    impl Access {
7910        fn write<W: Write>(
7911            &self,
7912            writer: &mut W,
7913            names: &FastHashMap<NameKey, String>,
7914        ) -> Result<(), core::fmt::Error> {
7915            match *self {
7916                Access::GlobalVariable(handle) => {
7917                    write!(writer, "{}", &names[&NameKey::GlobalVariable(handle)])
7918                }
7919                Access::StructMember(handle, index) => {
7920                    write!(writer, ".{}", &names[&NameKey::StructMember(handle, index)])
7921                }
7922                Access::Array(depth) => write!(writer, ".{WRAPPED_ARRAY_FIELD}[__i{depth}]"),
7923            }
7924        }
7925    }
7926
7927    struct AccessStack {
7928        stack: Vec<Access>,
7929        array_depth: usize,
7930    }
7931
7932    impl AccessStack {
7933        const fn new() -> Self {
7934            Self {
7935                stack: Vec::new(),
7936                array_depth: 0,
7937            }
7938        }
7939
7940        fn enter_array<R>(&mut self, cb: impl FnOnce(&mut Self, usize) -> R) -> R {
7941            let array_depth = self.array_depth;
7942            self.stack.push(Access::Array(array_depth));
7943            self.array_depth += 1;
7944            let res = cb(self, array_depth);
7945            self.stack.pop();
7946            self.array_depth -= 1;
7947            res
7948        }
7949
7950        fn enter<R>(&mut self, new: Access, cb: impl FnOnce(&mut Self) -> R) -> R {
7951            self.stack.push(new);
7952            let res = cb(self);
7953            self.stack.pop();
7954            res
7955        }
7956
7957        fn write<W: Write>(
7958            &self,
7959            writer: &mut W,
7960            names: &FastHashMap<NameKey, String>,
7961        ) -> Result<(), core::fmt::Error> {
7962            for next in self.stack.iter() {
7963                next.write(writer, names)?;
7964            }
7965            Ok(())
7966        }
7967    }
7968
7969    impl<W: Write> Writer<W> {
7970        pub(super) fn need_workgroup_variables_initialization(
7971            &mut self,
7972            options: &Options,
7973            ep: &EntryPoint,
7974            module: &crate::Module,
7975            fun_info: &valid::FunctionInfo,
7976        ) -> bool {
7977            options.zero_initialize_workgroup_memory
7978                && ep.stage.compute_like()
7979                && module.global_variables.iter().any(|(handle, var)| {
7980                    !fun_info[handle].is_empty() && var.space == crate::AddressSpace::WorkGroup
7981                })
7982        }
7983
7984        pub(super) fn write_workgroup_variables_initialization(
7985            &mut self,
7986            module: &crate::Module,
7987            module_info: &valid::ModuleInfo,
7988            fun_info: &valid::FunctionInfo,
7989            local_invocation_id: Option<&NameKey>,
7990        ) -> BackendResult {
7991            let level = back::Level(1);
7992
7993            writeln!(
7994                self.out,
7995                "{}if ({}::all({} == {}::uint3(0u))) {{",
7996                level,
7997                NAMESPACE,
7998                local_invocation_id
7999                    .map(|name_key| self.names[name_key].as_str())
8000                    .unwrap_or("__local_invocation_id"),
8001                NAMESPACE,
8002            )?;
8003
8004            let mut access_stack = AccessStack::new();
8005
8006            let vars = module.global_variables.iter().filter(|&(handle, var)| {
8007                !fun_info[handle].is_empty() && var.space == crate::AddressSpace::WorkGroup
8008            });
8009
8010            for (handle, var) in vars {
8011                access_stack.enter(Access::GlobalVariable(handle), |access_stack| {
8012                    self.write_workgroup_variable_initialization(
8013                        module,
8014                        module_info,
8015                        var.ty,
8016                        access_stack,
8017                        level.next(),
8018                    )
8019                })?;
8020            }
8021
8022            writeln!(self.out, "{level}}}")?;
8023            self.write_barrier(crate::Barrier::WORK_GROUP, level)
8024        }
8025
8026        fn write_workgroup_variable_initialization(
8027            &mut self,
8028            module: &crate::Module,
8029            module_info: &valid::ModuleInfo,
8030            ty: Handle<crate::Type>,
8031            access_stack: &mut AccessStack,
8032            level: back::Level,
8033        ) -> BackendResult {
8034            if module_info[ty].contains(valid::TypeFlags::CONSTRUCTIBLE) {
8035                write!(self.out, "{level}")?;
8036                access_stack.write(&mut self.out, &self.names)?;
8037                writeln!(self.out, " = {{}};")?;
8038            } else {
8039                match module.types[ty].inner {
8040                    crate::TypeInner::Atomic { .. } => {
8041                        write!(
8042                            self.out,
8043                            "{level}{NAMESPACE}::atomic_store_explicit({ATOMIC_REFERENCE}"
8044                        )?;
8045                        access_stack.write(&mut self.out, &self.names)?;
8046                        writeln!(self.out, ", 0, {NAMESPACE}::memory_order_relaxed);")?;
8047                    }
8048                    crate::TypeInner::Array { base, size, .. } => {
8049                        let count = match size.resolve(module.to_ctx())? {
8050                            proc::IndexableLength::Known(count) => count,
8051                            proc::IndexableLength::Dynamic => unreachable!(),
8052                        };
8053
8054                        access_stack.enter_array(|access_stack, array_depth| {
8055                            writeln!(
8056                                self.out,
8057                                "{level}for (int __i{array_depth} = 0; __i{array_depth} < {count}; __i{array_depth}++) {{"
8058                            )?;
8059                            self.write_workgroup_variable_initialization(
8060                                module,
8061                                module_info,
8062                                base,
8063                                access_stack,
8064                                level.next(),
8065                            )?;
8066                            writeln!(self.out, "{level}}}")?;
8067                            BackendResult::Ok(())
8068                        })?;
8069                    }
8070                    crate::TypeInner::Struct { ref members, .. } => {
8071                        for (index, member) in members.iter().enumerate() {
8072                            access_stack.enter(
8073                                Access::StructMember(ty, index as u32),
8074                                |access_stack| {
8075                                    self.write_workgroup_variable_initialization(
8076                                        module,
8077                                        module_info,
8078                                        member.ty,
8079                                        access_stack,
8080                                        level,
8081                                    )
8082                                },
8083                            )?;
8084                        }
8085                    }
8086                    _ => unreachable!(),
8087                }
8088            }
8089
8090            Ok(())
8091        }
8092    }
8093}
8094
8095impl crate::AtomicFunction {
8096    const fn to_msl(self) -> &'static str {
8097        match self {
8098            Self::Add => "fetch_add",
8099            Self::Subtract => "fetch_sub",
8100            Self::And => "fetch_and",
8101            Self::InclusiveOr => "fetch_or",
8102            Self::ExclusiveOr => "fetch_xor",
8103            Self::Min => "fetch_min",
8104            Self::Max => "fetch_max",
8105            Self::Exchange { compare: None } => "exchange",
8106            Self::Exchange { compare: Some(_) } => ATOMIC_COMP_EXCH_FUNCTION,
8107        }
8108    }
8109
8110    fn to_msl_64_bit(self) -> Result<&'static str, Error> {
8111        Ok(match self {
8112            Self::Min => "min",
8113            Self::Max => "max",
8114            _ => Err(Error::FeatureNotImplemented(
8115                "64-bit atomic operation other than min/max".to_string(),
8116            ))?,
8117        })
8118    }
8119}