naga/back/msl/
writer.rs

1use alloc::{
2    format,
3    string::{String, ToString},
4    vec,
5    vec::Vec,
6};
7use core::{
8    cmp::Ordering,
9    fmt::{Display, Error as FmtError, Formatter, Write},
10    iter,
11};
12use num_traits::real::Real as _;
13
14use half::f16;
15
16use super::{
17    ray::RT_NAMESPACE, sampler as sm, Error, LocationMode, Options, PipelineOptions,
18    TranslationInfo, NAMESPACE, WRAPPED_ARRAY_FIELD,
19};
20use crate::{
21    arena::{Handle, HandleSet},
22    back::{
23        self, get_entry_points,
24        msl::{mesh_shader::NestedFunctionInfo, BackendResult, EntryPointArgument},
25        Baked,
26    },
27    common,
28    proc::{
29        self, concrete_int_scalars,
30        index::{self, BoundsCheck},
31        ExternalTextureNameKey, NameKey, TypeResolution,
32    },
33    valid, FastHashMap, FastHashSet,
34};
35
36// This is a hack: we need to pass a pointer to an atomic,
37// but generally the backend isn't putting "&" in front of every pointer.
38// Some more general handling of pointers is needed to be implemented here.
39const ATOMIC_REFERENCE: &str = "&";
40
41pub(crate) const ATOMIC_COMP_EXCH_FUNCTION: &str = "naga_atomic_compare_exchange_weak_explicit";
42pub(crate) const MODF_FUNCTION: &str = "naga_modf";
43pub(crate) const FREXP_FUNCTION: &str = "naga_frexp";
44pub(crate) const ABS_FUNCTION: &str = "naga_abs";
45pub(crate) const DIV_FUNCTION: &str = "naga_div";
46pub(crate) const DOT_FUNCTION_PREFIX: &str = "naga_dot";
47pub(crate) const MOD_FUNCTION: &str = "naga_mod";
48pub(crate) const NEG_FUNCTION: &str = "naga_neg";
49pub(crate) const F2I32_FUNCTION: &str = "naga_f2i32";
50pub(crate) const F2U32_FUNCTION: &str = "naga_f2u32";
51pub(crate) const F2I64_FUNCTION: &str = "naga_f2i64";
52pub(crate) const F2U64_FUNCTION: &str = "naga_f2u64";
53pub(crate) const IMAGE_LOAD_EXTERNAL_FUNCTION: &str = "nagaTextureLoadExternal";
54pub(crate) const IMAGE_SIZE_EXTERNAL_FUNCTION: &str = "nagaTextureDimensionsExternal";
55pub(crate) const IMAGE_SAMPLE_BASE_CLAMP_TO_EDGE_FUNCTION: &str =
56    "nagaTextureSampleBaseClampToEdge";
57/// For some reason, Metal does not let you have `metal::texture<..>*` as a buffer argument.
58/// However, if you put that texture inside a struct, everything is totally fine. This
59/// baffles me to no end.
60///
61/// As such, we wrap all argument buffers in a struct that has a single generic `<T>` field.
62/// This allows `NagaArgumentBufferWrapper<metal::texture<..>>*` to work. The astute among
63/// you have noticed that this should be exactly the same to the compiler, and you're correct.
64pub(crate) const ARGUMENT_BUFFER_WRAPPER_STRUCT: &str = "NagaArgumentBufferWrapper";
65/// Name of the struct that is declared to wrap the 3 textures and parameters
66/// buffer that [`crate::ImageClass::External`] variables are lowered to,
67/// allowing them to be conveniently passed to user-defined or wrapper
68/// functions. The struct is declared in [`Writer::write_type_defs`].
69pub(crate) const EXTERNAL_TEXTURE_WRAPPER_STRUCT: &str = "NagaExternalTextureWrapper";
70pub(crate) const COOPERATIVE_LOAD_FUNCTION: &str = "NagaCooperativeLoad";
71pub(crate) const COOPERATIVE_MULTIPLY_ADD_FUNCTION: &str = "NagaCooperativeMultiplyAdd";
72
73/// Write the Metal name for a Naga numeric type: scalar, vector, or matrix.
74///
75/// The `sizes` slice determines whether this function writes a
76/// scalar, vector, or matrix type:
77///
78/// - An empty slice produces a scalar type.
79/// - A one-element slice produces a vector type.
80/// - A two element slice `[ROWS COLUMNS]` produces a matrix of the given size.
81fn put_numeric_type(
82    out: &mut impl Write,
83    scalar: crate::Scalar,
84    sizes: &[crate::VectorSize],
85) -> Result<(), FmtError> {
86    match (scalar, sizes) {
87        (scalar, &[]) => {
88            write!(out, "{}", scalar.to_msl_name())
89        }
90        (scalar, &[rows]) => {
91            write!(
92                out,
93                "{}::{}{}",
94                NAMESPACE,
95                scalar.to_msl_name(),
96                common::vector_size_str(rows)
97            )
98        }
99        (scalar, &[rows, columns]) => {
100            write!(
101                out,
102                "{}::{}{}x{}",
103                NAMESPACE,
104                scalar.to_msl_name(),
105                common::vector_size_str(columns),
106                common::vector_size_str(rows)
107            )
108        }
109        (_, _) => Ok(()), // not meaningful
110    }
111}
112
113const fn scalar_is_int(scalar: crate::Scalar) -> bool {
114    use crate::ScalarKind::*;
115    match scalar.kind {
116        Sint | Uint | AbstractInt | Bool => true,
117        Float | AbstractFloat => false,
118    }
119}
120
121/// Prefix for cached clamped level-of-detail values for `ImageLoad` expressions.
122const CLAMPED_LOD_LOAD_PREFIX: &str = "clamped_lod_e";
123
124/// Prefix for reinterpreted expressions using `as_type<T>(...)`.
125const REINTERPRET_PREFIX: &str = "reinterpreted_";
126
127/// Wrapper for identifier names for clamped level-of-detail values
128///
129/// Values of this type implement [`core::fmt::Display`], formatting as
130/// the name of the variable used to hold the cached clamped
131/// level-of-detail value for an `ImageLoad` expression.
132struct ClampedLod(Handle<crate::Expression>);
133
134impl Display for ClampedLod {
135    fn fmt(&self, f: &mut Formatter<'_>) -> core::fmt::Result {
136        self.0.write_prefixed(f, CLAMPED_LOD_LOAD_PREFIX)
137    }
138}
139
140/// Wrapper for generating `struct _mslBufferSizes` member names for
141/// runtime-sized array lengths.
142///
143/// On Metal, `wgpu_hal` passes the element counts for all runtime-sized arrays
144/// as an argument to the entry point. This argument's type in the MSL is
145/// `struct _mslBufferSizes`, a Naga-synthesized struct with a `uint` member for
146/// each global variable containing a runtime-sized array.
147///
148/// If `global` is a [`Handle`] for a [`GlobalVariable`] that contains a
149/// runtime-sized array, then the value `ArraySize(global)` implements
150/// [`core::fmt::Display`], formatting as the name of the struct member carrying
151/// the number of elements in that runtime-sized array.
152///
153/// [`GlobalVariable`]: crate::GlobalVariable
154struct ArraySizeMember(Handle<crate::GlobalVariable>);
155
156impl Display for ArraySizeMember {
157    fn fmt(&self, f: &mut Formatter<'_>) -> core::fmt::Result {
158        self.0.write_prefixed(f, "size")
159    }
160}
161
162/// Wrapper for reinterpreted variables using `as_type<target_type>(orig)`.
163///
164/// Implements [`core::fmt::Display`], formatting as a name derived from
165/// `target_type` and the variable name of `orig`.
166#[derive(Clone, Copy)]
167struct Reinterpreted<'a> {
168    target_type: &'a str,
169    orig: Handle<crate::Expression>,
170}
171
172impl<'a> Reinterpreted<'a> {
173    const fn new(target_type: &'a str, orig: Handle<crate::Expression>) -> Self {
174        Self { target_type, orig }
175    }
176}
177
178impl Display for Reinterpreted<'_> {
179    fn fmt(&self, f: &mut Formatter<'_>) -> core::fmt::Result {
180        f.write_str(REINTERPRET_PREFIX)?;
181        f.write_str(self.target_type)?;
182        self.orig.write_prefixed(f, "_e")
183    }
184}
185
186pub(super) struct TypeContext<'a> {
187    pub handle: Handle<crate::Type>,
188    pub gctx: proc::GlobalCtx<'a>,
189    pub names: &'a FastHashMap<NameKey, String>,
190    pub access: crate::StorageAccess,
191    pub first_time: bool,
192}
193
194impl TypeContext<'_> {
195    fn scalar(&self) -> Option<crate::Scalar> {
196        let ty = &self.gctx.types[self.handle];
197        ty.inner.scalar()
198    }
199
200    fn vector_size(&self) -> Option<crate::VectorSize> {
201        let ty = &self.gctx.types[self.handle];
202        match ty.inner {
203            crate::TypeInner::Vector { size, .. } => Some(size),
204            _ => None,
205        }
206    }
207
208    fn unwrap_array(self) -> Self {
209        match self.gctx.types[self.handle].inner {
210            crate::TypeInner::Array { base, .. } => Self {
211                handle: base,
212                ..self
213            },
214            _ => self,
215        }
216    }
217}
218
219impl Display for TypeContext<'_> {
220    fn fmt(&self, out: &mut Formatter<'_>) -> Result<(), FmtError> {
221        let ty = &self.gctx.types[self.handle];
222        if ty.needs_alias() && !self.first_time {
223            let name = &self.names[&NameKey::Type(self.handle)];
224            return write!(out, "{name}");
225        }
226
227        match ty.inner {
228            crate::TypeInner::Scalar(scalar) => put_numeric_type(out, scalar, &[]),
229            crate::TypeInner::Atomic(scalar) => {
230                write!(out, "{}::atomic_{}", NAMESPACE, scalar.to_msl_name())
231            }
232            crate::TypeInner::Vector { size, scalar } => put_numeric_type(out, scalar, &[size]),
233            crate::TypeInner::Matrix {
234                columns,
235                rows,
236                scalar,
237            } => put_numeric_type(out, scalar, &[rows, columns]),
238            // Requires Metal-2.3
239            crate::TypeInner::CooperativeMatrix {
240                columns,
241                rows,
242                scalar,
243                role: _,
244            } => {
245                write!(
246                    out,
247                    "{NAMESPACE}::simdgroup_{}{}x{}",
248                    scalar.to_msl_name(),
249                    columns as u32,
250                    rows as u32,
251                )
252            }
253            crate::TypeInner::Pointer { base, space } => {
254                let sub = Self {
255                    handle: base,
256                    first_time: false,
257                    ..*self
258                };
259                let space_name = match space.to_msl_name() {
260                    Some(name) => name,
261                    None => return Ok(()),
262                };
263                write!(out, "{space_name} {sub}&")
264            }
265            crate::TypeInner::ValuePointer {
266                size,
267                scalar,
268                space,
269            } => {
270                match space.to_msl_name() {
271                    Some(name) => write!(out, "{name} ")?,
272                    None => return Ok(()),
273                };
274                match size {
275                    Some(rows) => put_numeric_type(out, scalar, &[rows])?,
276                    None => put_numeric_type(out, scalar, &[])?,
277                };
278
279                write!(out, "&")
280            }
281            crate::TypeInner::Array { base, .. } => {
282                let sub = Self {
283                    handle: base,
284                    first_time: false,
285                    ..*self
286                };
287                // Array lengths go at the end of the type definition,
288                // so just print the element type here.
289                write!(out, "{sub}")
290            }
291            crate::TypeInner::Struct { .. } => unreachable!(),
292            crate::TypeInner::Image {
293                dim,
294                arrayed,
295                class,
296            } => {
297                let dim_str = match dim {
298                    crate::ImageDimension::D1 => "1d",
299                    crate::ImageDimension::D2 => "2d",
300                    crate::ImageDimension::D3 => "3d",
301                    crate::ImageDimension::Cube => "cube",
302                };
303                let (texture_str, msaa_str, scalar, access) = match class {
304                    crate::ImageClass::Sampled { kind, multi } => {
305                        let (msaa_str, access) = if multi {
306                            ("_ms", "read")
307                        } else {
308                            ("", "sample")
309                        };
310                        let scalar = crate::Scalar { kind, width: 4 };
311                        ("texture", msaa_str, scalar, access)
312                    }
313                    crate::ImageClass::Depth { multi } => {
314                        let (msaa_str, access) = if multi {
315                            ("_ms", "read")
316                        } else {
317                            ("", "sample")
318                        };
319                        let scalar = crate::Scalar {
320                            kind: crate::ScalarKind::Float,
321                            width: 4,
322                        };
323                        ("depth", msaa_str, scalar, access)
324                    }
325                    crate::ImageClass::Storage { format, .. } => {
326                        let access = if self
327                            .access
328                            .contains(crate::StorageAccess::LOAD | crate::StorageAccess::STORE)
329                        {
330                            "read_write"
331                        } else if self.access.contains(crate::StorageAccess::STORE) {
332                            "write"
333                        } else if self.access.contains(crate::StorageAccess::LOAD) {
334                            "read"
335                        } else {
336                            log::warn!(
337                                "Storage access for {:?} (name '{}'): {:?}",
338                                self.handle,
339                                ty.name.as_deref().unwrap_or_default(),
340                                self.access
341                            );
342                            unreachable!("module is not valid");
343                        };
344                        ("texture", "", format.into(), access)
345                    }
346                    crate::ImageClass::External => {
347                        return write!(out, "{EXTERNAL_TEXTURE_WRAPPER_STRUCT}");
348                    }
349                };
350                let base_name = scalar.to_msl_name();
351                let array_str = if arrayed { "_array" } else { "" };
352                write!(
353                    out,
354                    "{NAMESPACE}::{texture_str}{dim_str}{msaa_str}{array_str}<{base_name}, {NAMESPACE}::access::{access}>",
355                )
356            }
357            crate::TypeInner::Sampler { comparison: _ } => {
358                write!(out, "{NAMESPACE}::sampler")
359            }
360            crate::TypeInner::AccelerationStructure { vertex_return } => {
361                if vertex_return {
362                    unimplemented!("metal does not support vertex ray hit return")
363                }
364                write!(out, "{RT_NAMESPACE}::instance_acceleration_structure")
365            }
366            crate::TypeInner::RayQuery { vertex_return } => {
367                if vertex_return {
368                    unimplemented!("metal does not support vertex ray hit return")
369                }
370                write!(out, "{}", super::ray::metal_intersector_ty())
371            }
372            crate::TypeInner::BindingArray { base, .. } => {
373                let base_inner = &self.gctx.types[base].inner;
374                let base_tyname = Self {
375                    handle: base,
376                    first_time: false,
377                    ..*self
378                };
379                match *base_inner {
380                    crate::TypeInner::Struct { .. } => {
381                        // Buffers in a binding array are pointers declared as `device T*`, so members use `->`.
382                        // Textures and samplers stay as plain values inside the wrapper.
383                        write!(
384                            out,
385                            "device {ARGUMENT_BUFFER_WRAPPER_STRUCT}<device {base_tyname}*>*"
386                        )
387                    }
388                    _ => {
389                        write!(
390                            out,
391                            "constant {ARGUMENT_BUFFER_WRAPPER_STRUCT}<{base_tyname}>*"
392                        )
393                    }
394                }
395            }
396        }
397    }
398}
399
400pub(super) struct TypedGlobalVariable<'a> {
401    pub module: &'a crate::Module,
402    pub names: &'a FastHashMap<NameKey, String>,
403    pub handle: Handle<crate::GlobalVariable>,
404    pub usage: valid::GlobalUse,
405    pub reference: bool,
406}
407
408struct TypedGlobalVariableParts {
409    ty_name: String,
410    var_name: String,
411}
412
413impl TypedGlobalVariable<'_> {
414    fn to_parts(&self) -> Result<TypedGlobalVariableParts, Error> {
415        let var = &self.module.global_variables[self.handle];
416        let name = &self.names[&NameKey::GlobalVariable(self.handle)];
417
418        let storage_access = match var.space {
419            crate::AddressSpace::Storage { access } => access,
420            _ => match self.module.types[var.ty].inner {
421                crate::TypeInner::Image {
422                    class: crate::ImageClass::Storage { access, .. },
423                    ..
424                } => access,
425                crate::TypeInner::BindingArray { base, .. } => {
426                    match self.module.types[base].inner {
427                        crate::TypeInner::Image {
428                            class: crate::ImageClass::Storage { access, .. },
429                            ..
430                        } => access,
431                        _ => crate::StorageAccess::default(),
432                    }
433                }
434                _ => crate::StorageAccess::default(),
435            },
436        };
437        let ty_name = TypeContext {
438            handle: var.ty,
439            gctx: self.module.to_ctx(),
440            names: self.names,
441            access: storage_access,
442            first_time: false,
443        };
444
445        let (coherent, space, access, reference) = if matches!(
446            self.module.types[var.ty].inner,
447            crate::TypeInner::BindingArray { .. }
448        ) {
449            ("", "", "", "")
450        } else {
451            let access = if var.space.needs_access_qualifier()
452                && !self.usage.intersects(valid::GlobalUse::WRITE)
453            {
454                "const"
455            } else {
456                ""
457            };
458            match (var.space.to_msl_name(), var.space) {
459                (Some(space), crate::AddressSpace::WorkGroup) => {
460                    ("", space, access, if self.reference { "&" } else { "" })
461                }
462                (Some(space), _) if self.reference => {
463                    let coherent = if var
464                        .memory_decorations
465                        .contains(crate::MemoryDecorations::COHERENT)
466                    {
467                        "coherent "
468                    } else {
469                        ""
470                    };
471                    (coherent, space, access, "&")
472                }
473                _ => ("", "", "", ""),
474            }
475        };
476
477        let ty = format!(
478            "{coherent}{space}{}{ty_name}{}{access}{reference}",
479            if space.is_empty() { "" } else { " " },
480            if access.is_empty() { "" } else { " " },
481        );
482
483        Ok(TypedGlobalVariableParts {
484            ty_name: ty,
485            var_name: name.clone(),
486        })
487    }
488    pub(super) fn try_fmt<W: Write>(&self, out: &mut W) -> BackendResult {
489        let parts = self.to_parts()?;
490
491        Ok(write!(out, "{} {}", parts.ty_name, parts.var_name)?)
492    }
493}
494
495#[derive(Eq, PartialEq, Hash)]
496pub(super) enum WrappedFunction {
497    UnaryOp {
498        op: crate::UnaryOperator,
499        ty: (Option<crate::VectorSize>, crate::Scalar),
500    },
501    BinaryOp {
502        op: crate::BinaryOperator,
503        left_ty: (Option<crate::VectorSize>, crate::Scalar),
504        right_ty: (Option<crate::VectorSize>, crate::Scalar),
505    },
506    Math {
507        fun: crate::MathFunction,
508        arg_ty: (Option<crate::VectorSize>, crate::Scalar),
509    },
510    Cast {
511        src_scalar: crate::Scalar,
512        vector_size: Option<crate::VectorSize>,
513        dst_scalar: crate::Scalar,
514    },
515    ImageLoad {
516        class: crate::ImageClass,
517    },
518    ImageSample {
519        class: crate::ImageClass,
520        clamp_to_edge: bool,
521    },
522    ImageQuerySize {
523        class: crate::ImageClass,
524    },
525    CooperativeLoad {
526        space_name: &'static str,
527        columns: crate::CooperativeSize,
528        rows: crate::CooperativeSize,
529        scalar: crate::Scalar,
530    },
531    CooperativeMultiplyAdd {
532        space_name: &'static str,
533        columns: crate::CooperativeSize,
534        rows: crate::CooperativeSize,
535        intermediate: crate::CooperativeSize,
536        ab_scalar: crate::Scalar,
537        c_scalar: crate::Scalar,
538    },
539    RayQueryGetIntersection {
540        committed: bool,
541    },
542}
543
544#[expect(missing_debug_implementations, reason = "would be way too verbose?")]
545pub struct Writer<W> {
546    pub(super) out: W,
547    pub(super) names: FastHashMap<NameKey, String>,
548    pub(super) named_expressions: crate::NamedExpressions,
549    /// Set of expressions that need to be baked to avoid unnecessary repetition in output
550    need_bake_expressions: back::NeedBakeExpressions,
551    pub(super) namer: proc::Namer,
552    pub(super) wrapped_functions: FastHashSet<WrappedFunction>,
553    emit_int_div_checks: bool,
554    /// Set of (struct type, struct field index) denoting which fields require
555    /// padding inserted **before** them (i.e. between fields at index - 1 and index)
556    struct_member_pads: FastHashSet<(Handle<crate::Type>, u32)>,
557    needs_object_memory_barriers: bool,
558}
559
560impl crate::Scalar {
561    pub(super) fn to_msl_name(self) -> &'static str {
562        use crate::ScalarKind as Sk;
563        match self {
564            Self {
565                kind: Sk::Float,
566                width: 4,
567            } => "float",
568            Self {
569                kind: Sk::Float,
570                width: 2,
571            } => "half",
572            Self {
573                kind: Sk::Sint,
574                width: 2,
575            } => "short",
576            Self {
577                kind: Sk::Uint,
578                width: 2,
579            } => "ushort",
580            Self {
581                kind: Sk::Sint,
582                width: 4,
583            } => "int",
584            Self {
585                kind: Sk::Uint,
586                width: 4,
587            } => "uint",
588            Self {
589                kind: Sk::Sint,
590                width: 8,
591            } => "long",
592            Self {
593                kind: Sk::Uint,
594                width: 8,
595            } => "ulong",
596            Self {
597                kind: Sk::Bool,
598                width: _,
599            } => "bool",
600            Self {
601                kind: Sk::AbstractInt | Sk::AbstractFloat,
602                width: _,
603            } => unreachable!("Found Abstract scalar kind"),
604            _ => unreachable!("Unsupported scalar kind: {:?}", self),
605        }
606    }
607}
608
609const fn separate(need_separator: bool) -> &'static str {
610    if need_separator {
611        ","
612    } else {
613        ""
614    }
615}
616
617fn should_pack_struct_member(
618    members: &[crate::StructMember],
619    span: u32,
620    index: usize,
621    module: &crate::Module,
622) -> Option<crate::Scalar> {
623    let member = &members[index];
624
625    let ty_inner = &module.types[member.ty].inner;
626    let last_offset = member.offset + ty_inner.size(module.to_ctx());
627    let next_offset = match members.get(index + 1) {
628        Some(next) => next.offset,
629        None => span,
630    };
631    let is_tight = next_offset == last_offset;
632
633    match *ty_inner {
634        crate::TypeInner::Vector {
635            size: crate::VectorSize::Tri,
636            scalar: scalar @ crate::Scalar { width: 4 | 2, .. },
637        } if is_tight => Some(scalar),
638        _ => None,
639    }
640}
641
642impl crate::AddressSpace {
643    /// Returns true if global variables in this address space are
644    /// passed in function arguments. These arguments need to be
645    /// passed through any functions called from the entry point.
646    const fn needs_pass_through(&self) -> bool {
647        match *self {
648            Self::Uniform
649            | Self::Storage { .. }
650            | Self::Private
651            | Self::WorkGroup
652            | Self::Immediate
653            | Self::Handle
654            | Self::TaskPayload => true,
655            Self::Function => false,
656            Self::RayPayload | Self::IncomingRayPayload => unreachable!(),
657        }
658    }
659
660    /// Returns true if the address space may need a "const" qualifier.
661    const fn needs_access_qualifier(&self) -> bool {
662        match *self {
663            //Note: we are ignoring the storage access here, and instead
664            // rely on the actual use of a global by functions. This means we
665            // may end up with "const" even if the binding is read-write,
666            // and that should be OK.
667            Self::Storage { .. } => true,
668            Self::TaskPayload => true,
669            Self::RayPayload | Self::IncomingRayPayload => unimplemented!(),
670            // These should always be read-write.
671            Self::Private | Self::WorkGroup => false,
672            // These translate to `constant` address space, no need for qualifiers.
673            Self::Uniform | Self::Immediate => false,
674            // Not applicable.
675            Self::Handle | Self::Function => false,
676        }
677    }
678
679    const fn to_msl_name(self) -> Option<&'static str> {
680        match self {
681            Self::Handle => None,
682            Self::Uniform | Self::Immediate => Some("constant"),
683            Self::Storage { .. } => Some("device"),
684            // note for `RayPayload`, this probably needs to be emulated as a
685            // private variable, as metal has essentially an inout input
686            // for where it is passed.
687            Self::Private | Self::Function | Self::RayPayload => Some("thread"),
688            Self::WorkGroup => Some("threadgroup"),
689            Self::TaskPayload => Some("object_data"),
690            Self::IncomingRayPayload => Some("ray_data"),
691        }
692    }
693}
694
695impl crate::Type {
696    // Returns `true` if we need to emit an alias for this type.
697    const fn needs_alias(&self) -> bool {
698        use crate::TypeInner as Ti;
699
700        match self.inner {
701            // value types are concise enough, we only alias them if they are named
702            Ti::Scalar(_)
703            | Ti::Vector { .. }
704            | Ti::Matrix { .. }
705            | Ti::CooperativeMatrix { .. }
706            | Ti::Atomic(_)
707            | Ti::Pointer { .. }
708            | Ti::ValuePointer { .. } => self.name.is_some(),
709            // composite types are better to be aliased, regardless of the name
710            Ti::Struct { .. } | Ti::Array { .. } => true,
711            // handle types may be different, depending on the global var access, so we always inline them
712            Ti::Image { .. }
713            | Ti::Sampler { .. }
714            | Ti::AccelerationStructure { .. }
715            | Ti::RayQuery { .. }
716            | Ti::BindingArray { .. } => false,
717        }
718    }
719}
720
721#[derive(Clone, Copy)]
722pub(super) enum FunctionOrigin {
723    Handle(Handle<crate::Function>),
724    EntryPoint(proc::EntryPointIndex),
725}
726
727pub(super) trait NameKeyExt {
728    fn local(origin: FunctionOrigin, local_handle: Handle<crate::LocalVariable>) -> NameKey {
729        match origin {
730            FunctionOrigin::Handle(handle) => NameKey::FunctionLocal(handle, local_handle),
731            FunctionOrigin::EntryPoint(idx) => NameKey::EntryPointLocal(idx, local_handle),
732        }
733    }
734
735    /// Return the name key for a local variable used by ReadZeroSkipWrite bounds-check
736    /// policy when it needs to produce a pointer-typed result for an OOB access. These
737    /// are unique per accessed type, so the second argument is a type handle. See docs
738    /// for [`crate::back::msl`].
739    fn oob_local_for_type(origin: FunctionOrigin, ty: Handle<crate::Type>) -> NameKey {
740        match origin {
741            FunctionOrigin::Handle(handle) => NameKey::FunctionOobLocal(handle, ty),
742            FunctionOrigin::EntryPoint(idx) => NameKey::EntryPointOobLocal(idx, ty),
743        }
744    }
745}
746
747impl NameKeyExt for NameKey {}
748
749/// A level of detail argument.
750///
751/// When [`BoundsCheckPolicy::Restrict`] applies to an [`ImageLoad`] access, we
752/// save the clamped level of detail in a temporary variable whose name is based
753/// on the handle of the `ImageLoad` expression. But for other policies, we just
754/// use the expression directly.
755///
756/// [`BoundsCheckPolicy::Restrict`]: index::BoundsCheckPolicy::Restrict
757/// [`ImageLoad`]: crate::Expression::ImageLoad
758#[derive(Clone, Copy)]
759enum LevelOfDetail {
760    Direct(Handle<crate::Expression>),
761    Restricted(Handle<crate::Expression>),
762}
763
764/// Values needed to select a particular texel for [`ImageLoad`] and [`ImageStore`].
765///
766/// When this is used in code paths unconcerned with the `Restrict` bounds check
767/// policy, the `LevelOfDetail` enum introduces an unneeded match, since `level`
768/// will always be either `None` or `Some(Direct(_))`. But this turns out not to
769/// be too awkward. If that changes, we can revisit.
770///
771/// [`ImageLoad`]: crate::Expression::ImageLoad
772/// [`ImageStore`]: crate::Statement::ImageStore
773struct TexelAddress {
774    coordinate: Handle<crate::Expression>,
775    array_index: Option<Handle<crate::Expression>>,
776    sample: Option<Handle<crate::Expression>>,
777    level: Option<LevelOfDetail>,
778}
779
780pub(super) struct ExpressionContext<'a> {
781    pub(super) function: &'a crate::Function,
782    pub(super) origin: FunctionOrigin,
783    pub(super) info: &'a valid::FunctionInfo,
784    pub(super) module: &'a crate::Module,
785    pub(super) mod_info: &'a valid::ModuleInfo,
786    pub(super) pipeline_options: &'a PipelineOptions,
787    pub(super) lang_version: (u8, u8),
788    pub(super) policies: index::BoundsCheckPolicies,
789
790    /// The set of expressions used as indices in `ReadZeroSkipWrite`-policy
791    /// accesses. These may need to be cached in temporary variables. See
792    /// `index::find_checked_indexes` for details.
793    pub(super) guarded_indices: HandleSet<crate::Expression>,
794    /// See [`Writer::gen_force_bounded_loop_statements`] for details.
795    pub(super) force_loop_bounding: bool,
796    /// Whether to emit safety checks for integer division/modulo.
797    emit_int_div_checks: bool,
798    pub(super) ray_query_initialization_tracking: bool,
799}
800
801impl<'a> ExpressionContext<'a> {
802    fn resolve_type(&self, handle: Handle<crate::Expression>) -> &'a crate::TypeInner {
803        self.info[handle].ty.inner_with(&self.module.types)
804    }
805
806    /// Walks from an inner pointer toward a storage binding array global and
807    /// returns the element index at that global for MSL runtime buffer sizing.
808    fn binding_array_index_from_chain(
809        &self,
810        mut expr: Handle<crate::Expression>,
811        global: Handle<crate::GlobalVariable>,
812    ) -> Option<index::GuardedIndex> {
813        let expressions = &self.function.expressions;
814        loop {
815            match expressions[expr] {
816                crate::Expression::Load { pointer } => expr = pointer,
817                crate::Expression::Access { base, index } => {
818                    if matches!(
819                        expressions[base],
820                        crate::Expression::GlobalVariable(g) if g == global
821                    ) {
822                        return Some(index::GuardedIndex::Expression(index));
823                    }
824                    expr = base;
825                }
826                crate::Expression::AccessIndex { base, index } => {
827                    if matches!(
828                        expressions[base],
829                        crate::Expression::GlobalVariable(g) if g == global
830                    ) {
831                        return Some(index::GuardedIndex::Known(index));
832                    }
833                    expr = base;
834                }
835                crate::Expression::GlobalVariable(_) => return None,
836                _ => return None,
837            }
838        }
839    }
840
841    /// Whether `expr` is directly indexing a global in the outer `Access`/`AccessIndex` shape.
842    fn is_global_access_chain(&self, expr: Handle<crate::Expression>) -> bool {
843        let expressions = &self.function.expressions;
844        match expressions[expr] {
845            crate::Expression::Access { base, .. } => match expressions[base] {
846                crate::Expression::GlobalVariable(_) => true,
847                crate::Expression::Access { .. } => self.is_global_access_chain(base),
848                _ => false,
849            },
850            crate::Expression::AccessIndex { base, .. } => {
851                matches!(expressions[base], crate::Expression::GlobalVariable(_))
852            }
853            _ => false,
854        }
855    }
856
857    fn struct_member_needs_arrow(
858        &self,
859        base: Handle<crate::Expression>,
860        originating_global_ty: impl FnOnce(&crate::TypeInner) -> bool,
861    ) -> bool {
862        let originating_matches = match self.function.originating_global(base) {
863            Some(gv) => {
864                originating_global_ty(&self.module.types[self.module.global_variables[gv].ty].inner)
865            }
866            None => false,
867        };
868        originating_matches && self.is_global_access_chain(base)
869    }
870
871    /// Return true if calls to `image`'s `read` and `write` methods should supply a level of detail.
872    ///
873    /// Only mipmapped images need to specify a level of detail. Since 1D
874    /// textures cannot have mipmaps, MSL requires that the level argument to
875    /// texture1d queries and accesses must be a constexpr 0. It's easiest
876    /// just to omit the level entirely for 1D textures.
877    fn image_needs_lod(&self, image: Handle<crate::Expression>) -> bool {
878        let image_ty = self.resolve_type(image);
879        if let crate::TypeInner::Image { dim, class, .. } = *image_ty {
880            class.is_mipmapped() && dim != crate::ImageDimension::D1
881        } else {
882            false
883        }
884    }
885
886    fn choose_bounds_check_policy(
887        &self,
888        pointer: Handle<crate::Expression>,
889    ) -> index::BoundsCheckPolicy {
890        self.policies
891            .choose_policy(pointer, &self.module.types, self.info)
892    }
893
894    /// See docs for [`proc::index::access_needs_check`].
895    fn access_needs_check(
896        &self,
897        base: Handle<crate::Expression>,
898        index: index::GuardedIndex,
899    ) -> Option<index::IndexableLength> {
900        index::access_needs_check(
901            base,
902            index,
903            self.module,
904            &self.function.expressions,
905            self.info,
906        )
907    }
908
909    /// See docs for [`proc::index::bounds_check_iter`].
910    fn bounds_check_iter(
911        &self,
912        chain: Handle<crate::Expression>,
913    ) -> impl Iterator<Item = BoundsCheck> + '_ {
914        index::bounds_check_iter(chain, self.module, self.function, self.info)
915    }
916
917    /// See docs for [`proc::index::oob_local_types`].
918    fn oob_local_types(&self) -> FastHashSet<Handle<crate::Type>> {
919        index::oob_local_types(self.module, self.function, self.info, self.policies)
920    }
921
922    fn get_packed_vec_kind(&self, expr_handle: Handle<crate::Expression>) -> Option<crate::Scalar> {
923        match self.function.expressions[expr_handle] {
924            crate::Expression::AccessIndex { base, index } => {
925                let ty = match *self.resolve_type(base) {
926                    crate::TypeInner::Pointer { base, .. } => &self.module.types[base].inner,
927                    ref ty => ty,
928                };
929                match *ty {
930                    crate::TypeInner::Struct {
931                        ref members, span, ..
932                    } => should_pack_struct_member(members, span, index as usize, self.module),
933                    _ => None,
934                }
935            }
936            _ => None,
937        }
938    }
939}
940
941pub(super) struct StatementContext<'a> {
942    pub(super) expression: ExpressionContext<'a>,
943    pub(super) result_struct: Option<&'a str>,
944}
945
946impl<W: Write> Writer<W> {
947    /// Creates a new `Writer` instance.
948    pub fn new(out: W) -> Self {
949        Writer {
950            out,
951            names: FastHashMap::default(),
952            named_expressions: Default::default(),
953            need_bake_expressions: Default::default(),
954            namer: proc::Namer::default(),
955            wrapped_functions: FastHashSet::default(),
956            emit_int_div_checks: true,
957            struct_member_pads: FastHashSet::default(),
958            needs_object_memory_barriers: false,
959        }
960    }
961
962    /// Finishes writing and returns the output.
963    // See https://github.com/rust-lang/rust-clippy/issues/4979.
964    pub fn finish(self) -> W {
965        self.out
966    }
967
968    /// Generates statements to be inserted immediately before and at the very
969    /// start of the body of each loop, to defeat MSL infinite loop reasoning.
970    /// The 0th item of the returned tuple should be inserted immediately prior
971    /// to the loop and the 1st item should be inserted at the very start of
972    /// the loop body.
973    ///
974    /// # What is this trying to solve?
975    ///
976    /// In Metal Shading Language, an infinite loop has undefined behavior.
977    /// (This rule is inherited from C++14.) This means that, if the MSL
978    /// compiler determines that a given loop will never exit, it may assume
979    /// that it is never reached. It may thus assume that any conditions
980    /// sufficient to cause the loop to be reached must be false. Like many
981    /// optimizing compilers, MSL uses this kind of analysis to establish limits
982    /// on the range of values variables involved in those conditions might
983    /// hold.
984    ///
985    /// For example, suppose the MSL compiler sees the code:
986    ///
987    /// ```ignore
988    /// if (i >= 10) {
989    ///     while (true) { }
990    /// }
991    /// ```
992    ///
993    /// It will recognize that the `while` loop will never terminate, conclude
994    /// that it must be unreachable, and thus infer that, if this code is
995    /// reached, then `i < 10` at that point.
996    ///
997    /// Now suppose that, at some point where `i` has the same value as above,
998    /// the compiler sees the code:
999    ///
1000    /// ```ignore
1001    /// if (i < 10) {
1002    ///     a[i] = 1;
1003    /// }
1004    /// ```
1005    ///
1006    /// Because the compiler is confident that `i < 10`, it will make the
1007    /// assignment to `a[i]` unconditional, rewriting this code as, simply:
1008    ///
1009    /// ```ignore
1010    /// a[i] = 1;
1011    /// ```
1012    ///
1013    /// If that `if` condition was injected by Naga to implement a bounds check,
1014    /// the MSL compiler's optimizations could allow out-of-bounds array
1015    /// accesses to occur.
1016    ///
1017    /// Naga cannot feasibly anticipate whether the MSL compiler will determine
1018    /// that a loop is infinite, so an attacker could craft a Naga module
1019    /// containing an infinite loop protected by conditions that cause the Metal
1020    /// compiler to remove bounds checks that Naga injected elsewhere in the
1021    /// function.
1022    ///
1023    /// This rewrite could occur even if the conditional assignment appears
1024    /// *before* the `while` loop, as long as `i < 10` by the time the loop is
1025    /// reached. This would allow the attacker to save the results of
1026    /// unauthorized reads somewhere accessible before entering the infinite
1027    /// loop. But even worse, the MSL compiler has been observed to simply
1028    /// delete the infinite loop entirely, so that even code dominated by the
1029    /// loop becomes reachable. This would make the attack even more flexible,
1030    /// since shaders that would appear to never terminate would actually exit
1031    /// nicely, after having stolen data from elsewhere in the GPU address
1032    /// space.
1033    ///
1034    /// To avoid UB, Naga must persuade the MSL compiler that no loop Naga
1035    /// generates is infinite. One approach would be to add inline assembly to
1036    /// each loop that is annotated as potentially branching out of the loop,
1037    /// but which in fact generates no instructions. Unfortunately, inline
1038    /// assembly is not handled correctly by some Metal device drivers.
1039    ///
1040    /// A previously used approach was to add the following code to the bottom
1041    /// of every loop:
1042    ///
1043    /// ```ignore
1044    /// if (volatile bool unpredictable = false; unpredictable)
1045    ///     break;
1046    /// ```
1047    ///
1048    /// Although the `if` condition will always be false in any real execution,
1049    /// the `volatile` qualifier prevents the compiler from assuming this. Thus,
1050    /// it must assume that the `break` might be reached, and hence that the
1051    /// loop is not unbounded. This prevents the range analysis impact described
1052    /// above. Unfortunately this prevented the compiler from making important,
1053    /// and safe, optimizations such as loop unrolling and was observed to
1054    /// significantly hurt performance.
1055    ///
1056    /// Our current approach declares a counter before every loop and
1057    /// increments it every iteration, breaking after 2^64 iterations:
1058    ///
1059    /// ```ignore
1060    /// uint2 loop_bound = uint2(0);
1061    /// while (true) {
1062    ///   if (metal::all(loop_bound == uint2(4294967295))) { break; }
1063    ///   loop_bound += uint2(loop_bound.y == 4294967295, 1);
1064    /// }
1065    /// ```
1066    ///
1067    /// This convinces the compiler that the loop is finite and therefore may
1068    /// execute, whilst at the same time allowing optimizations such as loop
1069    /// unrolling. Furthermore the 64-bit counter is large enough it seems
1070    /// implausible that it would affect the execution of any shader.
1071    ///
1072    /// This approach is also used by Chromium WebGPU's Dawn shader compiler:
1073    /// <https://dawn.googlesource.com/dawn/+/d9e2d1f718678ebee0728b999830576c410cce0a/src/tint/lang/core/ir/transform/prevent_infinite_loops.cc>
1074    fn gen_force_bounded_loop_statements(
1075        &mut self,
1076        level: back::Level,
1077        context: &StatementContext,
1078    ) -> Option<(String, String)> {
1079        if !context.expression.force_loop_bounding {
1080            return None;
1081        }
1082
1083        let loop_bound_name = self.namer.call("loop_bound");
1084        // Count down from u32::MAX rather than up from 0 to avoid hang on
1085        // certain Intel drivers. See <https://github.com/gfx-rs/wgpu/issues/7319>.
1086        let decl = format!("{level}uint2 {loop_bound_name} = uint2({}u);", u32::MAX);
1087        let level = level.next();
1088        let break_and_inc = format!(
1089            "{level}if ({NAMESPACE}::all({loop_bound_name} == uint2(0u))) {{ break; }}
1090{level}{loop_bound_name} -= uint2({loop_bound_name}.y == 0u, 1u);"
1091        );
1092
1093        Some((decl, break_and_inc))
1094    }
1095
1096    fn put_call_parameters(
1097        &mut self,
1098        parameters: impl Iterator<Item = Handle<crate::Expression>>,
1099        context: &ExpressionContext,
1100    ) -> BackendResult {
1101        self.put_call_parameters_impl(parameters, context, |writer, context, expr| {
1102            writer.put_expression(expr, context, true)
1103        })
1104    }
1105
1106    fn put_call_parameters_impl<C, E>(
1107        &mut self,
1108        parameters: impl Iterator<Item = Handle<crate::Expression>>,
1109        ctx: &C,
1110        put_expression: E,
1111    ) -> BackendResult
1112    where
1113        E: Fn(&mut Self, &C, Handle<crate::Expression>) -> BackendResult,
1114    {
1115        write!(self.out, "(")?;
1116        for (i, handle) in parameters.enumerate() {
1117            if i != 0 {
1118                write!(self.out, ", ")?;
1119            }
1120            put_expression(self, ctx, handle)?;
1121        }
1122        write!(self.out, ")")?;
1123        Ok(())
1124    }
1125
1126    /// Writes the local variables of the given function, as well as any extra
1127    /// out-of-bounds locals that are needed.
1128    ///
1129    /// The names of the OOB locals are also added to `self.names` at the same
1130    /// time.
1131    fn put_locals(&mut self, context: &ExpressionContext) -> BackendResult {
1132        let oob_local_types = context.oob_local_types();
1133        for &ty in oob_local_types.iter() {
1134            let name_key = NameKey::oob_local_for_type(context.origin, ty);
1135            self.names.insert(name_key, self.namer.call("oob"));
1136        }
1137
1138        for (name_key, ty, init) in context
1139            .function
1140            .local_variables
1141            .iter()
1142            .map(|(local_handle, local)| {
1143                let name_key = NameKey::local(context.origin, local_handle);
1144                (name_key, local.ty, local.init)
1145            })
1146            .chain(oob_local_types.iter().map(|&ty| {
1147                let name_key = NameKey::oob_local_for_type(context.origin, ty);
1148                (name_key, ty, None)
1149            }))
1150        {
1151            let ty_name = TypeContext {
1152                handle: ty,
1153                gctx: context.module.to_ctx(),
1154                names: &self.names,
1155                access: crate::StorageAccess::empty(),
1156                first_time: false,
1157            };
1158            write!(
1159                self.out,
1160                "{}{} {}",
1161                back::INDENT,
1162                ty_name,
1163                self.names[&name_key]
1164            )?;
1165            match init {
1166                Some(value) => {
1167                    write!(self.out, " = ")?;
1168                    self.put_expression(value, context, true)?;
1169                }
1170                None => {
1171                    write!(self.out, " = {{}}")?;
1172                }
1173            };
1174            writeln!(self.out, ";")?;
1175
1176            // If this variable is a ray query, put in an initialization tracker.
1177            if context.ray_query_initialization_tracking {
1178                if let crate::TypeInner::RayQuery { .. } = context.module.types[ty].inner {
1179                    writeln!(
1180                        self.out,
1181                        "{}uint {}{} = 0u;",
1182                        back::INDENT,
1183                        super::ray::RAY_QUERY_TRACKER_VARIABLE_PREFIX,
1184                        self.names[&name_key]
1185                    )?;
1186
1187                    writeln!(
1188                        self.out,
1189                        "{}float {}{} = 0.0;",
1190                        back::INDENT,
1191                        super::ray::RAY_QUERY_T_MAX_TRACKER_VARIABLE_PREFIX,
1192                        self.names[&name_key]
1193                    )?;
1194                }
1195            }
1196        }
1197        Ok(())
1198    }
1199
1200    fn put_level_of_detail(
1201        &mut self,
1202        level: LevelOfDetail,
1203        context: &ExpressionContext,
1204    ) -> BackendResult {
1205        match level {
1206            LevelOfDetail::Direct(expr) => self.put_expression(expr, context, true)?,
1207            LevelOfDetail::Restricted(load) => write!(self.out, "{}", ClampedLod(load))?,
1208        }
1209        Ok(())
1210    }
1211
1212    fn put_image_query(
1213        &mut self,
1214        image: Handle<crate::Expression>,
1215        query: &str,
1216        level: Option<LevelOfDetail>,
1217        context: &ExpressionContext,
1218    ) -> BackendResult {
1219        self.put_expression(image, context, false)?;
1220        write!(self.out, ".get_{query}(")?;
1221        if let Some(level) = level {
1222            self.put_level_of_detail(level, context)?;
1223        }
1224        write!(self.out, ")")?;
1225        Ok(())
1226    }
1227
1228    fn put_image_size_query(
1229        &mut self,
1230        image: Handle<crate::Expression>,
1231        level: Option<LevelOfDetail>,
1232        kind: crate::ScalarKind,
1233        context: &ExpressionContext,
1234    ) -> BackendResult {
1235        if let crate::TypeInner::Image {
1236            class: crate::ImageClass::External,
1237            ..
1238        } = *context.resolve_type(image)
1239        {
1240            write!(self.out, "{IMAGE_SIZE_EXTERNAL_FUNCTION}(")?;
1241            self.put_expression(image, context, true)?;
1242            write!(self.out, ")")?;
1243            return Ok(());
1244        }
1245
1246        //Note: MSL only has separate width/height/depth queries,
1247        // so compose the result of them.
1248        let dim = match *context.resolve_type(image) {
1249            crate::TypeInner::Image { dim, .. } => dim,
1250            ref other => unreachable!("Unexpected type {:?}", other),
1251        };
1252        let scalar = crate::Scalar { kind, width: 4 };
1253        let coordinate_type = scalar.to_msl_name();
1254        match dim {
1255            crate::ImageDimension::D1 => {
1256                // Since 1D textures never have mipmaps, MSL requires that the
1257                // `level` argument be a constexpr 0. It's simplest for us just
1258                // to pass `None` and omit the level entirely.
1259                if kind == crate::ScalarKind::Uint {
1260                    // No need to construct a vector. No cast needed.
1261                    self.put_image_query(image, "width", None, context)?;
1262                } else {
1263                    // There's no definition for `int` in the `metal` namespace.
1264                    write!(self.out, "int(")?;
1265                    self.put_image_query(image, "width", None, context)?;
1266                    write!(self.out, ")")?;
1267                }
1268            }
1269            crate::ImageDimension::D2 => {
1270                write!(self.out, "{NAMESPACE}::{coordinate_type}2(")?;
1271                self.put_image_query(image, "width", level, context)?;
1272                write!(self.out, ", ")?;
1273                self.put_image_query(image, "height", level, context)?;
1274                write!(self.out, ")")?;
1275            }
1276            crate::ImageDimension::D3 => {
1277                write!(self.out, "{NAMESPACE}::{coordinate_type}3(")?;
1278                self.put_image_query(image, "width", level, context)?;
1279                write!(self.out, ", ")?;
1280                self.put_image_query(image, "height", level, context)?;
1281                write!(self.out, ", ")?;
1282                self.put_image_query(image, "depth", level, context)?;
1283                write!(self.out, ")")?;
1284            }
1285            crate::ImageDimension::Cube => {
1286                write!(self.out, "{NAMESPACE}::{coordinate_type}2(")?;
1287                self.put_image_query(image, "width", level, context)?;
1288                write!(self.out, ")")?;
1289            }
1290        }
1291        Ok(())
1292    }
1293
1294    fn put_cast_to_uint_scalar_or_vector(
1295        &mut self,
1296        expr: Handle<crate::Expression>,
1297        context: &ExpressionContext,
1298    ) -> BackendResult {
1299        // coordinates in IR are int, but Metal expects uint
1300        match *context.resolve_type(expr) {
1301            crate::TypeInner::Scalar(_) => {
1302                put_numeric_type(&mut self.out, crate::Scalar::U32, &[])?
1303            }
1304            crate::TypeInner::Vector { size, .. } => {
1305                put_numeric_type(&mut self.out, crate::Scalar::U32, &[size])?
1306            }
1307            _ => {
1308                return Err(Error::GenericValidation(
1309                    "Invalid type for image coordinate".into(),
1310                ))
1311            }
1312        };
1313
1314        write!(self.out, "(")?;
1315        self.put_expression(expr, context, true)?;
1316        write!(self.out, ")")?;
1317        Ok(())
1318    }
1319
1320    fn put_image_sample_level(
1321        &mut self,
1322        image: Handle<crate::Expression>,
1323        level: crate::SampleLevel,
1324        context: &ExpressionContext,
1325    ) -> BackendResult {
1326        let has_levels = context.image_needs_lod(image);
1327        match level {
1328            crate::SampleLevel::Auto => {}
1329            crate::SampleLevel::Zero => {
1330                //TODO: do we support Zero on `Sampled` image classes?
1331            }
1332            _ if !has_levels => {
1333                log::warn!("1D image can't be sampled with level {level:?}");
1334            }
1335            crate::SampleLevel::Exact(h) => {
1336                write!(self.out, ", {NAMESPACE}::level(")?;
1337                self.put_expression(h, context, true)?;
1338                write!(self.out, ")")?;
1339            }
1340            crate::SampleLevel::Bias(h) => {
1341                write!(self.out, ", {NAMESPACE}::bias(")?;
1342                self.put_expression(h, context, true)?;
1343                write!(self.out, ")")?;
1344            }
1345            crate::SampleLevel::Gradient { x, y } => {
1346                write!(self.out, ", {NAMESPACE}::gradient2d(")?;
1347                self.put_expression(x, context, true)?;
1348                write!(self.out, ", ")?;
1349                self.put_expression(y, context, true)?;
1350                write!(self.out, ")")?;
1351            }
1352        }
1353        Ok(())
1354    }
1355
1356    fn put_image_coordinate_limits(
1357        &mut self,
1358        image: Handle<crate::Expression>,
1359        level: Option<LevelOfDetail>,
1360        context: &ExpressionContext,
1361    ) -> BackendResult {
1362        self.put_image_size_query(image, level, crate::ScalarKind::Uint, context)?;
1363        write!(self.out, " - 1")?;
1364        Ok(())
1365    }
1366
1367    /// General function for writing restricted image indexes.
1368    ///
1369    /// This is used to produce restricted mip levels, array indices, and sample
1370    /// indices for [`ImageLoad`] and [`ImageStore`] accesses under the
1371    /// [`Restrict`] bounds check policy.
1372    ///
1373    /// This function writes an expression of the form:
1374    ///
1375    /// ```ignore
1376    ///
1377    ///     metal::min(uint(INDEX), IMAGE.LIMIT_METHOD() - 1)
1378    ///
1379    /// ```
1380    ///
1381    /// [`ImageLoad`]: crate::Expression::ImageLoad
1382    /// [`ImageStore`]: crate::Statement::ImageStore
1383    /// [`Restrict`]: index::BoundsCheckPolicy::Restrict
1384    fn put_restricted_scalar_image_index(
1385        &mut self,
1386        image: Handle<crate::Expression>,
1387        index: Handle<crate::Expression>,
1388        limit_method: &str,
1389        context: &ExpressionContext,
1390    ) -> BackendResult {
1391        write!(self.out, "{NAMESPACE}::min(uint(")?;
1392        self.put_expression(index, context, true)?;
1393        write!(self.out, "), ")?;
1394        self.put_expression(image, context, false)?;
1395        write!(self.out, ".{limit_method}() - 1)")?;
1396        Ok(())
1397    }
1398
1399    fn put_restricted_texel_address(
1400        &mut self,
1401        image: Handle<crate::Expression>,
1402        address: &TexelAddress,
1403        context: &ExpressionContext,
1404    ) -> BackendResult {
1405        // Write the coordinate.
1406        write!(self.out, "{NAMESPACE}::min(")?;
1407        self.put_cast_to_uint_scalar_or_vector(address.coordinate, context)?;
1408        write!(self.out, ", ")?;
1409        self.put_image_coordinate_limits(image, address.level, context)?;
1410        write!(self.out, ")")?;
1411
1412        // Write the array index, if present.
1413        if let Some(array_index) = address.array_index {
1414            write!(self.out, ", ")?;
1415            self.put_restricted_scalar_image_index(image, array_index, "get_array_size", context)?;
1416        }
1417
1418        // Write the sample index, if present.
1419        if let Some(sample) = address.sample {
1420            write!(self.out, ", ")?;
1421            self.put_restricted_scalar_image_index(image, sample, "get_num_samples", context)?;
1422        }
1423
1424        // The level of detail should be clamped and cached by
1425        // `put_cache_restricted_level`, so we don't need to clamp it here.
1426        if let Some(level) = address.level {
1427            write!(self.out, ", ")?;
1428            self.put_level_of_detail(level, context)?;
1429        }
1430
1431        Ok(())
1432    }
1433
1434    /// Write an expression that is true if the given image access is in bounds.
1435    fn put_image_access_bounds_check(
1436        &mut self,
1437        image: Handle<crate::Expression>,
1438        address: &TexelAddress,
1439        context: &ExpressionContext,
1440    ) -> BackendResult {
1441        let mut conjunction = "";
1442
1443        // First, check the level of detail. Only if that is in bounds can we
1444        // use it to find the appropriate bounds for the coordinates.
1445        let level = if let Some(level) = address.level {
1446            write!(self.out, "uint(")?;
1447            self.put_level_of_detail(level, context)?;
1448            write!(self.out, ") < ")?;
1449            self.put_expression(image, context, true)?;
1450            write!(self.out, ".get_num_mip_levels()")?;
1451            conjunction = " && ";
1452            Some(level)
1453        } else {
1454            None
1455        };
1456
1457        // Check sample index, if present.
1458        if let Some(sample) = address.sample {
1459            write!(self.out, "uint(")?;
1460            self.put_expression(sample, context, true)?;
1461            write!(self.out, ") < ")?;
1462            self.put_expression(image, context, true)?;
1463            write!(self.out, ".get_num_samples()")?;
1464            conjunction = " && ";
1465        }
1466
1467        // Check array index, if present.
1468        if let Some(array_index) = address.array_index {
1469            write!(self.out, "{conjunction}uint(")?;
1470            self.put_expression(array_index, context, true)?;
1471            write!(self.out, ") < ")?;
1472            self.put_expression(image, context, true)?;
1473            write!(self.out, ".get_array_size()")?;
1474            conjunction = " && ";
1475        }
1476
1477        // Finally, check if the coordinates are within bounds.
1478        let coord_is_vector = match *context.resolve_type(address.coordinate) {
1479            crate::TypeInner::Vector { .. } => true,
1480            _ => false,
1481        };
1482        write!(self.out, "{conjunction}")?;
1483        if coord_is_vector {
1484            write!(self.out, "{NAMESPACE}::all(")?;
1485        }
1486        self.put_cast_to_uint_scalar_or_vector(address.coordinate, context)?;
1487        write!(self.out, " < ")?;
1488        self.put_image_size_query(image, level, crate::ScalarKind::Uint, context)?;
1489        if coord_is_vector {
1490            write!(self.out, ")")?;
1491        }
1492
1493        Ok(())
1494    }
1495
1496    fn put_image_load(
1497        &mut self,
1498        load: Handle<crate::Expression>,
1499        image: Handle<crate::Expression>,
1500        mut address: TexelAddress,
1501        context: &ExpressionContext,
1502    ) -> BackendResult {
1503        if let crate::TypeInner::Image {
1504            class: crate::ImageClass::External,
1505            ..
1506        } = *context.resolve_type(image)
1507        {
1508            write!(self.out, "{IMAGE_LOAD_EXTERNAL_FUNCTION}(")?;
1509            self.put_expression(image, context, true)?;
1510            write!(self.out, ", ")?;
1511            self.put_cast_to_uint_scalar_or_vector(address.coordinate, context)?;
1512            write!(self.out, ")")?;
1513            return Ok(());
1514        }
1515
1516        match context.policies.image_load {
1517            proc::BoundsCheckPolicy::Restrict => {
1518                // Use the cached restricted level of detail, if any. Omit the
1519                // level altogether for 1D textures.
1520                if address.level.is_some() {
1521                    address.level = if context.image_needs_lod(image) {
1522                        Some(LevelOfDetail::Restricted(load))
1523                    } else {
1524                        None
1525                    }
1526                }
1527
1528                self.put_expression(image, context, false)?;
1529                write!(self.out, ".read(")?;
1530                self.put_restricted_texel_address(image, &address, context)?;
1531                write!(self.out, ")")?;
1532            }
1533            proc::BoundsCheckPolicy::ReadZeroSkipWrite => {
1534                write!(self.out, "(")?;
1535                self.put_image_access_bounds_check(image, &address, context)?;
1536                write!(self.out, " ? ")?;
1537                self.put_unchecked_image_load(image, &address, context)?;
1538                write!(self.out, ": DefaultConstructible())")?;
1539            }
1540            proc::BoundsCheckPolicy::Unchecked => {
1541                self.put_unchecked_image_load(image, &address, context)?;
1542            }
1543        }
1544
1545        Ok(())
1546    }
1547
1548    fn put_unchecked_image_load(
1549        &mut self,
1550        image: Handle<crate::Expression>,
1551        address: &TexelAddress,
1552        context: &ExpressionContext,
1553    ) -> BackendResult {
1554        self.put_expression(image, context, false)?;
1555        write!(self.out, ".read(")?;
1556        // coordinates in IR are int, but Metal expects uint
1557        self.put_cast_to_uint_scalar_or_vector(address.coordinate, context)?;
1558        if let Some(expr) = address.array_index {
1559            write!(self.out, ", ")?;
1560            self.put_expression(expr, context, true)?;
1561        }
1562        if let Some(sample) = address.sample {
1563            write!(self.out, ", ")?;
1564            self.put_expression(sample, context, true)?;
1565        }
1566        if let Some(level) = address.level {
1567            if context.image_needs_lod(image) {
1568                write!(self.out, ", ")?;
1569                self.put_level_of_detail(level, context)?;
1570            }
1571        }
1572        write!(self.out, ")")?;
1573
1574        Ok(())
1575    }
1576
1577    fn put_image_atomic(
1578        &mut self,
1579        level: back::Level,
1580        image: Handle<crate::Expression>,
1581        address: &TexelAddress,
1582        fun: crate::AtomicFunction,
1583        value: Handle<crate::Expression>,
1584        context: &StatementContext,
1585    ) -> BackendResult {
1586        write!(self.out, "{level}")?;
1587        self.put_expression(image, &context.expression, false)?;
1588        let op = if context.expression.resolve_type(value).scalar_width() == Some(8) {
1589            fun.to_msl_64_bit()?
1590        } else {
1591            fun.to_msl()
1592        };
1593        write!(self.out, ".atomic_{op}(")?;
1594        // coordinates in IR are int, but Metal expects uint
1595        self.put_cast_to_uint_scalar_or_vector(address.coordinate, &context.expression)?;
1596        write!(self.out, ", ")?;
1597        self.put_expression(value, &context.expression, true)?;
1598        writeln!(self.out, ");")?;
1599
1600        // Workaround for Apple Metal TBDR driver bug: fragment shader atomic
1601        // texture writes randomly drop unless followed by a standard texture
1602        // write. Insert a dead-code write behind an unprovable condition so
1603        // the compiler emits proper memory safety barriers.
1604        // See: https://projects.blender.org/blender/blender/commit/aa95220576706122d79c91c7f5c522e6c7416425
1605        let value_ty = context.expression.resolve_type(value);
1606        let zero_value = match (value_ty.scalar_kind(), value_ty.scalar_width()) {
1607            (Some(crate::ScalarKind::Sint), _) => "int4(0)",
1608            (_, Some(8)) => "ulong4(0uL)",
1609            _ => "uint4(0u)",
1610        };
1611        let coord_ty = context.expression.resolve_type(address.coordinate);
1612        let x = if matches!(coord_ty, crate::TypeInner::Scalar(_)) {
1613            ""
1614        } else {
1615            ".x"
1616        };
1617        write!(self.out, "{level}if (")?;
1618        self.put_expression(address.coordinate, &context.expression, true)?;
1619        write!(self.out, "{x} == -99999) {{ ")?;
1620        self.put_expression(image, &context.expression, false)?;
1621        write!(self.out, ".write({zero_value}, ")?;
1622        self.put_cast_to_uint_scalar_or_vector(address.coordinate, &context.expression)?;
1623        if let Some(array_index) = address.array_index {
1624            write!(self.out, ", ")?;
1625            self.put_expression(array_index, &context.expression, true)?;
1626        }
1627        writeln!(self.out, "); }}")?;
1628
1629        Ok(())
1630    }
1631
1632    fn put_image_store(
1633        &mut self,
1634        level: back::Level,
1635        image: Handle<crate::Expression>,
1636        address: &TexelAddress,
1637        value: Handle<crate::Expression>,
1638        context: &StatementContext,
1639    ) -> BackendResult {
1640        write!(self.out, "{level}")?;
1641        self.put_expression(image, &context.expression, false)?;
1642        write!(self.out, ".write(")?;
1643        self.put_expression(value, &context.expression, true)?;
1644        write!(self.out, ", ")?;
1645        // coordinates in IR are int, but Metal expects uint
1646        self.put_cast_to_uint_scalar_or_vector(address.coordinate, &context.expression)?;
1647        if let Some(expr) = address.array_index {
1648            write!(self.out, ", ")?;
1649            self.put_expression(expr, &context.expression, true)?;
1650        }
1651        writeln!(self.out, ");")?;
1652
1653        Ok(())
1654    }
1655
1656    /// Write the maximum valid index of the dynamically sized array at the end of `handle`.
1657    ///
1658    /// The 'maximum valid index' is simply one less than the array's length.
1659    ///
1660    /// This emits an expression of the form `a / b`, so the caller must
1661    /// parenthesize its output if it will be applying operators of higher
1662    /// precedence.
1663    ///
1664    /// `handle` must be the handle of a global variable whose final member is a
1665    /// dynamically sized array.
1666    ///
1667    /// `chain_expr` sits on the pointer path from an inner access, such as the
1668    /// value passed to array length, back toward this global. For storage binding
1669    /// arrays, that path tells us which element index to use in `_buffer_sizes`.
1670    fn binding_array_layout_count(
1671        module: &crate::Module,
1672        pipeline_options: &PipelineOptions,
1673        global: Handle<crate::GlobalVariable>,
1674    ) -> u32 {
1675        let var = &module.global_variables[global];
1676        let crate::TypeInner::BindingArray { size, .. } = module.types[var.ty].inner else {
1677            unreachable!("binding_array_layout_count called on non-binding-array global");
1678        };
1679        let from_shader = match size {
1680            crate::ArraySize::Constant(n) => n.get(),
1681            crate::ArraySize::Pending(_) | crate::ArraySize::Dynamic => 0,
1682        };
1683        let from_layout = var
1684            .binding
1685            .and_then(|br| pipeline_options.binding_array_length_map.get(&br))
1686            .copied()
1687            .unwrap_or(0);
1688        from_shader.max(from_layout).max(1)
1689    }
1690
1691    fn put_binding_array_size_member_index(
1692        &mut self,
1693        index: index::GuardedIndex,
1694        context: &ExpressionContext,
1695    ) -> BackendResult {
1696        match index {
1697            index::GuardedIndex::Expression(expr) => {
1698                write!(self.out, "unsigned(")?;
1699                self.put_expression(expr, context, true)?;
1700                write!(self.out, ")")?;
1701            }
1702            index::GuardedIndex::Known(value) => write!(self.out, "{value}u")?,
1703        }
1704        Ok(())
1705    }
1706
1707    fn put_dynamic_array_max_index(
1708        &mut self,
1709        handle: Handle<crate::GlobalVariable>,
1710        chain_expr: Handle<crate::Expression>,
1711        context: &ExpressionContext,
1712    ) -> BackendResult {
1713        let global = &context.module.global_variables[handle];
1714        let (offset, array_ty) = match context.module.types[global.ty].inner {
1715            crate::TypeInner::Struct { ref members, .. } => match members.last() {
1716                Some(&crate::StructMember { offset, ty, .. }) => (offset, ty),
1717                None => return Err(Error::GenericValidation("Struct has no members".into())),
1718            },
1719            crate::TypeInner::BindingArray { base, .. } => match context.module.types[base].inner {
1720                crate::TypeInner::Struct { ref members, .. } => match members.last() {
1721                    Some(&crate::StructMember { offset, ty, .. }) => (offset, ty),
1722                    None => return Err(Error::GenericValidation("Struct has no members".into())),
1723                },
1724                _ => {
1725                    return Err(Error::GenericValidation(
1726                        "binding_array element must be a struct with a runtime-sized array field"
1727                            .into(),
1728                    ))
1729                }
1730            },
1731            crate::TypeInner::Array {
1732                size: crate::ArraySize::Dynamic,
1733                ..
1734            } => (0, global.ty),
1735            ref ty => {
1736                return Err(Error::GenericValidation(format!(
1737                    "Expected type with dynamic array, got {ty:?}"
1738                )))
1739            }
1740        };
1741
1742        let (size, stride) = match context.module.types[array_ty].inner {
1743            crate::TypeInner::Array { base, stride, .. } => (
1744                context.module.types[base]
1745                    .inner
1746                    .size(context.module.to_ctx()),
1747                stride,
1748            ),
1749            ref ty => {
1750                return Err(Error::GenericValidation(format!(
1751                    "Expected array type, got {ty:?}"
1752                )))
1753            }
1754        };
1755
1756        // When the stride length is larger than the size, the final element's stride of
1757        // bytes would have padding following the value. But the buffer size in
1758        // `buffer_sizes.sizeN` may not include this padding - it only needs to be large
1759        // enough to hold the actual values' bytes.
1760        //
1761        // So subtract off the size to get a byte size that falls at the start or within
1762        // the final element. Then divide by the stride size, to get one less than the
1763        // length, and then add one. This works even if the buffer size does include the
1764        // stride padding, since division rounds towards zero (MSL 2.4 §6.1). It will fail
1765        // if there are zero elements in the array, but the WebGPU `validating shader binding`
1766        // rules, together with draw-time validation when `minBindingSize` is zero,
1767        // prevent that.
1768        write!(
1769            self.out,
1770            "(_buffer_sizes.{member}",
1771            member = ArraySizeMember(handle),
1772        )?;
1773        if let crate::TypeInner::BindingArray { .. } = context.module.types[global.ty].inner {
1774            let Some(array_index) = context.binding_array_index_from_chain(chain_expr, handle)
1775            else {
1776                return Err(Error::GenericValidation(
1777                    "Could not find binding_array index for buffer size".into(),
1778                ));
1779            };
1780            write!(self.out, "[")?;
1781            match array_index {
1782                index::GuardedIndex::Expression(expr) => {
1783                    write!(self.out, "unsigned(")?;
1784                    self.put_expression(expr, context, true)?;
1785                    write!(self.out, ")")?;
1786                }
1787                index::GuardedIndex::Known(i) => {
1788                    write!(self.out, "{i}u")?;
1789                }
1790            }
1791            write!(self.out, "]")?;
1792        }
1793        write!(
1794            self.out,
1795            " - {offset} - {size}) / {stride}",
1796            offset = offset,
1797            size = size,
1798            stride = stride,
1799        )?;
1800        Ok(())
1801    }
1802
1803    /// Emit code for the arithmetic expression of the dot product.
1804    ///
1805    /// The argument `extractor` is a function that accepts a `Writer`, a vector, and
1806    /// an index. It writes out the expression for the vector component at that index.
1807    fn put_dot_product<T: Copy>(
1808        &mut self,
1809        arg: T,
1810        arg1: T,
1811        size: usize,
1812        extractor: impl Fn(&mut Self, T, usize) -> BackendResult,
1813    ) -> BackendResult {
1814        // Write parentheses around the dot product expression to prevent operators
1815        // with different precedences from applying earlier.
1816        write!(self.out, "(")?;
1817
1818        // Cycle through all the components of the vector
1819        for index in 0..size {
1820            // Write the addition to the previous product
1821            // This will print an extra '+' at the beginning but that is fine in msl
1822            write!(self.out, " + ")?;
1823            extractor(self, arg, index)?;
1824            write!(self.out, " * ")?;
1825            extractor(self, arg1, index)?;
1826        }
1827
1828        write!(self.out, ")")?;
1829        Ok(())
1830    }
1831
1832    /// Emit code for the WGSL functions `pack4x{I, U}8[Clamp]`.
1833    fn put_pack4x8(
1834        &mut self,
1835        arg: Handle<crate::Expression>,
1836        context: &ExpressionContext<'_>,
1837        was_signed: bool,
1838        clamp_bounds: Option<(&str, &str)>,
1839    ) -> Result<(), Error> {
1840        let write_arg = |this: &mut Self| -> BackendResult {
1841            if let Some((min, max)) = clamp_bounds {
1842                // Clamping with scalar bounds works (component-wise) even for packed_[u]char4.
1843                write!(this.out, "{NAMESPACE}::clamp(")?;
1844                this.put_expression(arg, context, true)?;
1845                write!(this.out, ", {min}, {max})")?;
1846            } else {
1847                this.put_expression(arg, context, true)?;
1848            }
1849            Ok(())
1850        };
1851
1852        if context.lang_version >= (2, 1) {
1853            let packed_type = if was_signed {
1854                "packed_char4"
1855            } else {
1856                "packed_uchar4"
1857            };
1858            // Metal uses little endian byte order, which matches what WGSL expects here.
1859            write!(self.out, "as_type<uint>({packed_type}(")?;
1860            write_arg(self)?;
1861            write!(self.out, "))")?;
1862        } else {
1863            // MSL < 2.1 doesn't support `as_type` casting between packed chars and scalars.
1864            if was_signed {
1865                write!(self.out, "uint(")?;
1866            }
1867            write!(self.out, "(")?;
1868            write_arg(self)?;
1869            write!(self.out, "[0] & 0xFF) | ((")?;
1870            write_arg(self)?;
1871            write!(self.out, "[1] & 0xFF) << 8) | ((")?;
1872            write_arg(self)?;
1873            write!(self.out, "[2] & 0xFF) << 16) | ((")?;
1874            write_arg(self)?;
1875            write!(self.out, "[3] & 0xFF) << 24)")?;
1876            if was_signed {
1877                write!(self.out, ")")?;
1878            }
1879        }
1880
1881        Ok(())
1882    }
1883
1884    /// Emit code for the isign expression.
1885    ///
1886    fn put_isign(
1887        &mut self,
1888        arg: Handle<crate::Expression>,
1889        context: &ExpressionContext,
1890    ) -> BackendResult {
1891        write!(self.out, "{NAMESPACE}::select({NAMESPACE}::select(")?;
1892        let scalar = context
1893            .resolve_type(arg)
1894            .scalar()
1895            .expect("put_isign should only be called for args which have an integer scalar type")
1896            .to_msl_name();
1897        match context.resolve_type(arg) {
1898            &crate::TypeInner::Vector { size, .. } => {
1899                let size = common::vector_size_str(size);
1900                write!(self.out, "{scalar}{size}(-1), {scalar}{size}(1)")?;
1901            }
1902            _ => {
1903                write!(self.out, "{scalar}(-1), {scalar}(1)")?;
1904            }
1905        }
1906        write!(self.out, ", (")?;
1907        self.put_expression(arg, context, true)?;
1908        write!(self.out, " > 0)), {scalar}(0), (")?;
1909        self.put_expression(arg, context, true)?;
1910        write!(self.out, " == 0))")?;
1911        Ok(())
1912    }
1913
1914    pub(super) fn put_const_expression(
1915        &mut self,
1916        expr_handle: Handle<crate::Expression>,
1917        module: &crate::Module,
1918        mod_info: &valid::ModuleInfo,
1919        arena: &crate::Arena<crate::Expression>,
1920    ) -> BackendResult {
1921        self.put_possibly_const_expression(
1922            expr_handle,
1923            arena,
1924            module,
1925            mod_info,
1926            &(module, mod_info),
1927            |&(_, mod_info), expr| &mod_info[expr],
1928            |writer, &(module, _), expr| writer.put_const_expression(expr, module, mod_info, arena),
1929        )
1930    }
1931
1932    fn put_literal(&mut self, literal: crate::Literal) -> BackendResult {
1933        match literal {
1934            crate::Literal::F64(_) => {
1935                return Err(Error::CapabilityNotSupported(valid::Capabilities::FLOAT64))
1936            }
1937            crate::Literal::F16(value) => {
1938                if value.is_infinite() {
1939                    let sign = if value.is_sign_negative() { "-" } else { "" };
1940                    write!(self.out, "{sign}INFINITY")?;
1941                } else if value.is_nan() {
1942                    write!(self.out, "NAN")?;
1943                } else {
1944                    let suffix = if value.fract() == f16::from_f32(0.0) {
1945                        ".0h"
1946                    } else {
1947                        "h"
1948                    };
1949                    write!(self.out, "{value}{suffix}")?;
1950                }
1951            }
1952            crate::Literal::F32(value) => {
1953                if value.is_infinite() {
1954                    let sign = if value.is_sign_negative() { "-" } else { "" };
1955                    write!(self.out, "{sign}INFINITY")?;
1956                } else if value.is_nan() {
1957                    write!(self.out, "NAN")?;
1958                } else {
1959                    let suffix = if value.fract() == 0.0 { ".0" } else { "" };
1960                    write!(self.out, "{value}{suffix}")?;
1961                }
1962            }
1963            crate::Literal::U16(value) => {
1964                write!(self.out, "static_cast<ushort>({value})")?;
1965            }
1966            crate::Literal::I16(value) => {
1967                write!(self.out, "static_cast<short>({value})")?;
1968            }
1969            crate::Literal::U32(value) => {
1970                write!(self.out, "{value}u")?;
1971            }
1972            crate::Literal::I32(value) => {
1973                // `-2147483648` is parsed as unary negation of positive 2147483648.
1974                // 2147483648 is too large for int32_t meaning the expression gets
1975                // promoted to a int64_t which is not our intention. Avoid this by instead
1976                // using `-2147483647 - 1`.
1977                if value == i32::MIN {
1978                    write!(self.out, "({} - 1)", value + 1)?;
1979                } else {
1980                    write!(self.out, "{value}")?;
1981                }
1982            }
1983            crate::Literal::U64(value) => {
1984                write!(self.out, "{value}uL")?;
1985            }
1986            crate::Literal::I64(value) => {
1987                // `-9223372036854775808` is parsed as unary negation of positive
1988                // 9223372036854775808. 9223372036854775808 is too large for int64_t
1989                // causing Metal to emit a `-Wconstant-conversion` warning, and change the
1990                // value to `-9223372036854775808`. Which would then be negated, possibly
1991                // causing undefined behaviour. Avoid this by instead using
1992                // `-9223372036854775808L - 1L`.
1993                if value == i64::MIN {
1994                    write!(self.out, "({}L - 1L)", value + 1)?;
1995                } else {
1996                    write!(self.out, "{value}L")?;
1997                }
1998            }
1999            crate::Literal::Bool(value) => {
2000                write!(self.out, "{value}")?;
2001            }
2002            crate::Literal::AbstractInt(_) | crate::Literal::AbstractFloat(_) => {
2003                return Err(Error::GenericValidation(
2004                    "Unsupported abstract literal".into(),
2005                ));
2006            }
2007        }
2008        Ok(())
2009    }
2010
2011    #[allow(clippy::too_many_arguments)]
2012    fn put_possibly_const_expression<C, I, E>(
2013        &mut self,
2014        expr_handle: Handle<crate::Expression>,
2015        expressions: &crate::Arena<crate::Expression>,
2016        module: &crate::Module,
2017        mod_info: &valid::ModuleInfo,
2018        ctx: &C,
2019        get_expr_ty: I,
2020        put_expression: E,
2021    ) -> BackendResult
2022    where
2023        I: Fn(&C, Handle<crate::Expression>) -> &TypeResolution,
2024        E: Fn(&mut Self, &C, Handle<crate::Expression>) -> BackendResult,
2025    {
2026        match expressions[expr_handle] {
2027            crate::Expression::Literal(literal) => {
2028                self.put_literal(literal)?;
2029            }
2030            crate::Expression::Constant(handle) => {
2031                let constant = &module.constants[handle];
2032                if constant.name.is_some() {
2033                    write!(self.out, "{}", self.names[&NameKey::Constant(handle)])?;
2034                } else {
2035                    self.put_const_expression(
2036                        constant.init,
2037                        module,
2038                        mod_info,
2039                        &module.global_expressions,
2040                    )?;
2041                }
2042            }
2043            crate::Expression::ZeroValue(ty) => {
2044                let ty_name = TypeContext {
2045                    handle: ty,
2046                    gctx: module.to_ctx(),
2047                    names: &self.names,
2048                    access: crate::StorageAccess::empty(),
2049                    first_time: false,
2050                };
2051                write!(self.out, "{ty_name} {{}}")?;
2052            }
2053            crate::Expression::Compose { ty, ref components } => {
2054                let ty_name = TypeContext {
2055                    handle: ty,
2056                    gctx: module.to_ctx(),
2057                    names: &self.names,
2058                    access: crate::StorageAccess::empty(),
2059                    first_time: false,
2060                };
2061                write!(self.out, "{ty_name}")?;
2062                match module.types[ty].inner {
2063                    crate::TypeInner::Scalar(_)
2064                    | crate::TypeInner::Vector { .. }
2065                    | crate::TypeInner::Matrix { .. } => {
2066                        self.put_call_parameters_impl(
2067                            components.iter().copied(),
2068                            ctx,
2069                            put_expression,
2070                        )?;
2071                    }
2072                    crate::TypeInner::Array { .. } => {
2073                        // Naga Arrays are Metal arrays wrapped in structs, so
2074                        // we need two levels of braces.
2075                        write!(self.out, " {{{{")?;
2076                        for (index, &component) in components.iter().enumerate() {
2077                            if index != 0 {
2078                                write!(self.out, ", ")?;
2079                            }
2080                            put_expression(self, ctx, component)?;
2081                        }
2082                        write!(self.out, "}}}}")?;
2083                    }
2084                    crate::TypeInner::Struct { .. } => {
2085                        write!(self.out, " {{")?;
2086                        for (index, &component) in components.iter().enumerate() {
2087                            if index != 0 {
2088                                write!(self.out, ", ")?;
2089                            }
2090                            // insert padding initialization, if needed
2091                            if self.struct_member_pads.contains(&(ty, index as u32)) {
2092                                write!(self.out, "{{}}, ")?;
2093                            }
2094                            put_expression(self, ctx, component)?;
2095                        }
2096                        write!(self.out, "}}")?;
2097                    }
2098                    _ => return Err(Error::UnsupportedCompose(ty)),
2099                }
2100            }
2101            crate::Expression::Splat { size, value } => {
2102                let scalar = match *get_expr_ty(ctx, value).inner_with(&module.types) {
2103                    crate::TypeInner::Scalar(scalar) => scalar,
2104                    ref ty => {
2105                        return Err(Error::GenericValidation(format!(
2106                            "Expected splat value type must be a scalar, got {ty:?}",
2107                        )))
2108                    }
2109                };
2110                put_numeric_type(&mut self.out, scalar, &[size])?;
2111                write!(self.out, "(")?;
2112                put_expression(self, ctx, value)?;
2113                write!(self.out, ")")?;
2114            }
2115            _ => {
2116                return Err(Error::Override);
2117            }
2118        }
2119
2120        Ok(())
2121    }
2122
2123    /// Emit code for the expression `expr_handle`.
2124    ///
2125    /// The `is_scoped` argument is true if the surrounding operators have the
2126    /// precedence of the comma operator, or lower. So, for example:
2127    ///
2128    /// - Pass `true` for `is_scoped` when writing function arguments, an
2129    ///   expression statement, an initializer expression, or anything already
2130    ///   wrapped in parenthesis.
2131    ///
2132    /// - Pass `false` if it is an operand of a `?:` operator, a `[]`, or really
2133    ///   almost anything else.
2134    pub(super) fn put_expression(
2135        &mut self,
2136        expr_handle: Handle<crate::Expression>,
2137        context: &ExpressionContext,
2138        is_scoped: bool,
2139    ) -> BackendResult {
2140        if let Some(name) = self.named_expressions.get(&expr_handle) {
2141            write!(self.out, "{name}")?;
2142            return Ok(());
2143        }
2144
2145        let expression = &context.function.expressions[expr_handle];
2146        match *expression {
2147            crate::Expression::Literal(_)
2148            | crate::Expression::Constant(_)
2149            | crate::Expression::ZeroValue(_)
2150            | crate::Expression::Compose { .. }
2151            | crate::Expression::Splat { .. } => {
2152                self.put_possibly_const_expression(
2153                    expr_handle,
2154                    &context.function.expressions,
2155                    context.module,
2156                    context.mod_info,
2157                    context,
2158                    |context, expr: Handle<crate::Expression>| &context.info[expr].ty,
2159                    |writer, context, expr| writer.put_expression(expr, context, true),
2160                )?;
2161            }
2162            crate::Expression::Override(_) => return Err(Error::Override),
2163            crate::Expression::Access { base, .. }
2164            | crate::Expression::AccessIndex { base, .. } => {
2165                // This is an acceptable place to generate a `ReadZeroSkipWrite` check.
2166                // Since `put_bounds_checks` and `put_access_chain` handle an entire
2167                // access chain at a time, recursing back through `put_expression` only
2168                // for index expressions and the base object, we will never see intermediate
2169                // `Access` or `AccessIndex` expressions here.
2170                let policy = context.choose_bounds_check_policy(base);
2171                if policy == index::BoundsCheckPolicy::ReadZeroSkipWrite
2172                    && self.put_bounds_checks(
2173                        expr_handle,
2174                        context,
2175                        back::Level(0),
2176                        if is_scoped { "" } else { "(" },
2177                    )?
2178                {
2179                    write!(self.out, " ? ")?;
2180                    self.put_access_chain(expr_handle, policy, context)?;
2181                    write!(self.out, " : ")?;
2182
2183                    if context.resolve_type(base).pointer_space().is_some() {
2184                        // We can't just use `DefaultConstructible` if this is a pointer.
2185                        // Instead, we create a dummy local variable to serve as pointer
2186                        // target if the access is out of bounds.
2187                        let result_ty = context.info[expr_handle]
2188                            .ty
2189                            .inner_with(&context.module.types)
2190                            .pointer_base_type();
2191                        let result_ty_handle = match result_ty {
2192                            Some(TypeResolution::Handle(handle)) => handle,
2193                            Some(TypeResolution::Value(_)) => {
2194                                // As long as the result of a pointer access expression is
2195                                // passed to a function or stored in a let binding, the
2196                                // type will be in the arena. If additional uses of
2197                                // pointers become valid, this assumption might no longer
2198                                // hold. Note that the LHS of a load or store doesn't
2199                                // take this path -- there is dedicated code in `put_load`
2200                                // and `put_store`.
2201                                unreachable!(
2202                                    "Expected type {result_ty:?} of access through pointer type {base:?} to be in the arena",
2203                                );
2204                            }
2205                            None => {
2206                                unreachable!(
2207                                    "Expected access through pointer type {base:?} to return a pointer, but got {result_ty:?}",
2208                                )
2209                            }
2210                        };
2211                        let name_key =
2212                            NameKey::oob_local_for_type(context.origin, result_ty_handle);
2213                        self.out.write_str(&self.names[&name_key])?;
2214                    } else {
2215                        write!(self.out, "DefaultConstructible()")?;
2216                    }
2217
2218                    if !is_scoped {
2219                        write!(self.out, ")")?;
2220                    }
2221                } else {
2222                    self.put_access_chain(expr_handle, policy, context)?;
2223                }
2224            }
2225            crate::Expression::Swizzle {
2226                size,
2227                vector,
2228                pattern,
2229            } => {
2230                self.put_wrapped_expression_for_packed_vec3_access(
2231                    vector,
2232                    context,
2233                    false,
2234                    &Self::put_expression,
2235                )?;
2236                write!(self.out, ".")?;
2237                for &sc in pattern[..size as usize].iter() {
2238                    write!(self.out, "{}", back::COMPONENTS[sc as usize])?;
2239                }
2240            }
2241            crate::Expression::FunctionArgument(index) => {
2242                let name_key = match context.origin {
2243                    FunctionOrigin::Handle(handle) => NameKey::FunctionArgument(handle, index),
2244                    FunctionOrigin::EntryPoint(ep_index) => {
2245                        NameKey::EntryPointArgument(ep_index, index)
2246                    }
2247                };
2248                let name = &self.names[&name_key];
2249                write!(self.out, "{name}")?;
2250            }
2251            crate::Expression::GlobalVariable(handle) => {
2252                let name = &self.names[&NameKey::GlobalVariable(handle)];
2253                write!(self.out, "{name}")?;
2254            }
2255            crate::Expression::LocalVariable(handle) => {
2256                let name_key = NameKey::local(context.origin, handle);
2257                let name = &self.names[&name_key];
2258                write!(self.out, "{name}")?;
2259            }
2260            crate::Expression::Load { pointer } => self.put_load(pointer, context, is_scoped)?,
2261            crate::Expression::ImageSample {
2262                coordinate,
2263                image,
2264                sampler,
2265                clamp_to_edge: true,
2266                gather: None,
2267                array_index: None,
2268                offset: None,
2269                level: crate::SampleLevel::Zero,
2270                depth_ref: None,
2271            } => {
2272                write!(self.out, "{IMAGE_SAMPLE_BASE_CLAMP_TO_EDGE_FUNCTION}(")?;
2273                self.put_expression(image, context, true)?;
2274                write!(self.out, ", ")?;
2275                self.put_expression(sampler, context, true)?;
2276                write!(self.out, ", ")?;
2277                self.put_expression(coordinate, context, true)?;
2278                write!(self.out, ")")?;
2279            }
2280            crate::Expression::ImageSample {
2281                image,
2282                sampler,
2283                gather,
2284                coordinate,
2285                array_index,
2286                offset,
2287                level,
2288                depth_ref,
2289                clamp_to_edge,
2290            } => {
2291                if clamp_to_edge {
2292                    return Err(Error::GenericValidation(
2293                        "ImageSample::clamp_to_edge should have been validated out".to_string(),
2294                    ));
2295                }
2296
2297                let main_op = match gather {
2298                    Some(_) => "gather",
2299                    None => "sample",
2300                };
2301                let comparison_op = match depth_ref {
2302                    Some(_) => "_compare",
2303                    None => "",
2304                };
2305                self.put_expression(image, context, false)?;
2306                write!(self.out, ".{main_op}{comparison_op}(")?;
2307                self.put_expression(sampler, context, true)?;
2308                write!(self.out, ", ")?;
2309                self.put_expression(coordinate, context, true)?;
2310                if let Some(expr) = array_index {
2311                    write!(self.out, ", ")?;
2312                    self.put_expression(expr, context, true)?;
2313                }
2314                if let Some(dref) = depth_ref {
2315                    write!(self.out, ", ")?;
2316                    self.put_expression(dref, context, true)?;
2317                }
2318
2319                self.put_image_sample_level(image, level, context)?;
2320
2321                if let Some(offset) = offset {
2322                    write!(self.out, ", ")?;
2323                    self.put_expression(offset, context, true)?;
2324                }
2325
2326                match gather {
2327                    None | Some(crate::SwizzleComponent::X) => {}
2328                    Some(component) => {
2329                        let is_cube_map = match *context.resolve_type(image) {
2330                            crate::TypeInner::Image {
2331                                dim: crate::ImageDimension::Cube,
2332                                ..
2333                            } => true,
2334                            _ => false,
2335                        };
2336                        // Offset always comes before the gather, except
2337                        // in cube maps where it's not applicable
2338                        if offset.is_none() && !is_cube_map {
2339                            write!(self.out, ", {NAMESPACE}::int2(0)")?;
2340                        }
2341                        let letter = back::COMPONENTS[component as usize];
2342                        write!(self.out, ", {NAMESPACE}::component::{letter}")?;
2343                    }
2344                }
2345                write!(self.out, ")")?;
2346            }
2347            crate::Expression::ImageLoad {
2348                image,
2349                coordinate,
2350                array_index,
2351                sample,
2352                level,
2353            } => {
2354                let address = TexelAddress {
2355                    coordinate,
2356                    array_index,
2357                    sample,
2358                    level: level.map(LevelOfDetail::Direct),
2359                };
2360                self.put_image_load(expr_handle, image, address, context)?;
2361            }
2362            //Note: for all the queries, the signed integers are expected,
2363            // so a conversion is needed.
2364            crate::Expression::ImageQuery { image, query } => match query {
2365                crate::ImageQuery::Size { level } => {
2366                    self.put_image_size_query(
2367                        image,
2368                        level.map(LevelOfDetail::Direct),
2369                        crate::ScalarKind::Uint,
2370                        context,
2371                    )?;
2372                }
2373                crate::ImageQuery::NumLevels => {
2374                    self.put_expression(image, context, false)?;
2375                    write!(self.out, ".get_num_mip_levels()")?;
2376                }
2377                crate::ImageQuery::NumLayers => {
2378                    self.put_expression(image, context, false)?;
2379                    write!(self.out, ".get_array_size()")?;
2380                }
2381                crate::ImageQuery::NumSamples => {
2382                    self.put_expression(image, context, false)?;
2383                    write!(self.out, ".get_num_samples()")?;
2384                }
2385            },
2386            crate::Expression::Unary { op, expr } => {
2387                let op_str = match op {
2388                    crate::UnaryOperator::Negate => {
2389                        match context.resolve_type(expr).scalar_kind() {
2390                            Some(crate::ScalarKind::Sint) => NEG_FUNCTION,
2391                            _ => "-",
2392                        }
2393                    }
2394                    crate::UnaryOperator::LogicalNot => "!",
2395                    crate::UnaryOperator::BitwiseNot => "~",
2396                };
2397                write!(self.out, "{op_str}(")?;
2398                self.put_expression(expr, context, false)?;
2399                write!(self.out, ")")?;
2400            }
2401            crate::Expression::Binary { op, left, right } => {
2402                let kind = context
2403                    .resolve_type(left)
2404                    .scalar_kind()
2405                    .ok_or(Error::UnsupportedBinaryOp(op))?;
2406
2407                if op == crate::BinaryOperator::Divide
2408                    && (kind == crate::ScalarKind::Sint || kind == crate::ScalarKind::Uint)
2409                    && context.emit_int_div_checks
2410                {
2411                    write!(self.out, "{DIV_FUNCTION}(")?;
2412                    self.put_expression(left, context, true)?;
2413                    write!(self.out, ", ")?;
2414                    self.put_expression(right, context, true)?;
2415                    write!(self.out, ")")?;
2416                } else if op == crate::BinaryOperator::Modulo
2417                    && (kind == crate::ScalarKind::Sint || kind == crate::ScalarKind::Uint)
2418                    && context.emit_int_div_checks
2419                {
2420                    write!(self.out, "{MOD_FUNCTION}(")?;
2421                    self.put_expression(left, context, true)?;
2422                    write!(self.out, ", ")?;
2423                    self.put_expression(right, context, true)?;
2424                    write!(self.out, ")")?;
2425                } else if op == crate::BinaryOperator::Modulo && kind == crate::ScalarKind::Float {
2426                    // TODO: handle undefined behavior of BinaryOperator::Modulo
2427                    //
2428                    // float:
2429                    // if right == 0 return ? see https://github.com/gpuweb/gpuweb/issues/2798
2430                    write!(self.out, "{NAMESPACE}::fmod(")?;
2431                    self.put_expression(left, context, true)?;
2432                    write!(self.out, ", ")?;
2433                    self.put_expression(right, context, true)?;
2434                    write!(self.out, ")")?;
2435                } else if (op == crate::BinaryOperator::Add
2436                    || op == crate::BinaryOperator::Subtract
2437                    || op == crate::BinaryOperator::Multiply)
2438                    && kind == crate::ScalarKind::Sint
2439                {
2440                    let to_unsigned = |ty: &crate::TypeInner| match *ty {
2441                        crate::TypeInner::Scalar(scalar) => {
2442                            Ok(crate::TypeInner::Scalar(crate::Scalar {
2443                                kind: crate::ScalarKind::Uint,
2444                                ..scalar
2445                            }))
2446                        }
2447                        crate::TypeInner::Vector { size, scalar } => Ok(crate::TypeInner::Vector {
2448                            size,
2449                            scalar: crate::Scalar {
2450                                kind: crate::ScalarKind::Uint,
2451                                ..scalar
2452                            },
2453                        }),
2454                        _ => Err(Error::UnsupportedBitCast(ty.clone())),
2455                    };
2456
2457                    // Avoid undefined behaviour due to overflowing signed
2458                    // integer arithmetic. Cast the operands to unsigned prior
2459                    // to performing the operation, then cast the result back
2460                    // to signed.
2461                    self.put_bitcasted_expression(
2462                        context.resolve_type(expr_handle),
2463                        expr_handle,
2464                        context,
2465                        &|writer, context, is_scoped| {
2466                            writer.put_binop(
2467                                op,
2468                                left,
2469                                right,
2470                                context,
2471                                is_scoped,
2472                                &|writer, expr, context, _is_scoped| {
2473                                    writer.put_bitcasted_expression(
2474                                        &to_unsigned(context.resolve_type(expr))?,
2475                                        expr,
2476                                        context,
2477                                        &|writer, context, is_scoped| {
2478                                            writer.put_expression(expr, context, is_scoped)
2479                                        },
2480                                    )
2481                                },
2482                            )
2483                        },
2484                    )?;
2485                } else {
2486                    self.put_binop(op, left, right, context, is_scoped, &Self::put_expression)?;
2487                }
2488            }
2489            crate::Expression::Select {
2490                condition,
2491                accept,
2492                reject,
2493            } => match *context.resolve_type(condition) {
2494                crate::TypeInner::Scalar(crate::Scalar {
2495                    kind: crate::ScalarKind::Bool,
2496                    ..
2497                }) => {
2498                    if !is_scoped {
2499                        write!(self.out, "(")?;
2500                    }
2501                    self.put_expression(condition, context, false)?;
2502                    write!(self.out, " ? ")?;
2503                    self.put_expression(accept, context, false)?;
2504                    write!(self.out, " : ")?;
2505                    self.put_expression(reject, context, false)?;
2506                    if !is_scoped {
2507                        write!(self.out, ")")?;
2508                    }
2509                }
2510                crate::TypeInner::Vector {
2511                    scalar:
2512                        crate::Scalar {
2513                            kind: crate::ScalarKind::Bool,
2514                            ..
2515                        },
2516                    ..
2517                } => {
2518                    write!(self.out, "{NAMESPACE}::select(")?;
2519                    self.put_expression(reject, context, true)?;
2520                    write!(self.out, ", ")?;
2521                    self.put_expression(accept, context, true)?;
2522                    write!(self.out, ", ")?;
2523                    self.put_expression(condition, context, true)?;
2524                    write!(self.out, ")")?;
2525                }
2526                ref ty => {
2527                    return Err(Error::GenericValidation(format!(
2528                        "Expected select condition to be a non-bool type, got {ty:?}",
2529                    )))
2530                }
2531            },
2532            crate::Expression::Derivative { axis, expr, .. } => {
2533                use crate::DerivativeAxis as Axis;
2534                let op = match axis {
2535                    Axis::X => "dfdx",
2536                    Axis::Y => "dfdy",
2537                    Axis::Width => "fwidth",
2538                };
2539                write!(self.out, "{NAMESPACE}::{op}")?;
2540                self.put_call_parameters(iter::once(expr), context)?;
2541            }
2542            crate::Expression::Relational { fun, argument } => {
2543                let op = match fun {
2544                    crate::RelationalFunction::Any => "any",
2545                    crate::RelationalFunction::All => "all",
2546                    crate::RelationalFunction::IsNan => "isnan",
2547                    crate::RelationalFunction::IsInf => "isinf",
2548                };
2549                write!(self.out, "{NAMESPACE}::{op}")?;
2550                self.put_call_parameters(iter::once(argument), context)?;
2551            }
2552            crate::Expression::Math {
2553                fun,
2554                arg,
2555                arg1,
2556                arg2,
2557                arg3,
2558            } => {
2559                use crate::MathFunction as Mf;
2560
2561                let arg_type = context.resolve_type(arg);
2562                let scalar_argument = match arg_type {
2563                    &crate::TypeInner::Scalar(_) => true,
2564                    _ => false,
2565                };
2566
2567                let fun_name = match fun {
2568                    // comparison
2569                    Mf::Abs => "abs",
2570                    Mf::Min => "min",
2571                    Mf::Max => "max",
2572                    Mf::Clamp => "clamp",
2573                    Mf::Saturate => "saturate",
2574                    // trigonometry
2575                    Mf::Cos => "cos",
2576                    Mf::Cosh => "cosh",
2577                    Mf::Sin => "sin",
2578                    Mf::Sinh => "sinh",
2579                    Mf::Tan => "tan",
2580                    Mf::Tanh => "tanh",
2581                    Mf::Acos => "acos",
2582                    Mf::Asin => "asin",
2583                    Mf::Atan => "atan",
2584                    Mf::Atan2 => "atan2",
2585                    Mf::Asinh => "asinh",
2586                    Mf::Acosh => "acosh",
2587                    Mf::Atanh => "atanh",
2588                    Mf::Radians => "",
2589                    Mf::Degrees => "",
2590                    // decomposition
2591                    Mf::Ceil => "ceil",
2592                    Mf::Floor => "floor",
2593                    Mf::Round => "rint",
2594                    Mf::Fract => "fract",
2595                    Mf::Trunc => "trunc",
2596                    Mf::Modf => MODF_FUNCTION,
2597                    Mf::Frexp => FREXP_FUNCTION,
2598                    Mf::Ldexp => "ldexp",
2599                    // exponent
2600                    Mf::Exp => "exp",
2601                    Mf::Exp2 => "exp2",
2602                    Mf::Log => "log",
2603                    Mf::Log2 => "log2",
2604                    Mf::Pow => "pow",
2605                    // geometry
2606                    Mf::Dot => match *context.resolve_type(arg) {
2607                        crate::TypeInner::Vector {
2608                            scalar:
2609                                crate::Scalar {
2610                                    // Resolve float values to MSL's builtin dot function.
2611                                    kind: crate::ScalarKind::Float,
2612                                    ..
2613                                },
2614                            ..
2615                        } => "dot",
2616                        crate::TypeInner::Vector {
2617                            size,
2618                            scalar:
2619                                scalar @ crate::Scalar {
2620                                    kind: crate::ScalarKind::Sint | crate::ScalarKind::Uint,
2621                                    ..
2622                                },
2623                        } => {
2624                            // Integer vector dot: call our mangled helper `dot_{type}{N}(a, b)`.
2625                            let fun_name = self.get_dot_wrapper_function_helper_name(scalar, size);
2626                            write!(self.out, "{fun_name}(")?;
2627                            self.put_expression(arg, context, true)?;
2628                            write!(self.out, ", ")?;
2629                            self.put_expression(arg1.unwrap(), context, true)?;
2630                            write!(self.out, ")")?;
2631                            return Ok(());
2632                        }
2633                        _ => unreachable!(
2634                            "Correct TypeInner for dot product should be already validated"
2635                        ),
2636                    },
2637                    fun @ (Mf::Dot4I8Packed | Mf::Dot4U8Packed) => {
2638                        if context.lang_version >= (2, 1) {
2639                            // Write potentially optimizable code using `packed_(u?)char4`.
2640                            // The two function arguments were already reinterpreted as packed (signed
2641                            // or unsigned) chars in `Self::put_block`.
2642                            let packed_type = match fun {
2643                                Mf::Dot4I8Packed => "packed_char4",
2644                                Mf::Dot4U8Packed => "packed_uchar4",
2645                                _ => unreachable!(),
2646                            };
2647
2648                            return self.put_dot_product(
2649                                Reinterpreted::new(packed_type, arg),
2650                                Reinterpreted::new(packed_type, arg1.unwrap()),
2651                                4,
2652                                |writer, arg, index| {
2653                                    // MSL implicitly promotes these (signed or unsigned) chars to
2654                                    // `int` or `uint` in the multiplication, so no overflow can occur.
2655                                    write!(writer.out, "{arg}[{index}]")?;
2656                                    Ok(())
2657                                },
2658                            );
2659                        } else {
2660                            // Fall back to a polyfill since MSL < 2.1 doesn't seem to support
2661                            // bitcasting from uint to `packed_char4` or `packed_uchar4`.
2662                            // See <https://github.com/gfx-rs/wgpu/pull/7574#issuecomment-2835464472>.
2663                            let conversion = match fun {
2664                                Mf::Dot4I8Packed => "int",
2665                                Mf::Dot4U8Packed => "",
2666                                _ => unreachable!(),
2667                            };
2668
2669                            return self.put_dot_product(
2670                                arg,
2671                                arg1.unwrap(),
2672                                4,
2673                                |writer, arg, index| {
2674                                    write!(writer.out, "({conversion}(")?;
2675                                    writer.put_expression(arg, context, true)?;
2676                                    if index == 3 {
2677                                        write!(writer.out, ") >> 24)")?;
2678                                    } else {
2679                                        write!(writer.out, ") << {} >> 24)", (3 - index) * 8)?;
2680                                    }
2681                                    Ok(())
2682                                },
2683                            );
2684                        }
2685                    }
2686                    Mf::Outer => return Err(Error::UnsupportedCall(format!("{fun:?}"))),
2687                    Mf::Cross => "cross",
2688                    Mf::Distance => "distance",
2689                    Mf::Length if scalar_argument => "abs",
2690                    Mf::Length => "length",
2691                    Mf::Normalize => "normalize",
2692                    Mf::FaceForward => "faceforward",
2693                    Mf::Reflect => "reflect",
2694                    Mf::Refract => "refract",
2695                    // computational
2696                    Mf::Sign => match arg_type.scalar_kind() {
2697                        Some(crate::ScalarKind::Sint) => {
2698                            return self.put_isign(arg, context);
2699                        }
2700                        _ => "sign",
2701                    },
2702                    Mf::Fma => "fma",
2703                    Mf::Mix => "mix",
2704                    Mf::Step => "step",
2705                    Mf::SmoothStep => "smoothstep",
2706                    Mf::Sqrt => "sqrt",
2707                    Mf::InverseSqrt => "rsqrt",
2708                    Mf::Inverse => return Err(Error::UnsupportedCall(format!("{fun:?}"))),
2709                    Mf::Transpose => "transpose",
2710                    Mf::Determinant => "determinant",
2711                    Mf::QuantizeToF16 => "",
2712                    // bits
2713                    Mf::CountTrailingZeros => "ctz",
2714                    Mf::CountLeadingZeros => "clz",
2715                    Mf::CountOneBits => "popcount",
2716                    Mf::ReverseBits => "reverse_bits",
2717                    Mf::ExtractBits => "",
2718                    Mf::InsertBits => "",
2719                    Mf::FirstTrailingBit => "",
2720                    Mf::FirstLeadingBit => "",
2721                    // data packing
2722                    Mf::Pack4x8snorm => "pack_float_to_snorm4x8",
2723                    Mf::Pack4x8unorm => "pack_float_to_unorm4x8",
2724                    Mf::Pack2x16snorm => "pack_float_to_snorm2x16",
2725                    Mf::Pack2x16unorm => "pack_float_to_unorm2x16",
2726                    Mf::Pack2x16float => "",
2727                    Mf::Pack4xI8 => "",
2728                    Mf::Pack4xU8 => "",
2729                    Mf::Pack4xI8Clamp => "",
2730                    Mf::Pack4xU8Clamp => "",
2731                    // data unpacking
2732                    Mf::Unpack4x8snorm => "unpack_snorm4x8_to_float",
2733                    Mf::Unpack4x8unorm => "unpack_unorm4x8_to_float",
2734                    Mf::Unpack2x16snorm => "unpack_snorm2x16_to_float",
2735                    Mf::Unpack2x16unorm => "unpack_unorm2x16_to_float",
2736                    Mf::Unpack2x16float => "",
2737                    Mf::Unpack4xI8 => "",
2738                    Mf::Unpack4xU8 => "",
2739                };
2740
2741                match fun {
2742                    Mf::ReverseBits | Mf::ExtractBits | Mf::InsertBits => {
2743                        // reverse_bits is listed as requiring MSL 2.1 but that
2744                        // is a copy/paste error. Looking at previous snapshots
2745                        // on web.archive.org it's present in MSL 1.2.
2746                        //
2747                        // https://developer.apple.com/library/archive/documentation/Miscellaneous/Conceptual/MetalProgrammingGuide/WhatsNewiniOS10tvOS10andOSX1012/WhatsNewiniOS10tvOS10andOSX1012.html
2748                        // also talks about MSL 1.2 adding "New integer
2749                        // functions to extract, insert, and reverse bits, as
2750                        // described in Integer Functions."
2751                        if context.lang_version < (1, 2) {
2752                            return Err(Error::UnsupportedFunction(fun_name.to_string()));
2753                        }
2754                    }
2755                    _ => {}
2756                }
2757
2758                match fun {
2759                    Mf::Abs if arg_type.scalar_kind() == Some(crate::ScalarKind::Sint) => {
2760                        write!(self.out, "{ABS_FUNCTION}(")?;
2761                        self.put_expression(arg, context, true)?;
2762                        write!(self.out, ")")?;
2763                    }
2764                    Mf::Distance if scalar_argument => {
2765                        write!(self.out, "{NAMESPACE}::abs(")?;
2766                        self.put_expression(arg, context, false)?;
2767                        write!(self.out, " - ")?;
2768                        self.put_expression(arg1.unwrap(), context, false)?;
2769                        write!(self.out, ")")?;
2770                    }
2771                    Mf::FirstTrailingBit => {
2772                        let scalar = context.resolve_type(arg).scalar().unwrap();
2773                        let constant = scalar.width * 8 + 1;
2774
2775                        write!(self.out, "((({NAMESPACE}::ctz(")?;
2776                        self.put_expression(arg, context, true)?;
2777                        write!(self.out, ") + 1) % {constant}) - 1)")?;
2778                    }
2779                    Mf::FirstLeadingBit => {
2780                        let inner = context.resolve_type(arg);
2781                        let scalar = inner.scalar().unwrap();
2782                        let constant = scalar.width * 8 - 1;
2783
2784                        write!(
2785                            self.out,
2786                            "{NAMESPACE}::select({constant} - {NAMESPACE}::clz("
2787                        )?;
2788
2789                        if scalar.kind == crate::ScalarKind::Sint {
2790                            write!(self.out, "{NAMESPACE}::select(")?;
2791                            self.put_expression(arg, context, true)?;
2792                            write!(self.out, ", ~")?;
2793                            self.put_expression(arg, context, true)?;
2794                            write!(self.out, ", ")?;
2795                            self.put_expression(arg, context, true)?;
2796                            write!(self.out, " < 0)")?;
2797                        } else {
2798                            self.put_expression(arg, context, true)?;
2799                        }
2800
2801                        write!(self.out, "), ")?;
2802
2803                        // or metal will complain that select is ambiguous
2804                        match *inner {
2805                            crate::TypeInner::Vector { size, scalar } => {
2806                                let size = common::vector_size_str(size);
2807                                let name = scalar.to_msl_name();
2808                                write!(self.out, "{name}{size}")?;
2809                            }
2810                            crate::TypeInner::Scalar(scalar) => {
2811                                let name = scalar.to_msl_name();
2812                                write!(self.out, "{name}")?;
2813                            }
2814                            _ => (),
2815                        }
2816
2817                        write!(self.out, "(-1), ")?;
2818                        self.put_expression(arg, context, true)?;
2819                        write!(self.out, " == 0")?;
2820                        if scalar.kind == crate::ScalarKind::Sint {
2821                            write!(self.out, " || ")?;
2822                            self.put_expression(arg, context, true)?;
2823                            write!(self.out, " == -1")?;
2824                        }
2825                        write!(self.out, ")")?;
2826                    }
2827                    Mf::Unpack2x16float => {
2828                        write!(self.out, "float2(as_type<half2>(")?;
2829                        self.put_expression(arg, context, false)?;
2830                        write!(self.out, "))")?;
2831                    }
2832                    Mf::Pack2x16float => {
2833                        write!(self.out, "as_type<uint>(half2(")?;
2834                        self.put_expression(arg, context, false)?;
2835                        write!(self.out, "))")?;
2836                    }
2837                    Mf::ExtractBits => {
2838                        // The behavior of ExtractBits is undefined when offset + count > bit_width. We need
2839                        // to first sanitize the offset and count first. If we don't do this, Apple chips
2840                        // will return out-of-spec values if the extracted range is not within the bit width.
2841                        //
2842                        // This encodes the exact formula specified by the wgsl spec, without temporary values:
2843                        // https://gpuweb.github.io/gpuweb/wgsl/#extractBits-unsigned-builtin
2844                        //
2845                        // w = sizeof(x) * 8
2846                        // o = min(offset, w)
2847                        // tmp = w - o
2848                        // c = min(count, tmp)
2849                        //
2850                        // bitfieldExtract(x, o, c)
2851                        //
2852                        // extract_bits(e, min(offset, w), min(count, w - min(offset, w))))
2853
2854                        let scalar_bits = context.resolve_type(arg).scalar_width().unwrap() * 8;
2855
2856                        write!(self.out, "{NAMESPACE}::extract_bits(")?;
2857                        self.put_expression(arg, context, true)?;
2858                        write!(self.out, ", {NAMESPACE}::min(")?;
2859                        self.put_expression(arg1.unwrap(), context, true)?;
2860                        write!(self.out, ", {scalar_bits}u), {NAMESPACE}::min(")?;
2861                        self.put_expression(arg2.unwrap(), context, true)?;
2862                        write!(self.out, ", {scalar_bits}u - {NAMESPACE}::min(")?;
2863                        self.put_expression(arg1.unwrap(), context, true)?;
2864                        write!(self.out, ", {scalar_bits}u)))")?;
2865                    }
2866                    Mf::InsertBits => {
2867                        // The behavior of InsertBits has the same issue as ExtractBits.
2868                        //
2869                        // insertBits(e, newBits, min(offset, w), min(count, w - min(offset, w))))
2870
2871                        let scalar_bits = context.resolve_type(arg).scalar_width().unwrap() * 8;
2872
2873                        write!(self.out, "{NAMESPACE}::insert_bits(")?;
2874                        self.put_expression(arg, context, true)?;
2875                        write!(self.out, ", ")?;
2876                        self.put_expression(arg1.unwrap(), context, true)?;
2877                        write!(self.out, ", {NAMESPACE}::min(")?;
2878                        self.put_expression(arg2.unwrap(), context, true)?;
2879                        write!(self.out, ", {scalar_bits}u), {NAMESPACE}::min(")?;
2880                        self.put_expression(arg3.unwrap(), context, true)?;
2881                        write!(self.out, ", {scalar_bits}u - {NAMESPACE}::min(")?;
2882                        self.put_expression(arg2.unwrap(), context, true)?;
2883                        write!(self.out, ", {scalar_bits}u)))")?;
2884                    }
2885                    Mf::Radians => {
2886                        write!(self.out, "((")?;
2887                        self.put_expression(arg, context, false)?;
2888                        write!(self.out, ") * 0.017453292519943295474)")?;
2889                    }
2890                    Mf::Degrees => {
2891                        write!(self.out, "((")?;
2892                        self.put_expression(arg, context, false)?;
2893                        write!(self.out, ") * 57.295779513082322865)")?;
2894                    }
2895                    Mf::Modf | Mf::Frexp => {
2896                        write!(self.out, "{fun_name}")?;
2897                        self.put_call_parameters(iter::once(arg), context)?;
2898                    }
2899                    Mf::Pack4xI8 => self.put_pack4x8(arg, context, true, None)?,
2900                    Mf::Pack4xU8 => self.put_pack4x8(arg, context, false, None)?,
2901                    Mf::Pack4xI8Clamp => {
2902                        self.put_pack4x8(arg, context, true, Some(("-128", "127")))?
2903                    }
2904                    Mf::Pack4xU8Clamp => {
2905                        self.put_pack4x8(arg, context, false, Some(("0", "255")))?
2906                    }
2907                    fun @ (Mf::Unpack4xI8 | Mf::Unpack4xU8) => {
2908                        let sign_prefix = if matches!(fun, Mf::Unpack4xU8) {
2909                            "u"
2910                        } else {
2911                            ""
2912                        };
2913
2914                        if context.lang_version >= (2, 1) {
2915                            // Metal uses little endian byte order, which matches what WGSL expects here.
2916                            write!(
2917                                self.out,
2918                                "{sign_prefix}int4(as_type<packed_{sign_prefix}char4>("
2919                            )?;
2920                            self.put_expression(arg, context, true)?;
2921                            write!(self.out, "))")?;
2922                        } else {
2923                            // MSL < 2.1 doesn't support `as_type` casting between packed chars and scalars.
2924                            write!(self.out, "({sign_prefix}int4(")?;
2925                            self.put_expression(arg, context, true)?;
2926                            write!(self.out, ", ")?;
2927                            self.put_expression(arg, context, true)?;
2928                            write!(self.out, " >> 8, ")?;
2929                            self.put_expression(arg, context, true)?;
2930                            write!(self.out, " >> 16, ")?;
2931                            self.put_expression(arg, context, true)?;
2932                            write!(self.out, " >> 24) << 24 >> 24)")?;
2933                        }
2934                    }
2935                    Mf::QuantizeToF16 => {
2936                        match *context.resolve_type(arg) {
2937                            crate::TypeInner::Scalar { .. } => write!(self.out, "float(half(")?,
2938                            crate::TypeInner::Vector { size, .. } => write!(
2939                                self.out,
2940                                "{NAMESPACE}::float{size}({NAMESPACE}::half{size}(",
2941                                size = common::vector_size_str(size),
2942                            )?,
2943                            _ => unreachable!(
2944                                "Correct TypeInner for QuantizeToF16 should be already validated"
2945                            ),
2946                        };
2947
2948                        self.put_expression(arg, context, true)?;
2949                        write!(self.out, "))")?;
2950                    }
2951                    _ => {
2952                        write!(self.out, "{NAMESPACE}::{fun_name}")?;
2953                        self.put_call_parameters(
2954                            iter::once(arg).chain(arg1).chain(arg2).chain(arg3),
2955                            context,
2956                        )?;
2957                    }
2958                }
2959            }
2960            crate::Expression::As {
2961                expr,
2962                kind,
2963                convert,
2964            } => match *context.resolve_type(expr) {
2965                crate::TypeInner::Scalar(src) | crate::TypeInner::Vector { scalar: src, .. } => {
2966                    if src.kind == crate::ScalarKind::Float
2967                        && (kind == crate::ScalarKind::Sint || kind == crate::ScalarKind::Uint)
2968                        && convert.is_some()
2969                    {
2970                        // Use helper functions for float to int casts in order to avoid
2971                        // undefined behaviour when value is out of range for the target
2972                        // type.
2973                        let fun_name = match (kind, convert) {
2974                            (crate::ScalarKind::Sint, Some(4)) => F2I32_FUNCTION,
2975                            (crate::ScalarKind::Uint, Some(4)) => F2U32_FUNCTION,
2976                            (crate::ScalarKind::Sint, Some(8)) => F2I64_FUNCTION,
2977                            (crate::ScalarKind::Uint, Some(8)) => F2U64_FUNCTION,
2978                            _ => unreachable!(),
2979                        };
2980                        write!(self.out, "{fun_name}(")?;
2981                        self.put_expression(expr, context, true)?;
2982                        write!(self.out, ")")?;
2983                    } else {
2984                        let target_scalar = crate::Scalar {
2985                            kind,
2986                            width: convert.unwrap_or(src.width),
2987                        };
2988                        let op = match convert {
2989                            Some(_) => "static_cast",
2990                            None => "as_type",
2991                        };
2992                        write!(self.out, "{op}<")?;
2993                        match *context.resolve_type(expr) {
2994                            crate::TypeInner::Vector { size, .. } => {
2995                                put_numeric_type(&mut self.out, target_scalar, &[size])?
2996                            }
2997                            _ => put_numeric_type(&mut self.out, target_scalar, &[])?,
2998                        };
2999                        write!(self.out, ">(")?;
3000                        self.put_expression(expr, context, true)?;
3001                        write!(self.out, ")")?;
3002                    }
3003                }
3004                crate::TypeInner::Matrix {
3005                    columns,
3006                    rows,
3007                    scalar,
3008                } => {
3009                    let target_scalar = crate::Scalar {
3010                        kind,
3011                        width: convert.unwrap_or(scalar.width),
3012                    };
3013                    put_numeric_type(&mut self.out, target_scalar, &[rows, columns])?;
3014                    write!(self.out, "(")?;
3015                    self.put_expression(expr, context, true)?;
3016                    write!(self.out, ")")?;
3017                }
3018                ref ty => {
3019                    return Err(Error::GenericValidation(format!(
3020                        "Unsupported type for As: {ty:?}"
3021                    )))
3022                }
3023            },
3024            // has to be a named expression
3025            crate::Expression::CallResult(_)
3026            | crate::Expression::AtomicResult { .. }
3027            | crate::Expression::WorkGroupUniformLoadResult { .. }
3028            | crate::Expression::SubgroupBallotResult
3029            | crate::Expression::SubgroupOperationResult { .. }
3030            | crate::Expression::RayQueryProceedResult => {
3031                unreachable!()
3032            }
3033            crate::Expression::ArrayLength(expr) => {
3034                let global = context.function.originating_global(expr).ok_or_else(|| {
3035                    Error::GenericValidation(format!(
3036                        "Could not find global variable for ArrayLength operand {:?}",
3037                        context.function.expressions[expr]
3038                    ))
3039                })?;
3040
3041                if !is_scoped {
3042                    write!(self.out, "(")?;
3043                }
3044                write!(self.out, "1 + ")?;
3045                self.put_dynamic_array_max_index(global, expr, context)?;
3046                if !is_scoped {
3047                    write!(self.out, ")")?;
3048                }
3049            }
3050            crate::Expression::RayQueryVertexPositions { .. } => {
3051                unimplemented!()
3052            }
3053            crate::Expression::RayQueryGetIntersection { query, committed } => {
3054                if context.lang_version < (2, 4) {
3055                    return Err(Error::UnsupportedRayTracing);
3056                }
3057
3058                // See comment in `write_ray_query_stmt` for why this is valid
3059                let crate::Expression::LocalVariable(query_var) =
3060                    context.function.expressions[query]
3061                else {
3062                    unreachable!()
3063                };
3064
3065                let tracker_expr_name = format!(
3066                    "{}{}",
3067                    super::ray::RAY_QUERY_TRACKER_VARIABLE_PREFIX,
3068                    self.names[&NameKey::local(context.origin, query_var)]
3069                );
3070
3071                write!(
3072                    self.out,
3073                    "{}_{committed}(",
3074                    super::ray::INTERSECTION_FUNCTION_NAME
3075                )?;
3076                self.put_expression(query, context, true)?;
3077                if context.ray_query_initialization_tracking {
3078                    write!(self.out, ", {tracker_expr_name}")?;
3079                }
3080                write!(self.out, ")")?;
3081            }
3082            crate::Expression::CooperativeLoad { ref data, .. } => {
3083                if context.lang_version < (2, 3) {
3084                    return Err(Error::UnsupportedCooperativeMatrix);
3085                }
3086                write!(self.out, "{COOPERATIVE_LOAD_FUNCTION}(")?;
3087                write!(self.out, "&")?;
3088                self.put_access_chain(data.pointer, context.policies.index, context)?;
3089                write!(self.out, ", ")?;
3090                self.put_expression(data.stride, context, true)?;
3091                // Metal's `simdgroup_load` treats its `transpose` flag as
3092                // "memory is transposed from the simdgroup_matrix's canonical
3093                // layout". On Apple GPUs that canonical layout is row-major,
3094                // so `transpose=false` loads from row-major memory. WGSL's
3095                // `coopLoadT` (row_major=true) = row-major memory, so it must
3096                // map to `transpose=false`. Hence the negation.
3097                write!(self.out, ", {})", !data.row_major)?;
3098            }
3099            crate::Expression::CooperativeMultiplyAdd { a, b, c } => {
3100                if context.lang_version < (2, 3) {
3101                    return Err(Error::UnsupportedCooperativeMatrix);
3102                }
3103                write!(self.out, "{COOPERATIVE_MULTIPLY_ADD_FUNCTION}(")?;
3104                self.put_expression(a, context, true)?;
3105                write!(self.out, ", ")?;
3106                self.put_expression(b, context, true)?;
3107                write!(self.out, ", ")?;
3108                self.put_expression(c, context, true)?;
3109                write!(self.out, ")")?;
3110            }
3111        }
3112        Ok(())
3113    }
3114
3115    /// Emits code for a binary operation, using the provided callback to emit
3116    /// the left and right operands.
3117    fn put_binop<F>(
3118        &mut self,
3119        op: crate::BinaryOperator,
3120        left: Handle<crate::Expression>,
3121        right: Handle<crate::Expression>,
3122        context: &ExpressionContext,
3123        is_scoped: bool,
3124        put_expression: &F,
3125    ) -> BackendResult
3126    where
3127        F: Fn(&mut Self, Handle<crate::Expression>, &ExpressionContext, bool) -> BackendResult,
3128    {
3129        let op_str = back::binary_operation_str(op);
3130
3131        if !is_scoped {
3132            write!(self.out, "(")?;
3133        }
3134
3135        // Cast packed vector if necessary
3136        // Packed vector - matrix multiplications are not supported in MSL
3137        if op == crate::BinaryOperator::Multiply
3138            && matches!(
3139                context.resolve_type(right),
3140                &crate::TypeInner::Matrix { .. }
3141            )
3142        {
3143            self.put_wrapped_expression_for_packed_vec3_access(
3144                left,
3145                context,
3146                false,
3147                put_expression,
3148            )?;
3149        } else {
3150            put_expression(self, left, context, false)?;
3151        }
3152
3153        write!(self.out, " {op_str} ")?;
3154
3155        // See comment above
3156        if op == crate::BinaryOperator::Multiply
3157            && matches!(context.resolve_type(left), &crate::TypeInner::Matrix { .. })
3158        {
3159            self.put_wrapped_expression_for_packed_vec3_access(
3160                right,
3161                context,
3162                false,
3163                put_expression,
3164            )?;
3165        } else {
3166            put_expression(self, right, context, false)?;
3167        }
3168
3169        if !is_scoped {
3170            write!(self.out, ")")?;
3171        }
3172
3173        Ok(())
3174    }
3175
3176    /// Used by expressions like Swizzle and Binary since they need packed_vec3's to be casted to a vec3
3177    fn put_wrapped_expression_for_packed_vec3_access<F>(
3178        &mut self,
3179        expr_handle: Handle<crate::Expression>,
3180        context: &ExpressionContext,
3181        is_scoped: bool,
3182        put_expression: &F,
3183    ) -> BackendResult
3184    where
3185        F: Fn(&mut Self, Handle<crate::Expression>, &ExpressionContext, bool) -> BackendResult,
3186    {
3187        if let Some(scalar) = context.get_packed_vec_kind(expr_handle) {
3188            write!(self.out, "{}::{}3(", NAMESPACE, scalar.to_msl_name())?;
3189            put_expression(self, expr_handle, context, is_scoped)?;
3190            write!(self.out, ")")?;
3191        } else {
3192            put_expression(self, expr_handle, context, is_scoped)?;
3193        }
3194        Ok(())
3195    }
3196
3197    /// Emits code for an expression using the provided callback, wrapping the
3198    /// result in a bitcast to the type `cast_to`.
3199    fn put_bitcasted_expression<F>(
3200        &mut self,
3201        cast_to: &crate::TypeInner,
3202        inner_expr: Handle<crate::Expression>,
3203        context: &ExpressionContext,
3204        put_expression: &F,
3205    ) -> BackendResult
3206    where
3207        F: Fn(&mut Self, &ExpressionContext, bool) -> BackendResult,
3208    {
3209        // For sub-32-bit types, C++ integer promotion can widen the inner
3210        // expression (e.g. `ushort + ushort` promotes to `int`), making a
3211        // direct `as_type<short>(int_expr)` invalid due to size mismatch.
3212        // We wrap with `static_cast` to truncate back before the bitcast.
3213        let needs_truncation = match *cast_to {
3214            crate::TypeInner::Scalar(scalar) => scalar.width < 4,
3215            crate::TypeInner::Vector { scalar, .. } => scalar.width < 4,
3216            _ => false,
3217        };
3218
3219        write!(self.out, "as_type<")?;
3220        match *cast_to {
3221            crate::TypeInner::Scalar(scalar) => put_numeric_type(&mut self.out, scalar, &[])?,
3222            crate::TypeInner::Vector { size, scalar } => {
3223                put_numeric_type(&mut self.out, scalar, &[size])?
3224            }
3225            _ => return Err(Error::UnsupportedBitCast(cast_to.clone())),
3226        };
3227        write!(self.out, ">(")?;
3228
3229        if needs_truncation {
3230            write!(self.out, "static_cast<")?;
3231            // Cast to the unsigned version of the target type to truncate
3232            let unsigned_scalar = match *cast_to {
3233                crate::TypeInner::Scalar(scalar) => crate::Scalar {
3234                    kind: crate::ScalarKind::Uint,
3235                    ..scalar
3236                },
3237                crate::TypeInner::Vector { scalar, .. } => crate::Scalar {
3238                    kind: crate::ScalarKind::Uint,
3239                    ..scalar
3240                },
3241                _ => unreachable!(),
3242            };
3243            match *cast_to {
3244                crate::TypeInner::Scalar(_) => {
3245                    put_numeric_type(&mut self.out, unsigned_scalar, &[])?
3246                }
3247                crate::TypeInner::Vector { size, .. } => {
3248                    put_numeric_type(&mut self.out, unsigned_scalar, &[size])?
3249                }
3250                _ => unreachable!(),
3251            };
3252            write!(self.out, ">(")?;
3253        }
3254
3255        // if it's packed, we must unpack it (e.g., float3(val)) before the bitcast.
3256        if let Some(scalar) = context.get_packed_vec_kind(inner_expr) {
3257            put_numeric_type(&mut self.out, scalar, &[crate::VectorSize::Tri])?;
3258            write!(self.out, "(")?;
3259            put_expression(self, context, true)?;
3260            write!(self.out, ")")?;
3261        } else {
3262            put_expression(self, context, true)?;
3263        }
3264
3265        if needs_truncation {
3266            write!(self.out, ")")?;
3267        }
3268
3269        write!(self.out, ")")?;
3270        Ok(())
3271    }
3272
3273    /// Write a `GuardedIndex` as a Metal expression.
3274    fn put_index(
3275        &mut self,
3276        index: index::GuardedIndex,
3277        context: &ExpressionContext,
3278        is_scoped: bool,
3279    ) -> BackendResult {
3280        match index {
3281            index::GuardedIndex::Expression(expr) => {
3282                self.put_expression(expr, context, is_scoped)?
3283            }
3284            index::GuardedIndex::Known(value) => write!(self.out, "{value}")?,
3285        }
3286        Ok(())
3287    }
3288
3289    /// Emit an index bounds check condition for `chain`, if required.
3290    ///
3291    /// `chain` is a subtree of `Access` and `AccessIndex` expressions,
3292    /// operating either on a pointer to a value, or on a value directly. If we cannot
3293    /// statically determine that all indexing operations in `chain` are within
3294    /// bounds, then write a conditional expression to check them dynamically,
3295    /// and return true. All accesses in the chain are checked by the generated
3296    /// expression.
3297    ///
3298    /// This assumes that the [`BoundsCheckPolicy`] for `chain` is [`ReadZeroSkipWrite`].
3299    ///
3300    /// The text written is of the form:
3301    ///
3302    /// ```ignore
3303    /// {level}{prefix}uint(i) < 4 && uint(j) < 10
3304    /// ```
3305    ///
3306    /// where `{level}` and `{prefix}` are the arguments to this function. For [`Store`]
3307    /// statements, presumably these arguments start an indented `if` statement; for
3308    /// [`Load`] expressions, the caller is probably building up a ternary `?:`
3309    /// expression. In either case, what is written is not a complete syntactic structure
3310    /// in its own right, and the caller will have to finish it off if we return `true`.
3311    ///
3312    /// If no expression is written, return false.
3313    ///
3314    /// [`BoundsCheckPolicy`]: index::BoundsCheckPolicy
3315    /// [`ReadZeroSkipWrite`]: index::BoundsCheckPolicy::ReadZeroSkipWrite
3316    /// [`Store`]: crate::Statement::Store
3317    /// [`Load`]: crate::Expression::Load
3318    fn put_bounds_checks(
3319        &mut self,
3320        chain: Handle<crate::Expression>,
3321        context: &ExpressionContext,
3322        level: back::Level,
3323        prefix: &'static str,
3324    ) -> Result<bool, Error> {
3325        let mut check_written = false;
3326
3327        // Iterate over the access chain, handling each required bounds check.
3328        for item in context.bounds_check_iter(chain) {
3329            let BoundsCheck {
3330                base,
3331                index,
3332                length,
3333            } = item;
3334
3335            if check_written {
3336                write!(self.out, " && ")?;
3337            } else {
3338                write!(self.out, "{level}{prefix}")?;
3339                check_written = true;
3340            }
3341
3342            // Check that the index falls within bounds. Do this with a single
3343            // comparison, by casting the index to `uint` first, so that negative
3344            // indices become large positive values.
3345            write!(self.out, "uint(")?;
3346            self.put_index(index, context, true)?;
3347            self.out.write_str(") < ")?;
3348            match length {
3349                index::IndexableLength::Known(value) => write!(self.out, "{value}")?,
3350                index::IndexableLength::Dynamic => {
3351                    let global = context.function.originating_global(base).ok_or_else(|| {
3352                        Error::GenericValidation("Could not find originating global".into())
3353                    })?;
3354                    if matches!(
3355                        context.module.types[context.module.global_variables[global].ty].inner,
3356                        crate::TypeInner::BindingArray { .. }
3357                    ) {
3358                        write!(
3359                            self.out,
3360                            "{} && _buffer_sizes.{}[",
3361                            Self::binding_array_layout_count(
3362                                context.module,
3363                                context.pipeline_options,
3364                                global,
3365                            ),
3366                            ArraySizeMember(global),
3367                        )?;
3368                        self.put_binding_array_size_member_index(index, context)?;
3369                        write!(self.out, "] != 0u")?;
3370                    } else {
3371                        write!(self.out, "1 + ")?;
3372                        self.put_dynamic_array_max_index(global, base, context)?
3373                    }
3374                }
3375            }
3376        }
3377
3378        Ok(check_written)
3379    }
3380
3381    /// Write the access chain `chain`.
3382    ///
3383    /// `chain` is a subtree of [`Access`] and [`AccessIndex`] expressions,
3384    /// operating either on a pointer to a value, or on a value directly.
3385    ///
3386    /// Generate bounds checks code only if `policy` is [`Restrict`]. The
3387    /// [`ReadZeroSkipWrite`] policy requires checks before any accesses take place, so
3388    /// that must be handled in the caller.
3389    ///
3390    /// Handle the entire chain, recursing back into `put_expression` only for index
3391    /// expressions and the base expression that originates the pointer or composite value
3392    /// being accessed. This allows `put_expression` to assume that any `Access` or
3393    /// `AccessIndex` expressions it sees are the top of a chain, so it can emit
3394    /// `ReadZeroSkipWrite` checks.
3395    ///
3396    /// [`Access`]: crate::Expression::Access
3397    /// [`AccessIndex`]: crate::Expression::AccessIndex
3398    /// [`Restrict`]: crate::proc::index::BoundsCheckPolicy::Restrict
3399    /// [`ReadZeroSkipWrite`]: crate::proc::index::BoundsCheckPolicy::ReadZeroSkipWrite
3400    fn put_access_chain(
3401        &mut self,
3402        chain: Handle<crate::Expression>,
3403        policy: index::BoundsCheckPolicy,
3404        context: &ExpressionContext,
3405    ) -> BackendResult {
3406        match context.function.expressions[chain] {
3407            crate::Expression::Access { base, index } => {
3408                let mut base_ty = context.resolve_type(base);
3409
3410                // Look through any pointers to see what we're really indexing.
3411                if let crate::TypeInner::Pointer { base, space: _ } = *base_ty {
3412                    base_ty = &context.module.types[base].inner;
3413                }
3414
3415                self.put_subscripted_access_chain(
3416                    base,
3417                    base_ty,
3418                    index::GuardedIndex::Expression(index),
3419                    policy,
3420                    context,
3421                )?;
3422            }
3423            crate::Expression::AccessIndex { base, index } => {
3424                let base_resolution = &context.info[base].ty;
3425                let mut base_ty = base_resolution.inner_with(&context.module.types);
3426                let mut base_ty_handle = base_resolution.handle();
3427
3428                // Look through any pointers to see what we're really indexing.
3429                if let crate::TypeInner::Pointer { base, space: _ } = *base_ty {
3430                    base_ty = &context.module.types[base].inner;
3431                    base_ty_handle = Some(base);
3432                }
3433
3434                // Handle structs and anything else that can use `.x` syntax here, so
3435                // `put_subscripted_access_chain` won't have to handle the absurd case of
3436                // indexing a struct with an expression.
3437                match *base_ty {
3438                    crate::TypeInner::Struct { .. } => {
3439                        let base_ty = base_ty_handle.unwrap();
3440                        self.put_access_chain(base, policy, context)?;
3441                        let name = &self.names[&NameKey::StructMember(base_ty, index)];
3442                        write!(
3443                            self.out,
3444                            "{}{name}",
3445                            if context.struct_member_needs_arrow(base, |ty| {
3446                                matches!(ty, crate::TypeInner::BindingArray { .. })
3447                            }) {
3448                                "->"
3449                            } else {
3450                                "."
3451                            },
3452                        )?;
3453                    }
3454                    crate::TypeInner::ValuePointer { .. } | crate::TypeInner::Vector { .. } => {
3455                        self.put_access_chain(base, policy, context)?;
3456                        // Prior to Metal v2.1 component access for packed vectors wasn't available
3457                        // however array indexing is
3458                        if context.get_packed_vec_kind(base).is_some() {
3459                            write!(self.out, "[{index}]")?;
3460                        } else {
3461                            write!(self.out, ".{}", back::COMPONENTS[index as usize])?;
3462                        }
3463                    }
3464                    _ => {
3465                        self.put_subscripted_access_chain(
3466                            base,
3467                            base_ty,
3468                            index::GuardedIndex::Known(index),
3469                            policy,
3470                            context,
3471                        )?;
3472                    }
3473                }
3474            }
3475            _ => self.put_expression(chain, context, false)?,
3476        }
3477
3478        Ok(())
3479    }
3480
3481    /// Write a `[]`-style access of `base` by `index`.
3482    ///
3483    /// If `policy` is [`Restrict`], then generate code as needed to force all index
3484    /// values within bounds.
3485    ///
3486    /// The `base_ty` argument must be the type we are actually indexing, like [`Array`] or
3487    /// [`Vector`]. In other words, it's `base`'s type with any surrounding [`Pointer`]
3488    /// removed. Our callers often already have this handy.
3489    ///
3490    /// This only emits `[]` expressions; it doesn't handle struct member accesses or
3491    /// referencing vector components by name.
3492    ///
3493    /// [`Restrict`]: crate::proc::index::BoundsCheckPolicy::Restrict
3494    /// [`Array`]: crate::TypeInner::Array
3495    /// [`Vector`]: crate::TypeInner::Vector
3496    /// [`Pointer`]: crate::TypeInner::Pointer
3497    fn put_subscripted_access_chain(
3498        &mut self,
3499        base: Handle<crate::Expression>,
3500        base_ty: &crate::TypeInner,
3501        index: index::GuardedIndex,
3502        policy: index::BoundsCheckPolicy,
3503        context: &ExpressionContext,
3504    ) -> BackendResult {
3505        let accessing_wrapped_array = match *base_ty {
3506            crate::TypeInner::Array {
3507                size: crate::ArraySize::Constant(_) | crate::ArraySize::Pending(_),
3508                ..
3509            } => true,
3510            _ => false,
3511        };
3512        let accessing_wrapped_binding_array =
3513            matches!(*base_ty, crate::TypeInner::BindingArray { .. });
3514
3515        self.put_access_chain(base, policy, context)?;
3516        if accessing_wrapped_array {
3517            write!(self.out, ".{WRAPPED_ARRAY_FIELD}")?;
3518        }
3519        write!(self.out, "[")?;
3520
3521        // Decide whether this index needs to be clamped to fall within range.
3522        let restriction_needed = if policy == index::BoundsCheckPolicy::Restrict {
3523            context.access_needs_check(base, index)
3524        } else {
3525            None
3526        };
3527        if let Some(limit) = restriction_needed {
3528            write!(self.out, "{NAMESPACE}::min(unsigned(")?;
3529            self.put_index(index, context, true)?;
3530            write!(self.out, "), ")?;
3531            match limit {
3532                index::IndexableLength::Known(limit) => {
3533                    write!(self.out, "{}u", limit - 1)?;
3534                }
3535                index::IndexableLength::Dynamic => {
3536                    let global = context.function.originating_global(base).ok_or_else(|| {
3537                        Error::GenericValidation("Could not find originating global".into())
3538                    })?;
3539                    self.put_dynamic_array_max_index(global, base, context)?;
3540                }
3541            }
3542            write!(self.out, ")")?;
3543        } else {
3544            self.put_index(index, context, true)?;
3545        }
3546
3547        write!(self.out, "]")?;
3548
3549        if accessing_wrapped_binding_array {
3550            write!(self.out, ".{WRAPPED_ARRAY_FIELD}")?;
3551        }
3552
3553        Ok(())
3554    }
3555
3556    fn put_load(
3557        &mut self,
3558        pointer: Handle<crate::Expression>,
3559        context: &ExpressionContext,
3560        is_scoped: bool,
3561    ) -> BackendResult {
3562        // Since access chains never cross between address spaces, we can just
3563        // check the index bounds check policy once at the top.
3564        let policy = context.choose_bounds_check_policy(pointer);
3565        if policy == index::BoundsCheckPolicy::ReadZeroSkipWrite
3566            && self.put_bounds_checks(
3567                pointer,
3568                context,
3569                back::Level(0),
3570                if is_scoped { "" } else { "(" },
3571            )?
3572        {
3573            write!(self.out, " ? ")?;
3574            self.put_unchecked_load(pointer, policy, context)?;
3575            write!(self.out, " : DefaultConstructible()")?;
3576
3577            if !is_scoped {
3578                write!(self.out, ")")?;
3579            }
3580        } else {
3581            self.put_unchecked_load(pointer, policy, context)?;
3582        }
3583
3584        Ok(())
3585    }
3586
3587    fn put_unchecked_load(
3588        &mut self,
3589        pointer: Handle<crate::Expression>,
3590        policy: index::BoundsCheckPolicy,
3591        context: &ExpressionContext,
3592    ) -> BackendResult {
3593        let is_atomic_pointer = context
3594            .resolve_type(pointer)
3595            .is_atomic_pointer(&context.module.types);
3596
3597        if is_atomic_pointer {
3598            write!(
3599                self.out,
3600                "{NAMESPACE}::atomic_load_explicit({ATOMIC_REFERENCE}"
3601            )?;
3602            self.put_access_chain(pointer, policy, context)?;
3603            write!(self.out, ", {NAMESPACE}::memory_order_relaxed)")?;
3604        } else {
3605            // We don't do any dereferencing with `*` here as pointer arguments to functions
3606            // are done by `&` references and not `*` pointers. These do not need to be
3607            // dereferenced.
3608            self.put_access_chain(pointer, policy, context)?;
3609        }
3610
3611        Ok(())
3612    }
3613
3614    fn put_return_value(
3615        &mut self,
3616        level: back::Level,
3617        expr_handle: Handle<crate::Expression>,
3618        result_struct: Option<&str>,
3619        context: &ExpressionContext,
3620    ) -> BackendResult {
3621        match result_struct {
3622            Some(struct_name) => {
3623                let mut has_point_size = false;
3624                let result_ty = context.function.result.as_ref().unwrap().ty;
3625                match context.module.types[result_ty].inner {
3626                    crate::TypeInner::Struct { ref members, .. } => {
3627                        let tmp = self.namer.call("_tmp");
3628                        write!(self.out, "{level}const auto {tmp} = ")?;
3629                        self.put_expression(expr_handle, context, true)?;
3630                        writeln!(self.out, ";")?;
3631                        write!(self.out, "{level}return {struct_name} {{")?;
3632
3633                        let mut is_first = true;
3634
3635                        for (index, member) in members.iter().enumerate() {
3636                            if let Some(crate::Binding::BuiltIn(crate::BuiltIn::PointSize)) =
3637                                member.binding
3638                            {
3639                                has_point_size = true;
3640                                if !context.pipeline_options.allow_and_force_point_size {
3641                                    continue;
3642                                }
3643                            }
3644
3645                            let comma = if is_first { "" } else { "," };
3646                            is_first = false;
3647                            let name = &self.names[&NameKey::StructMember(result_ty, index as u32)];
3648                            // HACK: we are forcefully deduplicating the expression here
3649                            // to convert from a wrapped struct to a raw array, e.g.
3650                            // `float gl_ClipDistance1 [[clip_distance]] [1];`.
3651                            if let crate::TypeInner::Array {
3652                                size: crate::ArraySize::Constant(size),
3653                                ..
3654                            } = context.module.types[member.ty].inner
3655                            {
3656                                write!(self.out, "{comma} {{")?;
3657                                for j in 0..size.get() {
3658                                    if j != 0 {
3659                                        write!(self.out, ",")?;
3660                                    }
3661                                    write!(self.out, "{tmp}.{name}.{WRAPPED_ARRAY_FIELD}[{j}]")?;
3662                                }
3663                                write!(self.out, "}}")?;
3664                            } else {
3665                                write!(self.out, "{comma} {tmp}.{name}")?;
3666                            }
3667                        }
3668                    }
3669                    _ => {
3670                        write!(self.out, "{level}return {struct_name} {{ ")?;
3671                        self.put_expression(expr_handle, context, true)?;
3672                    }
3673                }
3674
3675                if let FunctionOrigin::EntryPoint(ep_index) = context.origin {
3676                    let stage = context.module.entry_points[ep_index as usize].stage;
3677                    if context.pipeline_options.allow_and_force_point_size
3678                        && stage == crate::ShaderStage::Vertex
3679                        && !has_point_size
3680                    {
3681                        // point size was injected and comes last
3682                        write!(self.out, ", 1.0")?;
3683                    }
3684                }
3685                write!(self.out, " }}")?;
3686            }
3687            None => {
3688                write!(self.out, "{level}return ")?;
3689                self.put_expression(expr_handle, context, true)?;
3690            }
3691        }
3692        writeln!(self.out, ";")?;
3693        Ok(())
3694    }
3695
3696    /// Helper method used to find which expressions of a given function require baking
3697    ///
3698    /// # Notes
3699    /// This function overwrites the contents of `self.need_bake_expressions`
3700    fn update_expressions_to_bake(
3701        &mut self,
3702        func: &crate::Function,
3703        info: &valid::FunctionInfo,
3704        context: &ExpressionContext,
3705    ) {
3706        use crate::Expression;
3707        self.need_bake_expressions.clear();
3708
3709        for (expr_handle, expr) in func.expressions.iter() {
3710            // Expressions whose reference count is above the
3711            // threshold should always be stored in temporaries.
3712            let expr_info = &info[expr_handle];
3713            let min_ref_count = func.expressions[expr_handle].bake_ref_count();
3714            if min_ref_count <= expr_info.ref_count {
3715                self.need_bake_expressions.insert(expr_handle);
3716            } else {
3717                match expr_info.ty {
3718                    // force ray desc to be baked: it's used multiple times internally
3719                    TypeResolution::Handle(h)
3720                        if Some(h) == context.module.special_types.ray_desc =>
3721                    {
3722                        self.need_bake_expressions.insert(expr_handle);
3723                    }
3724                    _ => {}
3725                }
3726            }
3727
3728            if let Expression::Math {
3729                fun,
3730                arg,
3731                arg1,
3732                arg2,
3733                ..
3734            } = *expr
3735            {
3736                match fun {
3737                    // WGSL's `dot` function works on any `vecN` type, but Metal's only
3738                    // works on floating-point vectors, so we emit inline code for
3739                    // integer vector `dot` calls. But that code uses each argument `N`
3740                    // times, once for each component (see `put_dot_product`), so to
3741                    // avoid duplicated evaluation, we must bake integer operands.
3742                    // This applies both when using the polyfill (because of the duplicate
3743                    // evaluation issue) and when we don't use the polyfill (because we
3744                    // need them to be emitted before casting to packed chars -- see the
3745                    // comment at the call to `put_casting_to_packed_chars`).
3746                    crate::MathFunction::Dot4U8Packed | crate::MathFunction::Dot4I8Packed => {
3747                        self.need_bake_expressions.insert(arg);
3748                        self.need_bake_expressions.insert(arg1.unwrap());
3749                    }
3750                    crate::MathFunction::FirstLeadingBit => {
3751                        self.need_bake_expressions.insert(arg);
3752                    }
3753                    crate::MathFunction::Pack4xI8
3754                    | crate::MathFunction::Pack4xU8
3755                    | crate::MathFunction::Pack4xI8Clamp
3756                    | crate::MathFunction::Pack4xU8Clamp
3757                    | crate::MathFunction::Unpack4xI8
3758                    | crate::MathFunction::Unpack4xU8 => {
3759                        // On MSL < 2.1, we emit a polyfill for these functions that uses the
3760                        // argument multiple times. This is no longer necessary on MSL >= 2.1.
3761                        if context.lang_version < (2, 1) {
3762                            self.need_bake_expressions.insert(arg);
3763                        }
3764                    }
3765                    crate::MathFunction::ExtractBits => {
3766                        // Only argument 1 is re-used.
3767                        self.need_bake_expressions.insert(arg1.unwrap());
3768                    }
3769                    crate::MathFunction::InsertBits => {
3770                        // Only argument 2 is re-used.
3771                        self.need_bake_expressions.insert(arg2.unwrap());
3772                    }
3773                    crate::MathFunction::Sign => {
3774                        // WGSL's `sign` function works also on signed ints, but Metal's only
3775                        // works on floating points, so we emit inline code for integer `sign`
3776                        // calls. But that code uses each argument 2 times (see `put_isign`),
3777                        // so to avoid duplicated evaluation, we must bake the argument.
3778                        let inner = context.resolve_type(expr_handle);
3779                        if inner.scalar_kind() == Some(crate::ScalarKind::Sint) {
3780                            self.need_bake_expressions.insert(arg);
3781                        }
3782                    }
3783                    _ => {}
3784                }
3785            }
3786        }
3787    }
3788
3789    pub(super) fn start_baking_expression(
3790        &mut self,
3791        handle: Handle<crate::Expression>,
3792        context: &ExpressionContext,
3793        name: &str,
3794    ) -> BackendResult {
3795        match context.info[handle].ty {
3796            TypeResolution::Handle(ty_handle) => {
3797                let ty_name = TypeContext {
3798                    handle: ty_handle,
3799                    gctx: context.module.to_ctx(),
3800                    names: &self.names,
3801                    access: crate::StorageAccess::empty(),
3802                    first_time: false,
3803                };
3804                write!(self.out, "{ty_name}")?;
3805            }
3806            TypeResolution::Value(crate::TypeInner::Scalar(scalar)) => {
3807                put_numeric_type(&mut self.out, scalar, &[])?;
3808            }
3809            TypeResolution::Value(crate::TypeInner::Vector { size, scalar }) => {
3810                put_numeric_type(&mut self.out, scalar, &[size])?;
3811            }
3812            TypeResolution::Value(crate::TypeInner::Matrix {
3813                columns,
3814                rows,
3815                scalar,
3816            }) => {
3817                put_numeric_type(&mut self.out, scalar, &[rows, columns])?;
3818            }
3819            TypeResolution::Value(crate::TypeInner::CooperativeMatrix {
3820                columns,
3821                rows,
3822                scalar,
3823                role: _,
3824            }) => {
3825                write!(
3826                    self.out,
3827                    "{}::simdgroup_{}{}x{}",
3828                    NAMESPACE,
3829                    scalar.to_msl_name(),
3830                    columns as u32,
3831                    rows as u32,
3832                )?;
3833            }
3834            TypeResolution::Value(ref other) => {
3835                log::warn!("Type {other:?} isn't a known local");
3836                return Err(Error::FeatureNotImplemented("weird local type".to_string()));
3837            }
3838        }
3839
3840        //TODO: figure out the naming scheme that wouldn't collide with user names.
3841        write!(self.out, " {name} = ")?;
3842
3843        Ok(())
3844    }
3845
3846    /// Cache a clamped level of detail value, if necessary.
3847    ///
3848    /// [`ImageLoad`] accesses covered by [`BoundsCheckPolicy::Restrict`] use a
3849    /// properly clamped level of detail value both in the access itself, and
3850    /// for fetching the size of the requested MIP level, needed to clamp the
3851    /// coordinates. To avoid recomputing this clamped level of detail, we cache
3852    /// it in a temporary variable, as part of the [`Emit`] statement covering
3853    /// the [`ImageLoad`] expression.
3854    ///
3855    /// [`ImageLoad`]: crate::Expression::ImageLoad
3856    /// [`BoundsCheckPolicy::Restrict`]: index::BoundsCheckPolicy::Restrict
3857    /// [`Emit`]: crate::Statement::Emit
3858    fn put_cache_restricted_level(
3859        &mut self,
3860        load: Handle<crate::Expression>,
3861        image: Handle<crate::Expression>,
3862        mip_level: Option<Handle<crate::Expression>>,
3863        indent: back::Level,
3864        context: &StatementContext,
3865    ) -> BackendResult {
3866        // Does this image access actually require (or even permit) a
3867        // level-of-detail, and does the policy require us to restrict it?
3868        let level_of_detail = match mip_level {
3869            Some(level) => level,
3870            None => return Ok(()),
3871        };
3872
3873        if context.expression.policies.image_load != index::BoundsCheckPolicy::Restrict
3874            || !context.expression.image_needs_lod(image)
3875        {
3876            return Ok(());
3877        }
3878
3879        write!(self.out, "{}uint {} = ", indent, ClampedLod(load),)?;
3880        self.put_restricted_scalar_image_index(
3881            image,
3882            level_of_detail,
3883            "get_num_mip_levels",
3884            &context.expression,
3885        )?;
3886        writeln!(self.out, ";")?;
3887
3888        Ok(())
3889    }
3890
3891    /// Convert the arguments of `Dot4{I, U}Packed` to `packed_(u?)char4`.
3892    ///
3893    /// Caches the results in temporary variables (whose names are derived from
3894    /// the original variable names). This caching avoids the need to redo the
3895    /// casting for each vector component when emitting the dot product.
3896    fn put_casting_to_packed_chars(
3897        &mut self,
3898        fun: crate::MathFunction,
3899        arg0: Handle<crate::Expression>,
3900        arg1: Handle<crate::Expression>,
3901        indent: back::Level,
3902        context: &StatementContext<'_>,
3903    ) -> Result<(), Error> {
3904        let packed_type = match fun {
3905            crate::MathFunction::Dot4I8Packed => "packed_char4",
3906            crate::MathFunction::Dot4U8Packed => "packed_uchar4",
3907            _ => unreachable!(),
3908        };
3909
3910        for arg in [arg0, arg1] {
3911            write!(
3912                self.out,
3913                "{indent}{packed_type} {0} = as_type<{packed_type}>(",
3914                Reinterpreted::new(packed_type, arg)
3915            )?;
3916            self.put_expression(arg, &context.expression, true)?;
3917            writeln!(self.out, ");")?;
3918        }
3919
3920        Ok(())
3921    }
3922
3923    fn put_block(
3924        &mut self,
3925        level: back::Level,
3926        statements: &[crate::Statement],
3927        context: &StatementContext,
3928    ) -> BackendResult {
3929        for statement in statements {
3930            log::trace!("statement[{}] {:?}", level.0, statement);
3931            match *statement {
3932                crate::Statement::Emit(ref range) => {
3933                    for handle in range.clone() {
3934                        use crate::MathFunction as Mf;
3935
3936                        match context.expression.function.expressions[handle] {
3937                            // `ImageLoad` expressions covered by the `Restrict` bounds check policy
3938                            // may need to cache a clamped version of their level-of-detail argument.
3939                            crate::Expression::ImageLoad {
3940                                image,
3941                                level: mip_level,
3942                                ..
3943                            } => {
3944                                self.put_cache_restricted_level(
3945                                    handle, image, mip_level, level, context,
3946                                )?;
3947                            }
3948
3949                            // If we are going to write a `Dot4I8Packed` or `Dot4U8Packed` on Metal
3950                            // 2.1+ then we introduce two intermediate variables that recast the two
3951                            // arguments as packed (signed or unsigned) chars. The actual dot product
3952                            // is implemented in `Self::put_expression`, and it uses both of these
3953                            // intermediate variables multiple times. There's no danger that the
3954                            // original arguments get modified between the definition of these
3955                            // intermediate variables and the implementation of the actual dot
3956                            // product since we require the inputs of `Dot4{I, U}Packed` to be baked.
3957                            crate::Expression::Math {
3958                                fun: fun @ (Mf::Dot4I8Packed | Mf::Dot4U8Packed),
3959                                arg,
3960                                arg1,
3961                                ..
3962                            } if context.expression.lang_version >= (2, 1) => {
3963                                self.put_casting_to_packed_chars(
3964                                    fun,
3965                                    arg,
3966                                    arg1.unwrap(),
3967                                    level,
3968                                    context,
3969                                )?;
3970                            }
3971
3972                            _ => (),
3973                        }
3974
3975                        let ptr_class = context.expression.resolve_type(handle).pointer_space();
3976                        let expr_name = if ptr_class.is_some() {
3977                            None // don't bake pointer expressions (just yet)
3978                        } else if let Some(name) =
3979                            context.expression.function.named_expressions.get(&handle)
3980                        {
3981                            // The `crate::Function::named_expressions` table holds
3982                            // expressions that should be saved in temporaries once they
3983                            // are `Emit`ted. We only add them to `self.named_expressions`
3984                            // when we reach the `Emit` that covers them, so that we don't
3985                            // try to use their names before we've actually initialized
3986                            // the temporary that holds them.
3987                            //
3988                            // Don't assume the names in `named_expressions` are unique,
3989                            // or even valid. Use the `Namer`.
3990                            Some(self.namer.call(name))
3991                        } else {
3992                            // If this expression is an index that we're going to first compare
3993                            // against a limit, and then actually use as an index, then we may
3994                            // want to cache it in a temporary, to avoid evaluating it twice.
3995                            let bake = if context.expression.guarded_indices.contains(handle) {
3996                                true
3997                            } else {
3998                                self.need_bake_expressions.contains(&handle)
3999                            };
4000
4001                            if bake {
4002                                Some(Baked(handle).to_string())
4003                            } else {
4004                                None
4005                            }
4006                        };
4007
4008                        if let Some(name) = expr_name {
4009                            write!(self.out, "{level}")?;
4010                            self.start_baking_expression(handle, &context.expression, &name)?;
4011                            self.put_expression(handle, &context.expression, true)?;
4012                            self.named_expressions.insert(handle, name);
4013                            writeln!(self.out, ";")?;
4014                        }
4015                    }
4016                }
4017                crate::Statement::Block(ref block) => {
4018                    if !block.is_empty() {
4019                        writeln!(self.out, "{level}{{")?;
4020                        self.put_block(level.next(), block, context)?;
4021                        writeln!(self.out, "{level}}}")?;
4022                    }
4023                }
4024                crate::Statement::If {
4025                    condition,
4026                    ref accept,
4027                    ref reject,
4028                } => {
4029                    write!(self.out, "{level}if (")?;
4030                    self.put_expression(condition, &context.expression, true)?;
4031                    writeln!(self.out, ") {{")?;
4032                    self.put_block(level.next(), accept, context)?;
4033                    if !reject.is_empty() {
4034                        writeln!(self.out, "{level}}} else {{")?;
4035                        self.put_block(level.next(), reject, context)?;
4036                    }
4037                    writeln!(self.out, "{level}}}")?;
4038                }
4039                crate::Statement::Switch {
4040                    selector,
4041                    ref cases,
4042                } => {
4043                    write!(self.out, "{level}switch(")?;
4044                    self.put_expression(selector, &context.expression, true)?;
4045                    writeln!(self.out, ") {{")?;
4046                    let lcase = level.next();
4047                    for case in cases.iter() {
4048                        match case.value {
4049                            crate::SwitchValue::I32(value) => {
4050                                write!(self.out, "{lcase}case {value}:")?;
4051                            }
4052                            crate::SwitchValue::U32(value) => {
4053                                write!(self.out, "{lcase}case {value}u:")?;
4054                            }
4055                            crate::SwitchValue::Default => {
4056                                write!(self.out, "{lcase}default:")?;
4057                            }
4058                        }
4059
4060                        let write_block_braces = !(case.fall_through && case.body.is_empty());
4061                        if write_block_braces {
4062                            writeln!(self.out, " {{")?;
4063                        } else {
4064                            writeln!(self.out)?;
4065                        }
4066
4067                        self.put_block(lcase.next(), &case.body, context)?;
4068                        if !case.fall_through && case.body.last().is_none_or(|s| !s.is_terminator())
4069                        {
4070                            writeln!(self.out, "{}break;", lcase.next())?;
4071                        }
4072
4073                        if write_block_braces {
4074                            writeln!(self.out, "{lcase}}}")?;
4075                        }
4076                    }
4077                    writeln!(self.out, "{level}}}")?;
4078                }
4079                crate::Statement::Loop {
4080                    ref body,
4081                    ref continuing,
4082                    break_if,
4083                } => {
4084                    let force_loop_bound_statements =
4085                        self.gen_force_bounded_loop_statements(level, context);
4086                    let gate_name = (!continuing.is_empty() || break_if.is_some())
4087                        .then(|| self.namer.call("loop_init"));
4088
4089                    if let Some((ref decl, _)) = force_loop_bound_statements {
4090                        writeln!(self.out, "{decl}")?;
4091                    }
4092                    if let Some(ref gate_name) = gate_name {
4093                        writeln!(self.out, "{level}bool {gate_name} = true;")?;
4094                    }
4095
4096                    writeln!(self.out, "{level}while(true) {{",)?;
4097                    if let Some((_, ref break_and_inc)) = force_loop_bound_statements {
4098                        writeln!(self.out, "{break_and_inc}")?;
4099                    }
4100                    if let Some(ref gate_name) = gate_name {
4101                        let lif = level.next();
4102                        let lcontinuing = lif.next();
4103                        writeln!(self.out, "{lif}if (!{gate_name}) {{")?;
4104                        self.put_block(lcontinuing, continuing, context)?;
4105                        if let Some(condition) = break_if {
4106                            write!(self.out, "{lcontinuing}if (")?;
4107                            self.put_expression(condition, &context.expression, true)?;
4108                            writeln!(self.out, ") {{")?;
4109                            writeln!(self.out, "{}break;", lcontinuing.next())?;
4110                            writeln!(self.out, "{lcontinuing}}}")?;
4111                        }
4112                        writeln!(self.out, "{lif}}}")?;
4113                        writeln!(self.out, "{lif}{gate_name} = false;")?;
4114                    }
4115                    self.put_block(level.next(), body, context)?;
4116
4117                    writeln!(self.out, "{level}}}")?;
4118                }
4119                crate::Statement::Break => {
4120                    writeln!(self.out, "{level}break;")?;
4121                }
4122                crate::Statement::Continue => {
4123                    writeln!(self.out, "{level}continue;")?;
4124                }
4125                crate::Statement::Return {
4126                    value: Some(expr_handle),
4127                } => {
4128                    self.put_return_value(
4129                        level,
4130                        expr_handle,
4131                        context.result_struct,
4132                        &context.expression,
4133                    )?;
4134                }
4135                crate::Statement::Return { value: None } => {
4136                    writeln!(self.out, "{level}return;")?;
4137                }
4138                crate::Statement::Kill => {
4139                    writeln!(self.out, "{level}{NAMESPACE}::discard_fragment();")?;
4140                }
4141                crate::Statement::ControlBarrier(flags)
4142                | crate::Statement::MemoryBarrier(flags) => {
4143                    self.write_barrier(flags, level)?;
4144                }
4145                crate::Statement::Store { pointer, value } => {
4146                    self.put_store(pointer, value, level, context)?
4147                }
4148                crate::Statement::ImageStore {
4149                    image,
4150                    coordinate,
4151                    array_index,
4152                    value,
4153                } => {
4154                    let address = TexelAddress {
4155                        coordinate,
4156                        array_index,
4157                        sample: None,
4158                        level: None,
4159                    };
4160                    self.put_image_store(level, image, &address, value, context)?
4161                }
4162                crate::Statement::Call {
4163                    function,
4164                    ref arguments,
4165                    result,
4166                } => {
4167                    write!(self.out, "{level}")?;
4168                    if let Some(expr) = result {
4169                        let name = Baked(expr).to_string();
4170                        self.start_baking_expression(expr, &context.expression, &name)?;
4171                        self.named_expressions.insert(expr, name);
4172                    }
4173                    let fun_name = &self.names[&NameKey::Function(function)];
4174                    write!(self.out, "{fun_name}(")?;
4175                    // first, write down the actual arguments
4176                    for (i, &handle) in arguments.iter().enumerate() {
4177                        if i != 0 {
4178                            write!(self.out, ", ")?;
4179                        }
4180                        self.put_expression(handle, &context.expression, true)?;
4181                    }
4182                    // follow-up with any global resources used
4183                    let mut separate = !arguments.is_empty();
4184                    let fun_info = &context.expression.mod_info[function];
4185                    let mut needs_buffer_sizes = false;
4186                    for (handle, var) in context.expression.module.global_variables.iter() {
4187                        if fun_info[handle].is_empty() {
4188                            continue;
4189                        }
4190                        if var.space.needs_pass_through() {
4191                            let name = &self.names[&NameKey::GlobalVariable(handle)];
4192                            if separate {
4193                                write!(self.out, ", ")?;
4194                            } else {
4195                                separate = true;
4196                            }
4197                            write!(self.out, "{name}")?;
4198                        }
4199                        needs_buffer_sizes |= context.expression.module.types[var.ty]
4200                            .inner
4201                            .needs_host_buffer_byte_size(&context.expression.module.types);
4202                    }
4203                    if needs_buffer_sizes {
4204                        if separate {
4205                            write!(self.out, ", ")?;
4206                        }
4207                        write!(self.out, "_buffer_sizes")?;
4208                    }
4209
4210                    // done
4211                    writeln!(self.out, ");")?;
4212                }
4213                crate::Statement::Atomic {
4214                    pointer,
4215                    ref fun,
4216                    value,
4217                    result,
4218                } => {
4219                    let context = &context.expression;
4220
4221                    // This backend supports `SHADER_INT64_ATOMIC_MIN_MAX` but not
4222                    // `SHADER_INT64_ATOMIC_ALL_OPS`, so we can assume that if `result` is
4223                    // `Some`, we are not operating on a 64-bit value, and that if we are
4224                    // operating on a 64-bit value, `result` is `None`.
4225                    write!(self.out, "{level}")?;
4226                    let fun_key = if let Some(result) = result {
4227                        let res_name = Baked(result).to_string();
4228                        self.start_baking_expression(result, context, &res_name)?;
4229                        self.named_expressions.insert(result, res_name);
4230                        fun.to_msl()
4231                    } else if context.resolve_type(value).scalar_width() == Some(8) {
4232                        fun.to_msl_64_bit()?
4233                    } else {
4234                        fun.to_msl()
4235                    };
4236
4237                    // If the pointer we're passing to the atomic operation needs to be conditional
4238                    // for `ReadZeroSkipWrite`, the condition needs to *surround* the atomic op, and
4239                    // the pointer operand should be unchecked.
4240                    let policy = context.choose_bounds_check_policy(pointer);
4241                    let checked = policy == index::BoundsCheckPolicy::ReadZeroSkipWrite
4242                        && self.put_bounds_checks(pointer, context, back::Level(0), "")?;
4243
4244                    // If requested and successfully put bounds checks, continue the ternary expression.
4245                    if checked {
4246                        write!(self.out, " ? ")?;
4247                    }
4248
4249                    // Put the atomic function invocation.
4250                    match *fun {
4251                        crate::AtomicFunction::Exchange { compare: Some(cmp) } => {
4252                            write!(self.out, "{ATOMIC_COMP_EXCH_FUNCTION}({ATOMIC_REFERENCE}")?;
4253                            self.put_access_chain(pointer, policy, context)?;
4254                            write!(self.out, ", ")?;
4255                            self.put_expression(cmp, context, true)?;
4256                            write!(self.out, ", ")?;
4257                            self.put_expression(value, context, true)?;
4258                            write!(self.out, ")")?;
4259                        }
4260                        _ => {
4261                            write!(
4262                                self.out,
4263                                "{NAMESPACE}::atomic_{fun_key}_explicit({ATOMIC_REFERENCE}"
4264                            )?;
4265                            self.put_access_chain(pointer, policy, context)?;
4266                            write!(self.out, ", ")?;
4267                            self.put_expression(value, context, true)?;
4268                            write!(self.out, ", {NAMESPACE}::memory_order_relaxed)")?;
4269                        }
4270                    }
4271
4272                    // Finish the ternary expression.
4273                    if checked {
4274                        write!(self.out, " : DefaultConstructible()")?;
4275                    }
4276
4277                    // Done
4278                    writeln!(self.out, ";")?;
4279                }
4280                crate::Statement::ImageAtomic {
4281                    image,
4282                    coordinate,
4283                    array_index,
4284                    fun,
4285                    value,
4286                } => {
4287                    let address = TexelAddress {
4288                        coordinate,
4289                        array_index,
4290                        sample: None,
4291                        level: None,
4292                    };
4293                    self.put_image_atomic(level, image, &address, fun, value, context)?
4294                }
4295                crate::Statement::WorkGroupUniformLoad { pointer, result } => {
4296                    self.write_barrier(crate::Barrier::WORK_GROUP, level)?;
4297
4298                    write!(self.out, "{level}")?;
4299                    let name = self.namer.call("");
4300                    self.start_baking_expression(result, &context.expression, &name)?;
4301                    self.put_load(pointer, &context.expression, true)?;
4302                    self.named_expressions.insert(result, name);
4303
4304                    writeln!(self.out, ";")?;
4305                    self.write_barrier(crate::Barrier::WORK_GROUP, level)?;
4306                }
4307                crate::Statement::RayQuery { query, ref fun } => {
4308                    self.write_ray_query_stmt(level, context, query, fun)?;
4309                }
4310                crate::Statement::SubgroupBallot { result, predicate } => {
4311                    write!(self.out, "{level}")?;
4312                    let name = self.namer.call("");
4313                    self.start_baking_expression(result, &context.expression, &name)?;
4314                    self.named_expressions.insert(result, name);
4315                    write!(
4316                        self.out,
4317                        "{NAMESPACE}::uint4((uint64_t){NAMESPACE}::simd_ballot("
4318                    )?;
4319                    if let Some(predicate) = predicate {
4320                        self.put_expression(predicate, &context.expression, true)?;
4321                    } else {
4322                        write!(self.out, "true")?;
4323                    }
4324                    writeln!(self.out, "), 0, 0, 0);")?;
4325                }
4326                crate::Statement::SubgroupCollectiveOperation {
4327                    op,
4328                    collective_op,
4329                    argument,
4330                    result,
4331                } => {
4332                    write!(self.out, "{level}")?;
4333                    let name = self.namer.call("");
4334                    self.start_baking_expression(result, &context.expression, &name)?;
4335                    self.named_expressions.insert(result, name);
4336                    match (collective_op, op) {
4337                        (crate::CollectiveOperation::Reduce, crate::SubgroupOperation::All) => {
4338                            write!(self.out, "{NAMESPACE}::simd_all(")?
4339                        }
4340                        (crate::CollectiveOperation::Reduce, crate::SubgroupOperation::Any) => {
4341                            write!(self.out, "{NAMESPACE}::simd_any(")?
4342                        }
4343                        (crate::CollectiveOperation::Reduce, crate::SubgroupOperation::Add) => {
4344                            write!(self.out, "{NAMESPACE}::simd_sum(")?
4345                        }
4346                        (crate::CollectiveOperation::Reduce, crate::SubgroupOperation::Mul) => {
4347                            write!(self.out, "{NAMESPACE}::simd_product(")?
4348                        }
4349                        (crate::CollectiveOperation::Reduce, crate::SubgroupOperation::Max) => {
4350                            write!(self.out, "{NAMESPACE}::simd_max(")?
4351                        }
4352                        (crate::CollectiveOperation::Reduce, crate::SubgroupOperation::Min) => {
4353                            write!(self.out, "{NAMESPACE}::simd_min(")?
4354                        }
4355                        (crate::CollectiveOperation::Reduce, crate::SubgroupOperation::And) => {
4356                            write!(self.out, "{NAMESPACE}::simd_and(")?
4357                        }
4358                        (crate::CollectiveOperation::Reduce, crate::SubgroupOperation::Or) => {
4359                            write!(self.out, "{NAMESPACE}::simd_or(")?
4360                        }
4361                        (crate::CollectiveOperation::Reduce, crate::SubgroupOperation::Xor) => {
4362                            write!(self.out, "{NAMESPACE}::simd_xor(")?
4363                        }
4364                        (
4365                            crate::CollectiveOperation::ExclusiveScan,
4366                            crate::SubgroupOperation::Add,
4367                        ) => write!(self.out, "{NAMESPACE}::simd_prefix_exclusive_sum(")?,
4368                        (
4369                            crate::CollectiveOperation::ExclusiveScan,
4370                            crate::SubgroupOperation::Mul,
4371                        ) => write!(self.out, "{NAMESPACE}::simd_prefix_exclusive_product(")?,
4372                        (
4373                            crate::CollectiveOperation::InclusiveScan,
4374                            crate::SubgroupOperation::Add,
4375                        ) => write!(self.out, "{NAMESPACE}::simd_prefix_inclusive_sum(")?,
4376                        (
4377                            crate::CollectiveOperation::InclusiveScan,
4378                            crate::SubgroupOperation::Mul,
4379                        ) => write!(self.out, "{NAMESPACE}::simd_prefix_inclusive_product(")?,
4380                        _ => unimplemented!(),
4381                    }
4382                    self.put_expression(argument, &context.expression, true)?;
4383                    writeln!(self.out, ");")?;
4384                }
4385                crate::Statement::SubgroupGather {
4386                    mode,
4387                    argument,
4388                    result,
4389                } => {
4390                    write!(self.out, "{level}")?;
4391                    let name = self.namer.call("");
4392                    self.start_baking_expression(result, &context.expression, &name)?;
4393                    self.named_expressions.insert(result, name);
4394                    match mode {
4395                        crate::GatherMode::BroadcastFirst => {
4396                            write!(self.out, "{NAMESPACE}::simd_broadcast_first(")?;
4397                        }
4398                        crate::GatherMode::Broadcast(_) => {
4399                            write!(self.out, "{NAMESPACE}::simd_broadcast(")?;
4400                        }
4401                        crate::GatherMode::Shuffle(_) => {
4402                            write!(self.out, "{NAMESPACE}::simd_shuffle(")?;
4403                        }
4404                        crate::GatherMode::ShuffleDown(_) => {
4405                            write!(self.out, "{NAMESPACE}::simd_shuffle_down(")?;
4406                        }
4407                        crate::GatherMode::ShuffleUp(_) => {
4408                            write!(self.out, "{NAMESPACE}::simd_shuffle_up(")?;
4409                        }
4410                        crate::GatherMode::ShuffleXor(_) => {
4411                            write!(self.out, "{NAMESPACE}::simd_shuffle_xor(")?;
4412                        }
4413                        crate::GatherMode::QuadBroadcast(_) => {
4414                            write!(self.out, "{NAMESPACE}::quad_broadcast(")?;
4415                        }
4416                        crate::GatherMode::QuadSwap(_) => {
4417                            write!(self.out, "{NAMESPACE}::quad_shuffle_xor(")?;
4418                        }
4419                    }
4420                    self.put_expression(argument, &context.expression, true)?;
4421                    match mode {
4422                        crate::GatherMode::BroadcastFirst => {}
4423                        crate::GatherMode::Broadcast(index)
4424                        | crate::GatherMode::Shuffle(index)
4425                        | crate::GatherMode::ShuffleDown(index)
4426                        | crate::GatherMode::ShuffleUp(index)
4427                        | crate::GatherMode::ShuffleXor(index)
4428                        | crate::GatherMode::QuadBroadcast(index) => {
4429                            write!(self.out, ", ")?;
4430                            self.put_expression(index, &context.expression, true)?;
4431                        }
4432                        crate::GatherMode::QuadSwap(direction) => {
4433                            write!(self.out, ", ")?;
4434                            match direction {
4435                                crate::Direction::X => {
4436                                    write!(self.out, "1u")?;
4437                                }
4438                                crate::Direction::Y => {
4439                                    write!(self.out, "2u")?;
4440                                }
4441                                crate::Direction::Diagonal => {
4442                                    write!(self.out, "3u")?;
4443                                }
4444                            }
4445                        }
4446                    }
4447                    writeln!(self.out, ");")?;
4448                }
4449                crate::Statement::CooperativeStore { target, ref data } => {
4450                    write!(self.out, "{level}simdgroup_store(")?;
4451                    self.put_expression(target, &context.expression, true)?;
4452                    write!(self.out, ", &")?;
4453                    self.put_access_chain(
4454                        data.pointer,
4455                        context.expression.policies.index,
4456                        &context.expression,
4457                    )?;
4458                    write!(self.out, ", ")?;
4459                    self.put_expression(data.stride, &context.expression, true)?;
4460                    // See the comment in `CooperativeLoad` above: WGSL's
4461                    // row_major flag is negated when emitting Metal's
4462                    // `transpose` flag, so a col-major store (row_major=false)
4463                    // must use `transpose=true`.
4464                    if !data.row_major {
4465                        let matrix_origin = "0";
4466                        let transpose = true;
4467                        write!(self.out, ", {matrix_origin}, {transpose}")?;
4468                    }
4469                    writeln!(self.out, ");")?;
4470                }
4471                crate::Statement::RayPipelineFunction(_) => unreachable!(),
4472            }
4473        }
4474
4475        // un-emit expressions
4476        //TODO: take care of loop/continuing?
4477        for statement in statements {
4478            if let crate::Statement::Emit(ref range) = *statement {
4479                for handle in range.clone() {
4480                    self.named_expressions.shift_remove(&handle);
4481                }
4482            }
4483        }
4484        Ok(())
4485    }
4486
4487    fn put_store(
4488        &mut self,
4489        pointer: Handle<crate::Expression>,
4490        value: Handle<crate::Expression>,
4491        level: back::Level,
4492        context: &StatementContext,
4493    ) -> BackendResult {
4494        let policy = context.expression.choose_bounds_check_policy(pointer);
4495        if policy == index::BoundsCheckPolicy::ReadZeroSkipWrite
4496            && self.put_bounds_checks(pointer, &context.expression, level, "if (")?
4497        {
4498            writeln!(self.out, ") {{")?;
4499            self.put_unchecked_store(pointer, value, policy, level.next(), context)?;
4500            writeln!(self.out, "{level}}}")?;
4501        } else {
4502            self.put_unchecked_store(pointer, value, policy, level, context)?;
4503        }
4504
4505        Ok(())
4506    }
4507
4508    fn put_unchecked_store(
4509        &mut self,
4510        pointer: Handle<crate::Expression>,
4511        value: Handle<crate::Expression>,
4512        policy: index::BoundsCheckPolicy,
4513        level: back::Level,
4514        context: &StatementContext,
4515    ) -> BackendResult {
4516        let is_atomic_pointer = context
4517            .expression
4518            .resolve_type(pointer)
4519            .is_atomic_pointer(&context.expression.module.types);
4520
4521        if is_atomic_pointer {
4522            write!(
4523                self.out,
4524                "{level}{NAMESPACE}::atomic_store_explicit({ATOMIC_REFERENCE}"
4525            )?;
4526            self.put_access_chain(pointer, policy, &context.expression)?;
4527            write!(self.out, ", ")?;
4528            self.put_expression(value, &context.expression, true)?;
4529            writeln!(self.out, ", {NAMESPACE}::memory_order_relaxed);")?;
4530        } else {
4531            write!(self.out, "{level}")?;
4532            self.put_access_chain(pointer, policy, &context.expression)?;
4533            write!(self.out, " = ")?;
4534            self.put_expression(value, &context.expression, true)?;
4535            writeln!(self.out, ";")?;
4536        }
4537
4538        Ok(())
4539    }
4540
4541    pub fn write(
4542        &mut self,
4543        module: &crate::Module,
4544        info: &valid::ModuleInfo,
4545        options: &Options,
4546        pipeline_options: &PipelineOptions,
4547    ) -> Result<TranslationInfo, Error> {
4548        self.emit_int_div_checks = options.emit_int_div_checks;
4549        self.names.clear();
4550        self.namer.reset(
4551            module,
4552            &super::keywords::RESERVED_SET,
4553            proc::KeywordSet::empty(),
4554            proc::CaseInsensitiveKeywordSet::empty(),
4555            &[
4556                CLAMPED_LOD_LOAD_PREFIX,
4557                super::ray::INTERSECTION_FUNCTION_NAME,
4558                super::ray::RAY_QUERY_TRACKER_VARIABLE_PREFIX,
4559                super::ray::RAY_QUERY_T_MAX_TRACKER_VARIABLE_PREFIX,
4560            ],
4561            &mut self.names,
4562        );
4563        self.wrapped_functions.clear();
4564        self.struct_member_pads.clear();
4565
4566        writeln!(
4567            self.out,
4568            "// language: metal{}.{}",
4569            options.lang_version.0, options.lang_version.1
4570        )?;
4571        writeln!(self.out, "#include <metal_stdlib>")?;
4572        writeln!(self.out, "#include <simd/simd.h>")?;
4573        writeln!(self.out)?;
4574        // Work around Metal bug where `uint` is not available by default
4575        writeln!(self.out, "using {NAMESPACE}::uint;")?;
4576
4577        if module.uses_mesh_shaders() && options.lang_version < (3, 0) {
4578            return Err(Error::UnsupportedMeshShader);
4579        }
4580        self.needs_object_memory_barriers = module
4581            .entry_points
4582            .iter()
4583            .any(|e| e.stage == crate::ShaderStage::Task && e.task_payload.is_some());
4584
4585        if module.special_types.ray_desc.is_some()
4586            || module.special_types.ray_intersection.is_some()
4587        {
4588            if options.lang_version < (2, 4) {
4589                return Err(Error::UnsupportedRayTracing);
4590            }
4591        }
4592
4593        if options
4594            .bounds_check_policies
4595            .contains(index::BoundsCheckPolicy::ReadZeroSkipWrite)
4596        {
4597            self.put_default_constructible()?;
4598        }
4599        writeln!(self.out)?;
4600
4601        {
4602            // Make a `Vec` of all the `GlobalVariable`s that contain
4603            // runtime-sized arrays.
4604            let globals: Vec<Handle<crate::GlobalVariable>> = module
4605                .global_variables
4606                .iter()
4607                .filter(|&(_, var)| {
4608                    module.types[var.ty]
4609                        .inner
4610                        .needs_host_buffer_byte_size(&module.types)
4611                })
4612                .map(|(handle, _)| handle)
4613                .collect();
4614
4615            let mut buffer_indices = vec![];
4616            for vbm in &pipeline_options.vertex_buffer_mappings {
4617                buffer_indices.push(vbm.id);
4618            }
4619
4620            if !globals.is_empty() || !buffer_indices.is_empty() {
4621                writeln!(self.out, "struct _mslBufferSizes {{")?;
4622
4623                for global in globals {
4624                    let var = &module.global_variables[global];
4625                    let var_ty = var.ty;
4626                    match module.types[var_ty].inner {
4627                        crate::TypeInner::BindingArray { .. } => {
4628                            let n =
4629                                Self::binding_array_layout_count(module, pipeline_options, global);
4630                            writeln!(
4631                                self.out,
4632                                "{}uint {}[{n}];",
4633                                back::INDENT,
4634                                ArraySizeMember(global),
4635                            )?;
4636                        }
4637                        _ => writeln!(
4638                            self.out,
4639                            "{}uint {};",
4640                            back::INDENT,
4641                            ArraySizeMember(global)
4642                        )?,
4643                    }
4644                }
4645
4646                for idx in buffer_indices {
4647                    writeln!(self.out, "{}uint buffer_size{};", back::INDENT, idx)?;
4648                }
4649
4650                writeln!(self.out, "}};")?;
4651                writeln!(self.out)?;
4652            }
4653        };
4654
4655        self.write_type_defs(module)?;
4656        self.write_global_constants(module, info)?;
4657        self.write_functions(module, info, options, pipeline_options)
4658    }
4659
4660    /// Write the definition for the `DefaultConstructible` class.
4661    ///
4662    /// The [`ReadZeroSkipWrite`] bounds check policy requires us to be able to
4663    /// produce 'zero' values for any type, including structs, arrays, and so
4664    /// on. We could do this by emitting default constructor applications, but
4665    /// that would entail printing the name of the type, which is more trouble
4666    /// than you'd think. Instead, we just construct this magic C++14 class that
4667    /// can be converted to any type that can be default constructed, using
4668    /// template parameter inference to detect which type is needed, so we don't
4669    /// have to figure out the name.
4670    ///
4671    /// [`ReadZeroSkipWrite`]: index::BoundsCheckPolicy::ReadZeroSkipWrite
4672    fn put_default_constructible(&mut self) -> BackendResult {
4673        let tab = back::INDENT;
4674        writeln!(self.out, "struct DefaultConstructible {{")?;
4675        writeln!(self.out, "{tab}template<typename T>")?;
4676        writeln!(self.out, "{tab}operator T() && {{")?;
4677        writeln!(self.out, "{tab}{tab}return T {{}};")?;
4678        writeln!(self.out, "{tab}}}")?;
4679        writeln!(self.out, "}};")?;
4680        Ok(())
4681    }
4682
4683    fn write_type_defs(&mut self, module: &crate::Module) -> BackendResult {
4684        let mut generated_argument_buffer_wrapper = false;
4685        let mut generated_external_texture_wrapper = false;
4686        for (handle, ty) in module.types.iter() {
4687            match ty.inner {
4688                crate::TypeInner::BindingArray { .. } if !generated_argument_buffer_wrapper => {
4689                    writeln!(self.out, "template <typename T>")?;
4690                    writeln!(self.out, "struct {ARGUMENT_BUFFER_WRAPPER_STRUCT} {{")?;
4691                    writeln!(self.out, "{}T {WRAPPED_ARRAY_FIELD};", back::INDENT)?;
4692                    writeln!(self.out, "}};")?;
4693                    generated_argument_buffer_wrapper = true;
4694                }
4695                crate::TypeInner::Image {
4696                    class: crate::ImageClass::External,
4697                    ..
4698                } if !generated_external_texture_wrapper => {
4699                    let params_ty_name = &self.names
4700                        [&NameKey::Type(module.special_types.external_texture_params.unwrap())];
4701                    writeln!(self.out, "struct {EXTERNAL_TEXTURE_WRAPPER_STRUCT} {{")?;
4702                    writeln!(
4703                        self.out,
4704                        "{}{NAMESPACE}::texture2d<float, {NAMESPACE}::access::sample> plane0;",
4705                        back::INDENT
4706                    )?;
4707                    writeln!(
4708                        self.out,
4709                        "{}{NAMESPACE}::texture2d<float, {NAMESPACE}::access::sample> plane1;",
4710                        back::INDENT
4711                    )?;
4712                    writeln!(
4713                        self.out,
4714                        "{}{NAMESPACE}::texture2d<float, {NAMESPACE}::access::sample> plane2;",
4715                        back::INDENT
4716                    )?;
4717                    writeln!(self.out, "{}{params_ty_name} params;", back::INDENT)?;
4718                    writeln!(self.out, "}};")?;
4719                    generated_external_texture_wrapper = true;
4720                }
4721                _ => {}
4722            }
4723
4724            if !ty.needs_alias() {
4725                continue;
4726            }
4727            let name = &self.names[&NameKey::Type(handle)];
4728            match ty.inner {
4729                // Naga IR can pass around arrays by value, but Metal, following
4730                // C++, performs an array-to-pointer conversion (C++ [conv.array])
4731                // on expressions of array type, so assigning the array by value
4732                // isn't possible. However, Metal *does* assign structs by
4733                // value. So in our Metal output, we wrap all array types in
4734                // synthetic struct types:
4735                //
4736                //     struct type1 {
4737                //         float inner[10]
4738                //     };
4739                //
4740                // Then we carefully include `.inner` (`WRAPPED_ARRAY_FIELD`) in
4741                // any expression that actually wants access to the array.
4742                crate::TypeInner::Array {
4743                    base,
4744                    size,
4745                    stride: _,
4746                } => {
4747                    let base_name = TypeContext {
4748                        handle: base,
4749                        gctx: module.to_ctx(),
4750                        names: &self.names,
4751                        access: crate::StorageAccess::empty(),
4752                        first_time: false,
4753                    };
4754
4755                    match size.resolve(module.to_ctx())? {
4756                        proc::IndexableLength::Known(size) => {
4757                            writeln!(self.out, "struct {name} {{")?;
4758                            writeln!(
4759                                self.out,
4760                                "{}{} {}[{}];",
4761                                back::INDENT,
4762                                base_name,
4763                                WRAPPED_ARRAY_FIELD,
4764                                size
4765                            )?;
4766                            writeln!(self.out, "}};")?;
4767                        }
4768                        proc::IndexableLength::Dynamic => {
4769                            writeln!(self.out, "typedef {base_name} {name}[1];")?;
4770                        }
4771                    }
4772                }
4773                crate::TypeInner::Struct {
4774                    ref members, span, ..
4775                } => {
4776                    writeln!(self.out, "struct {name} {{")?;
4777                    let mut last_offset = 0;
4778                    for (index, member) in members.iter().enumerate() {
4779                        if member.offset > last_offset {
4780                            self.struct_member_pads.insert((handle, index as u32));
4781                            let pad = member.offset - last_offset;
4782                            writeln!(self.out, "{}char _pad{}[{}];", back::INDENT, index, pad)?;
4783                        }
4784                        let ty_inner = &module.types[member.ty].inner;
4785                        last_offset = member.offset + ty_inner.size(module.to_ctx());
4786
4787                        let member_name = &self.names[&NameKey::StructMember(handle, index as u32)];
4788
4789                        // If the member should be packed (as is the case for a misaligned vec3) issue a packed vector
4790                        match should_pack_struct_member(members, span, index, module) {
4791                            Some(scalar) => {
4792                                writeln!(
4793                                    self.out,
4794                                    "{}{}::packed_{}3 {};",
4795                                    back::INDENT,
4796                                    NAMESPACE,
4797                                    scalar.to_msl_name(),
4798                                    member_name
4799                                )?;
4800                            }
4801                            None => {
4802                                let base_name = TypeContext {
4803                                    handle: member.ty,
4804                                    gctx: module.to_ctx(),
4805                                    names: &self.names,
4806                                    access: crate::StorageAccess::empty(),
4807                                    first_time: false,
4808                                };
4809                                writeln!(
4810                                    self.out,
4811                                    "{}{} {};",
4812                                    back::INDENT,
4813                                    base_name,
4814                                    member_name
4815                                )?;
4816
4817                                // for 3-component vectors, add one component
4818                                if let crate::TypeInner::Vector {
4819                                    size: crate::VectorSize::Tri,
4820                                    scalar,
4821                                } = *ty_inner
4822                                {
4823                                    last_offset += scalar.width as u32;
4824                                }
4825                            }
4826                        }
4827                    }
4828                    if last_offset < span {
4829                        let pad = span - last_offset;
4830                        writeln!(
4831                            self.out,
4832                            "{}char _pad{}[{}];",
4833                            back::INDENT,
4834                            members.len(),
4835                            pad
4836                        )?;
4837                    }
4838                    writeln!(self.out, "}};")?;
4839                }
4840                _ => {
4841                    let ty_name = TypeContext {
4842                        handle,
4843                        gctx: module.to_ctx(),
4844                        names: &self.names,
4845                        access: crate::StorageAccess::empty(),
4846                        first_time: true,
4847                    };
4848                    writeln!(self.out, "typedef {ty_name} {name};")?;
4849                }
4850            }
4851        }
4852
4853        // Write functions to create special types.
4854        for (type_key, struct_ty) in module.special_types.predeclared_types.iter() {
4855            match type_key {
4856                &crate::PredeclaredType::ModfResult { size, scalar }
4857                | &crate::PredeclaredType::FrexpResult { size, scalar } => {
4858                    let arg_type_name_owner;
4859                    let arg_type_name = if let Some(size) = size {
4860                        arg_type_name_owner = format!(
4861                            "{NAMESPACE}::{}{}",
4862                            if scalar.width == 8 { "double" } else { "float" },
4863                            size as u8
4864                        );
4865                        &arg_type_name_owner
4866                    } else if scalar.width == 8 {
4867                        "double"
4868                    } else {
4869                        "float"
4870                    };
4871
4872                    let other_type_name_owner;
4873                    let (defined_func_name, called_func_name, other_type_name) =
4874                        if matches!(type_key, &crate::PredeclaredType::ModfResult { .. }) {
4875                            (MODF_FUNCTION, "modf", arg_type_name)
4876                        } else {
4877                            let other_type_name = if let Some(size) = size {
4878                                other_type_name_owner = format!("int{}", size as u8);
4879                                &other_type_name_owner
4880                            } else {
4881                                "int"
4882                            };
4883                            (FREXP_FUNCTION, "frexp", other_type_name)
4884                        };
4885
4886                    let struct_name = &self.names[&NameKey::Type(*struct_ty)];
4887
4888                    writeln!(self.out)?;
4889                    writeln!(
4890                        self.out,
4891                        "{struct_name} {defined_func_name}({arg_type_name} arg) {{
4892    {other_type_name} other;
4893    {arg_type_name} fract = {NAMESPACE}::{called_func_name}(arg, other);
4894    return {struct_name}{{ fract, other }};
4895}}"
4896                    )?;
4897                }
4898                &crate::PredeclaredType::AtomicCompareExchangeWeakResult(scalar) => {
4899                    let arg_type_name = scalar.to_msl_name();
4900                    let called_func_name = "atomic_compare_exchange_weak_explicit";
4901                    let defined_func_name = ATOMIC_COMP_EXCH_FUNCTION;
4902                    let struct_name = &self.names[&NameKey::Type(*struct_ty)];
4903
4904                    writeln!(self.out)?;
4905
4906                    for address_space_name in ["device", "threadgroup"] {
4907                        writeln!(
4908                            self.out,
4909                            "\
4910template <typename A>
4911{struct_name} {defined_func_name}(
4912    {address_space_name} A *atomic_ptr,
4913    {arg_type_name} cmp,
4914    {arg_type_name} v
4915) {{
4916    bool swapped = {NAMESPACE}::{called_func_name}(
4917        atomic_ptr, &cmp, v,
4918        metal::memory_order_relaxed, metal::memory_order_relaxed
4919    );
4920    return {struct_name}{{cmp, swapped}};
4921}}"
4922                        )?;
4923                    }
4924                }
4925            }
4926        }
4927
4928        Ok(())
4929    }
4930
4931    /// Writes all named constants
4932    fn write_global_constants(
4933        &mut self,
4934        module: &crate::Module,
4935        mod_info: &valid::ModuleInfo,
4936    ) -> BackendResult {
4937        let constants = module.constants.iter().filter(|&(_, c)| c.name.is_some());
4938
4939        for (handle, constant) in constants {
4940            let ty_name = TypeContext {
4941                handle: constant.ty,
4942                gctx: module.to_ctx(),
4943                names: &self.names,
4944                access: crate::StorageAccess::empty(),
4945                first_time: false,
4946            };
4947            let name = &self.names[&NameKey::Constant(handle)];
4948            write!(self.out, "constant {ty_name} {name} = ")?;
4949            self.put_const_expression(constant.init, module, mod_info, &module.global_expressions)?;
4950            writeln!(self.out, ";")?;
4951        }
4952
4953        Ok(())
4954    }
4955
4956    fn put_inline_sampler_properties(
4957        &mut self,
4958        level: back::Level,
4959        sampler: &sm::InlineSampler,
4960    ) -> BackendResult {
4961        for (&letter, address) in ['s', 't', 'r'].iter().zip(sampler.address.iter()) {
4962            writeln!(
4963                self.out,
4964                "{}{}::{}_address::{},",
4965                level,
4966                NAMESPACE,
4967                letter,
4968                address.as_str(),
4969            )?;
4970        }
4971        writeln!(
4972            self.out,
4973            "{}{}::mag_filter::{},",
4974            level,
4975            NAMESPACE,
4976            sampler.mag_filter.as_str(),
4977        )?;
4978        writeln!(
4979            self.out,
4980            "{}{}::min_filter::{},",
4981            level,
4982            NAMESPACE,
4983            sampler.min_filter.as_str(),
4984        )?;
4985        if let Some(filter) = sampler.mip_filter {
4986            writeln!(
4987                self.out,
4988                "{}{}::mip_filter::{},",
4989                level,
4990                NAMESPACE,
4991                filter.as_str(),
4992            )?;
4993        }
4994        // avoid setting it on platforms that don't support it
4995        if sampler.border_color != sm::BorderColor::TransparentBlack {
4996            writeln!(
4997                self.out,
4998                "{}{}::border_color::{},",
4999                level,
5000                NAMESPACE,
5001                sampler.border_color.as_str(),
5002            )?;
5003        }
5004        //TODO: I'm not able to feed this in a way that MSL likes:
5005        //>error: use of undeclared identifier 'lod_clamp'
5006        //>error: no member named 'max_anisotropy' in namespace 'metal'
5007        if false {
5008            if let Some(ref lod) = sampler.lod_clamp {
5009                writeln!(self.out, "{}lod_clamp({},{}),", level, lod.start, lod.end,)?;
5010            }
5011            if let Some(aniso) = sampler.max_anisotropy {
5012                writeln!(self.out, "{}max_anisotropy({}),", level, aniso.get(),)?;
5013            }
5014        }
5015        if sampler.compare_func != sm::CompareFunc::Never {
5016            writeln!(
5017                self.out,
5018                "{}{}::compare_func::{},",
5019                level,
5020                NAMESPACE,
5021                sampler.compare_func.as_str(),
5022            )?;
5023        }
5024        writeln!(
5025            self.out,
5026            "{}{}::coord::{}",
5027            level,
5028            NAMESPACE,
5029            sampler.coord.as_str()
5030        )?;
5031        Ok(())
5032    }
5033
5034    fn write_unpacking_function(
5035        &mut self,
5036        format: nt::VertexFormat,
5037    ) -> Result<(String, u32, Option<crate::VectorSize>, crate::Scalar), Error> {
5038        use crate::{Scalar, VectorSize};
5039        use nt::VertexFormat::*;
5040        match format {
5041            Uint8 => {
5042                let name = self.namer.call("unpackUint8");
5043                writeln!(self.out, "uint {name}(metal::uchar b0) {{")?;
5044                writeln!(self.out, "{}return uint(b0);", back::INDENT)?;
5045                writeln!(self.out, "}}")?;
5046                Ok((name, 1, None, Scalar::U32))
5047            }
5048            Uint8x2 => {
5049                let name = self.namer.call("unpackUint8x2");
5050                writeln!(
5051                    self.out,
5052                    "metal::uint2 {name}(metal::uchar b0, \
5053                                         metal::uchar b1) {{"
5054                )?;
5055                writeln!(self.out, "{}return metal::uint2(b0, b1);", back::INDENT)?;
5056                writeln!(self.out, "}}")?;
5057                Ok((name, 2, Some(VectorSize::Bi), Scalar::U32))
5058            }
5059            Uint8x4 => {
5060                let name = self.namer.call("unpackUint8x4");
5061                writeln!(
5062                    self.out,
5063                    "metal::uint4 {name}(metal::uchar b0, \
5064                                         metal::uchar b1, \
5065                                         metal::uchar b2, \
5066                                         metal::uchar b3) {{"
5067                )?;
5068                writeln!(
5069                    self.out,
5070                    "{}return metal::uint4(b0, b1, b2, b3);",
5071                    back::INDENT
5072                )?;
5073                writeln!(self.out, "}}")?;
5074                Ok((name, 4, Some(VectorSize::Quad), Scalar::U32))
5075            }
5076            Sint8 => {
5077                let name = self.namer.call("unpackSint8");
5078                writeln!(self.out, "int {name}(metal::uchar b0) {{")?;
5079                writeln!(self.out, "{}return int(as_type<char>(b0));", back::INDENT)?;
5080                writeln!(self.out, "}}")?;
5081                Ok((name, 1, None, Scalar::I32))
5082            }
5083            Sint8x2 => {
5084                let name = self.namer.call("unpackSint8x2");
5085                writeln!(
5086                    self.out,
5087                    "metal::int2 {name}(metal::uchar b0, \
5088                                        metal::uchar b1) {{"
5089                )?;
5090                writeln!(
5091                    self.out,
5092                    "{}return metal::int2(as_type<char>(b0), \
5093                                          as_type<char>(b1));",
5094                    back::INDENT
5095                )?;
5096                writeln!(self.out, "}}")?;
5097                Ok((name, 2, Some(VectorSize::Bi), Scalar::I32))
5098            }
5099            Sint8x4 => {
5100                let name = self.namer.call("unpackSint8x4");
5101                writeln!(
5102                    self.out,
5103                    "metal::int4 {name}(metal::uchar b0, \
5104                                        metal::uchar b1, \
5105                                        metal::uchar b2, \
5106                                        metal::uchar b3) {{"
5107                )?;
5108                writeln!(
5109                    self.out,
5110                    "{}return metal::int4(as_type<char>(b0), \
5111                                          as_type<char>(b1), \
5112                                          as_type<char>(b2), \
5113                                          as_type<char>(b3));",
5114                    back::INDENT
5115                )?;
5116                writeln!(self.out, "}}")?;
5117                Ok((name, 4, Some(VectorSize::Quad), Scalar::I32))
5118            }
5119            Unorm8 => {
5120                let name = self.namer.call("unpackUnorm8");
5121                writeln!(self.out, "float {name}(metal::uchar b0) {{")?;
5122                writeln!(
5123                    self.out,
5124                    "{}return float(float(b0) / 255.0f);",
5125                    back::INDENT
5126                )?;
5127                writeln!(self.out, "}}")?;
5128                Ok((name, 1, None, Scalar::F32))
5129            }
5130            Unorm8x2 => {
5131                let name = self.namer.call("unpackUnorm8x2");
5132                writeln!(
5133                    self.out,
5134                    "metal::float2 {name}(metal::uchar b0, \
5135                                          metal::uchar b1) {{"
5136                )?;
5137                writeln!(
5138                    self.out,
5139                    "{}return metal::float2(float(b0) / 255.0f, \
5140                                            float(b1) / 255.0f);",
5141                    back::INDENT
5142                )?;
5143                writeln!(self.out, "}}")?;
5144                Ok((name, 2, Some(VectorSize::Bi), Scalar::F32))
5145            }
5146            Unorm8x4 => {
5147                let name = self.namer.call("unpackUnorm8x4");
5148                writeln!(
5149                    self.out,
5150                    "metal::float4 {name}(metal::uchar b0, \
5151                                          metal::uchar b1, \
5152                                          metal::uchar b2, \
5153                                          metal::uchar b3) {{"
5154                )?;
5155                writeln!(
5156                    self.out,
5157                    "{}return metal::float4(float(b0) / 255.0f, \
5158                                            float(b1) / 255.0f, \
5159                                            float(b2) / 255.0f, \
5160                                            float(b3) / 255.0f);",
5161                    back::INDENT
5162                )?;
5163                writeln!(self.out, "}}")?;
5164                Ok((name, 4, Some(VectorSize::Quad), Scalar::F32))
5165            }
5166            Snorm8 => {
5167                let name = self.namer.call("unpackSnorm8");
5168                writeln!(self.out, "float {name}(metal::uchar b0) {{")?;
5169                writeln!(
5170                    self.out,
5171                    "{}return float(metal::max(-1.0f, as_type<char>(b0) / 127.0f));",
5172                    back::INDENT
5173                )?;
5174                writeln!(self.out, "}}")?;
5175                Ok((name, 1, None, Scalar::F32))
5176            }
5177            Snorm8x2 => {
5178                let name = self.namer.call("unpackSnorm8x2");
5179                writeln!(
5180                    self.out,
5181                    "metal::float2 {name}(metal::uchar b0, \
5182                                          metal::uchar b1) {{"
5183                )?;
5184                writeln!(
5185                    self.out,
5186                    "{}return metal::float2(metal::max(-1.0f, as_type<char>(b0) / 127.0f), \
5187                                            metal::max(-1.0f, as_type<char>(b1) / 127.0f));",
5188                    back::INDENT
5189                )?;
5190                writeln!(self.out, "}}")?;
5191                Ok((name, 2, Some(VectorSize::Bi), Scalar::F32))
5192            }
5193            Snorm8x4 => {
5194                let name = self.namer.call("unpackSnorm8x4");
5195                writeln!(
5196                    self.out,
5197                    "metal::float4 {name}(metal::uchar b0, \
5198                                          metal::uchar b1, \
5199                                          metal::uchar b2, \
5200                                          metal::uchar b3) {{"
5201                )?;
5202                writeln!(
5203                    self.out,
5204                    "{}return metal::float4(metal::max(-1.0f, as_type<char>(b0) / 127.0f), \
5205                                            metal::max(-1.0f, as_type<char>(b1) / 127.0f), \
5206                                            metal::max(-1.0f, as_type<char>(b2) / 127.0f), \
5207                                            metal::max(-1.0f, as_type<char>(b3) / 127.0f));",
5208                    back::INDENT
5209                )?;
5210                writeln!(self.out, "}}")?;
5211                Ok((name, 4, Some(VectorSize::Quad), Scalar::F32))
5212            }
5213            Uint16 => {
5214                let name = self.namer.call("unpackUint16");
5215                writeln!(
5216                    self.out,
5217                    "metal::uint {name}(metal::uint b0, \
5218                                        metal::uint b1) {{"
5219                )?;
5220                writeln!(
5221                    self.out,
5222                    "{}return metal::uint(b1 << 8 | b0);",
5223                    back::INDENT
5224                )?;
5225                writeln!(self.out, "}}")?;
5226                Ok((name, 2, None, Scalar::U32))
5227            }
5228            Uint16x2 => {
5229                let name = self.namer.call("unpackUint16x2");
5230                writeln!(
5231                    self.out,
5232                    "metal::uint2 {name}(metal::uint b0, \
5233                                         metal::uint b1, \
5234                                         metal::uint b2, \
5235                                         metal::uint b3) {{"
5236                )?;
5237                writeln!(
5238                    self.out,
5239                    "{}return metal::uint2(b1 << 8 | b0, \
5240                                           b3 << 8 | b2);",
5241                    back::INDENT
5242                )?;
5243                writeln!(self.out, "}}")?;
5244                Ok((name, 4, Some(VectorSize::Bi), Scalar::U32))
5245            }
5246            Uint16x4 => {
5247                let name = self.namer.call("unpackUint16x4");
5248                writeln!(
5249                    self.out,
5250                    "metal::uint4 {name}(metal::uint b0, \
5251                                         metal::uint b1, \
5252                                         metal::uint b2, \
5253                                         metal::uint b3, \
5254                                         metal::uint b4, \
5255                                         metal::uint b5, \
5256                                         metal::uint b6, \
5257                                         metal::uint b7) {{"
5258                )?;
5259                writeln!(
5260                    self.out,
5261                    "{}return metal::uint4(b1 << 8 | b0, \
5262                                           b3 << 8 | b2, \
5263                                           b5 << 8 | b4, \
5264                                           b7 << 8 | b6);",
5265                    back::INDENT
5266                )?;
5267                writeln!(self.out, "}}")?;
5268                Ok((name, 8, Some(VectorSize::Quad), Scalar::U32))
5269            }
5270            Sint16 => {
5271                let name = self.namer.call("unpackSint16");
5272                writeln!(
5273                    self.out,
5274                    "int {name}(metal::ushort b0, \
5275                                metal::ushort b1) {{"
5276                )?;
5277                writeln!(
5278                    self.out,
5279                    "{}return int(as_type<short>(metal::ushort(b1 << 8 | b0)));",
5280                    back::INDENT
5281                )?;
5282                writeln!(self.out, "}}")?;
5283                Ok((name, 2, None, Scalar::I32))
5284            }
5285            Sint16x2 => {
5286                let name = self.namer.call("unpackSint16x2");
5287                writeln!(
5288                    self.out,
5289                    "metal::int2 {name}(metal::ushort b0, \
5290                                        metal::ushort b1, \
5291                                        metal::ushort b2, \
5292                                        metal::ushort b3) {{"
5293                )?;
5294                writeln!(
5295                    self.out,
5296                    "{}return metal::int2(as_type<short>(metal::ushort(b1 << 8 | b0)), \
5297                                          as_type<short>(metal::ushort(b3 << 8 | b2)));",
5298                    back::INDENT
5299                )?;
5300                writeln!(self.out, "}}")?;
5301                Ok((name, 4, Some(VectorSize::Bi), Scalar::I32))
5302            }
5303            Sint16x4 => {
5304                let name = self.namer.call("unpackSint16x4");
5305                writeln!(
5306                    self.out,
5307                    "metal::int4 {name}(metal::ushort b0, \
5308                                        metal::ushort b1, \
5309                                        metal::ushort b2, \
5310                                        metal::ushort b3, \
5311                                        metal::ushort b4, \
5312                                        metal::ushort b5, \
5313                                        metal::ushort b6, \
5314                                        metal::ushort b7) {{"
5315                )?;
5316                writeln!(
5317                    self.out,
5318                    "{}return metal::int4(as_type<short>(metal::ushort(b1 << 8 | b0)), \
5319                                          as_type<short>(metal::ushort(b3 << 8 | b2)), \
5320                                          as_type<short>(metal::ushort(b5 << 8 | b4)), \
5321                                          as_type<short>(metal::ushort(b7 << 8 | b6)));",
5322                    back::INDENT
5323                )?;
5324                writeln!(self.out, "}}")?;
5325                Ok((name, 8, Some(VectorSize::Quad), Scalar::I32))
5326            }
5327            Unorm16 => {
5328                let name = self.namer.call("unpackUnorm16");
5329                writeln!(
5330                    self.out,
5331                    "float {name}(metal::ushort b0, \
5332                                  metal::ushort b1) {{"
5333                )?;
5334                writeln!(
5335                    self.out,
5336                    "{}return float(float(b1 << 8 | b0) / 65535.0f);",
5337                    back::INDENT
5338                )?;
5339                writeln!(self.out, "}}")?;
5340                Ok((name, 2, None, Scalar::F32))
5341            }
5342            Unorm16x2 => {
5343                let name = self.namer.call("unpackUnorm16x2");
5344                writeln!(
5345                    self.out,
5346                    "metal::float2 {name}(metal::ushort b0, \
5347                                          metal::ushort b1, \
5348                                          metal::ushort b2, \
5349                                          metal::ushort b3) {{"
5350                )?;
5351                writeln!(
5352                    self.out,
5353                    "{}return metal::float2(float(b1 << 8 | b0) / 65535.0f, \
5354                                            float(b3 << 8 | b2) / 65535.0f);",
5355                    back::INDENT
5356                )?;
5357                writeln!(self.out, "}}")?;
5358                Ok((name, 4, Some(VectorSize::Bi), Scalar::F32))
5359            }
5360            Unorm16x4 => {
5361                let name = self.namer.call("unpackUnorm16x4");
5362                writeln!(
5363                    self.out,
5364                    "metal::float4 {name}(metal::ushort b0, \
5365                                          metal::ushort b1, \
5366                                          metal::ushort b2, \
5367                                          metal::ushort b3, \
5368                                          metal::ushort b4, \
5369                                          metal::ushort b5, \
5370                                          metal::ushort b6, \
5371                                          metal::ushort b7) {{"
5372                )?;
5373                writeln!(
5374                    self.out,
5375                    "{}return metal::float4(float(b1 << 8 | b0) / 65535.0f, \
5376                                            float(b3 << 8 | b2) / 65535.0f, \
5377                                            float(b5 << 8 | b4) / 65535.0f, \
5378                                            float(b7 << 8 | b6) / 65535.0f);",
5379                    back::INDENT
5380                )?;
5381                writeln!(self.out, "}}")?;
5382                Ok((name, 8, Some(VectorSize::Quad), Scalar::F32))
5383            }
5384            Snorm16 => {
5385                let name = self.namer.call("unpackSnorm16");
5386                writeln!(
5387                    self.out,
5388                    "float {name}(metal::ushort b0, \
5389                                  metal::ushort b1) {{"
5390                )?;
5391                writeln!(
5392                    self.out,
5393                    "{}return metal::unpack_snorm2x16_to_float(b1 << 8 | b0).x;",
5394                    back::INDENT
5395                )?;
5396                writeln!(self.out, "}}")?;
5397                Ok((name, 2, None, Scalar::F32))
5398            }
5399            Snorm16x2 => {
5400                let name = self.namer.call("unpackSnorm16x2");
5401                writeln!(
5402                    self.out,
5403                    "metal::float2 {name}(uint b0, \
5404                                          uint b1, \
5405                                          uint b2, \
5406                                          uint b3) {{"
5407                )?;
5408                writeln!(
5409                    self.out,
5410                    "{}return metal::unpack_snorm2x16_to_float(b3 << 24 | b2 << 16 | b1 << 8 | b0);",
5411                    back::INDENT
5412                )?;
5413                writeln!(self.out, "}}")?;
5414                Ok((name, 4, Some(VectorSize::Bi), Scalar::F32))
5415            }
5416            Snorm16x4 => {
5417                let name = self.namer.call("unpackSnorm16x4");
5418                writeln!(
5419                    self.out,
5420                    "metal::float4 {name}(uint b0, \
5421                                          uint b1, \
5422                                          uint b2, \
5423                                          uint b3, \
5424                                          uint b4, \
5425                                          uint b5, \
5426                                          uint b6, \
5427                                          uint b7) {{"
5428                )?;
5429                writeln!(
5430                    self.out,
5431                    "{}return metal::float4(metal::unpack_snorm2x16_to_float(b3 << 24 | b2 << 16 | b1 << 8 | b0), \
5432                                            metal::unpack_snorm2x16_to_float(b7 << 24 | b6 << 16 | b5 << 8 | b4));",
5433                    back::INDENT
5434                )?;
5435                writeln!(self.out, "}}")?;
5436                Ok((name, 8, Some(VectorSize::Quad), Scalar::F32))
5437            }
5438            Float16 => {
5439                let name = self.namer.call("unpackFloat16");
5440                writeln!(
5441                    self.out,
5442                    "float {name}(metal::ushort b0, \
5443                                  metal::ushort b1) {{"
5444                )?;
5445                writeln!(
5446                    self.out,
5447                    "{}return float(as_type<half>(metal::ushort(b1 << 8 | b0)));",
5448                    back::INDENT
5449                )?;
5450                writeln!(self.out, "}}")?;
5451                Ok((name, 2, None, Scalar::F32))
5452            }
5453            Float16x2 => {
5454                let name = self.namer.call("unpackFloat16x2");
5455                writeln!(
5456                    self.out,
5457                    "metal::float2 {name}(metal::ushort b0, \
5458                                          metal::ushort b1, \
5459                                          metal::ushort b2, \
5460                                          metal::ushort b3) {{"
5461                )?;
5462                writeln!(
5463                    self.out,
5464                    "{}return metal::float2(as_type<half>(metal::ushort(b1 << 8 | b0)), \
5465                                            as_type<half>(metal::ushort(b3 << 8 | b2)));",
5466                    back::INDENT
5467                )?;
5468                writeln!(self.out, "}}")?;
5469                Ok((name, 4, Some(VectorSize::Bi), Scalar::F32))
5470            }
5471            Float16x4 => {
5472                let name = self.namer.call("unpackFloat16x4");
5473                writeln!(
5474                    self.out,
5475                    "metal::float4 {name}(metal::ushort b0, \
5476                                        metal::ushort b1, \
5477                                        metal::ushort b2, \
5478                                        metal::ushort b3, \
5479                                        metal::ushort b4, \
5480                                        metal::ushort b5, \
5481                                        metal::ushort b6, \
5482                                        metal::ushort b7) {{"
5483                )?;
5484                writeln!(
5485                    self.out,
5486                    "{}return metal::float4(as_type<half>(metal::ushort(b1 << 8 | b0)), \
5487                                          as_type<half>(metal::ushort(b3 << 8 | b2)), \
5488                                          as_type<half>(metal::ushort(b5 << 8 | b4)), \
5489                                          as_type<half>(metal::ushort(b7 << 8 | b6)));",
5490                    back::INDENT
5491                )?;
5492                writeln!(self.out, "}}")?;
5493                Ok((name, 8, Some(VectorSize::Quad), Scalar::F32))
5494            }
5495            Float32 => {
5496                let name = self.namer.call("unpackFloat32");
5497                writeln!(
5498                    self.out,
5499                    "float {name}(uint b0, \
5500                                  uint b1, \
5501                                  uint b2, \
5502                                  uint b3) {{"
5503                )?;
5504                writeln!(
5505                    self.out,
5506                    "{}return as_type<float>(b3 << 24 | b2 << 16 | b1 << 8 | b0);",
5507                    back::INDENT
5508                )?;
5509                writeln!(self.out, "}}")?;
5510                Ok((name, 4, None, Scalar::F32))
5511            }
5512            Float32x2 => {
5513                let name = self.namer.call("unpackFloat32x2");
5514                writeln!(
5515                    self.out,
5516                    "metal::float2 {name}(uint b0, \
5517                                          uint b1, \
5518                                          uint b2, \
5519                                          uint b3, \
5520                                          uint b4, \
5521                                          uint b5, \
5522                                          uint b6, \
5523                                          uint b7) {{"
5524                )?;
5525                writeln!(
5526                    self.out,
5527                    "{}return metal::float2(as_type<float>(b3 << 24 | b2 << 16 | b1 << 8 | b0), \
5528                                            as_type<float>(b7 << 24 | b6 << 16 | b5 << 8 | b4));",
5529                    back::INDENT
5530                )?;
5531                writeln!(self.out, "}}")?;
5532                Ok((name, 8, Some(VectorSize::Bi), Scalar::F32))
5533            }
5534            Float32x3 => {
5535                let name = self.namer.call("unpackFloat32x3");
5536                writeln!(
5537                    self.out,
5538                    "metal::float3 {name}(uint b0, \
5539                                          uint b1, \
5540                                          uint b2, \
5541                                          uint b3, \
5542                                          uint b4, \
5543                                          uint b5, \
5544                                          uint b6, \
5545                                          uint b7, \
5546                                          uint b8, \
5547                                          uint b9, \
5548                                          uint b10, \
5549                                          uint b11) {{"
5550                )?;
5551                writeln!(
5552                    self.out,
5553                    "{}return metal::float3(as_type<float>(b3 << 24 | b2 << 16 | b1 << 8 | b0), \
5554                                            as_type<float>(b7 << 24 | b6 << 16 | b5 << 8 | b4), \
5555                                            as_type<float>(b11 << 24 | b10 << 16 | b9 << 8 | b8));",
5556                    back::INDENT
5557                )?;
5558                writeln!(self.out, "}}")?;
5559                Ok((name, 12, Some(VectorSize::Tri), Scalar::F32))
5560            }
5561            Float32x4 => {
5562                let name = self.namer.call("unpackFloat32x4");
5563                writeln!(
5564                    self.out,
5565                    "metal::float4 {name}(uint b0, \
5566                                          uint b1, \
5567                                          uint b2, \
5568                                          uint b3, \
5569                                          uint b4, \
5570                                          uint b5, \
5571                                          uint b6, \
5572                                          uint b7, \
5573                                          uint b8, \
5574                                          uint b9, \
5575                                          uint b10, \
5576                                          uint b11, \
5577                                          uint b12, \
5578                                          uint b13, \
5579                                          uint b14, \
5580                                          uint b15) {{"
5581                )?;
5582                writeln!(
5583                    self.out,
5584                    "{}return metal::float4(as_type<float>(b3 << 24 | b2 << 16 | b1 << 8 | b0), \
5585                                            as_type<float>(b7 << 24 | b6 << 16 | b5 << 8 | b4), \
5586                                            as_type<float>(b11 << 24 | b10 << 16 | b9 << 8 | b8), \
5587                                            as_type<float>(b15 << 24 | b14 << 16 | b13 << 8 | b12));",
5588                    back::INDENT
5589                )?;
5590                writeln!(self.out, "}}")?;
5591                Ok((name, 16, Some(VectorSize::Quad), Scalar::F32))
5592            }
5593            Uint32 => {
5594                let name = self.namer.call("unpackUint32");
5595                writeln!(
5596                    self.out,
5597                    "uint {name}(uint b0, \
5598                                 uint b1, \
5599                                 uint b2, \
5600                                 uint b3) {{"
5601                )?;
5602                writeln!(
5603                    self.out,
5604                    "{}return (b3 << 24 | b2 << 16 | b1 << 8 | b0);",
5605                    back::INDENT
5606                )?;
5607                writeln!(self.out, "}}")?;
5608                Ok((name, 4, None, Scalar::U32))
5609            }
5610            Uint32x2 => {
5611                let name = self.namer.call("unpackUint32x2");
5612                writeln!(
5613                    self.out,
5614                    "uint2 {name}(uint b0, \
5615                                  uint b1, \
5616                                  uint b2, \
5617                                  uint b3, \
5618                                  uint b4, \
5619                                  uint b5, \
5620                                  uint b6, \
5621                                  uint b7) {{"
5622                )?;
5623                writeln!(
5624                    self.out,
5625                    "{}return uint2((b3 << 24 | b2 << 16 | b1 << 8 | b0), \
5626                                    (b7 << 24 | b6 << 16 | b5 << 8 | b4));",
5627                    back::INDENT
5628                )?;
5629                writeln!(self.out, "}}")?;
5630                Ok((name, 8, Some(VectorSize::Bi), Scalar::U32))
5631            }
5632            Uint32x3 => {
5633                let name = self.namer.call("unpackUint32x3");
5634                writeln!(
5635                    self.out,
5636                    "uint3 {name}(uint b0, \
5637                                  uint b1, \
5638                                  uint b2, \
5639                                  uint b3, \
5640                                  uint b4, \
5641                                  uint b5, \
5642                                  uint b6, \
5643                                  uint b7, \
5644                                  uint b8, \
5645                                  uint b9, \
5646                                  uint b10, \
5647                                  uint b11) {{"
5648                )?;
5649                writeln!(
5650                    self.out,
5651                    "{}return uint3((b3 << 24 | b2 << 16 | b1 << 8 | b0), \
5652                                    (b7 << 24 | b6 << 16 | b5 << 8 | b4), \
5653                                    (b11 << 24 | b10 << 16 | b9 << 8 | b8));",
5654                    back::INDENT
5655                )?;
5656                writeln!(self.out, "}}")?;
5657                Ok((name, 12, Some(VectorSize::Tri), Scalar::U32))
5658            }
5659            Uint32x4 => {
5660                let name = self.namer.call("unpackUint32x4");
5661                writeln!(
5662                    self.out,
5663                    "{NAMESPACE}::uint4 {name}(uint b0, \
5664                                  uint b1, \
5665                                  uint b2, \
5666                                  uint b3, \
5667                                  uint b4, \
5668                                  uint b5, \
5669                                  uint b6, \
5670                                  uint b7, \
5671                                  uint b8, \
5672                                  uint b9, \
5673                                  uint b10, \
5674                                  uint b11, \
5675                                  uint b12, \
5676                                  uint b13, \
5677                                  uint b14, \
5678                                  uint b15) {{"
5679                )?;
5680                writeln!(
5681                    self.out,
5682                    "{}return {NAMESPACE}::uint4((b3 << 24 | b2 << 16 | b1 << 8 | b0), \
5683                                    (b7 << 24 | b6 << 16 | b5 << 8 | b4), \
5684                                    (b11 << 24 | b10 << 16 | b9 << 8 | b8), \
5685                                    (b15 << 24 | b14 << 16 | b13 << 8 | b12));",
5686                    back::INDENT
5687                )?;
5688                writeln!(self.out, "}}")?;
5689                Ok((name, 16, Some(VectorSize::Quad), Scalar::U32))
5690            }
5691            Sint32 => {
5692                let name = self.namer.call("unpackSint32");
5693                writeln!(
5694                    self.out,
5695                    "int {name}(uint b0, \
5696                                uint b1, \
5697                                uint b2, \
5698                                uint b3) {{"
5699                )?;
5700                writeln!(
5701                    self.out,
5702                    "{}return as_type<int>(b3 << 24 | b2 << 16 | b1 << 8 | b0);",
5703                    back::INDENT
5704                )?;
5705                writeln!(self.out, "}}")?;
5706                Ok((name, 4, None, Scalar::I32))
5707            }
5708            Sint32x2 => {
5709                let name = self.namer.call("unpackSint32x2");
5710                writeln!(
5711                    self.out,
5712                    "metal::int2 {name}(uint b0, \
5713                                        uint b1, \
5714                                        uint b2, \
5715                                        uint b3, \
5716                                        uint b4, \
5717                                        uint b5, \
5718                                        uint b6, \
5719                                        uint b7) {{"
5720                )?;
5721                writeln!(
5722                    self.out,
5723                    "{}return metal::int2(as_type<int>(b3 << 24 | b2 << 16 | b1 << 8 | b0), \
5724                                          as_type<int>(b7 << 24 | b6 << 16 | b5 << 8 | b4));",
5725                    back::INDENT
5726                )?;
5727                writeln!(self.out, "}}")?;
5728                Ok((name, 8, Some(VectorSize::Bi), Scalar::I32))
5729            }
5730            Sint32x3 => {
5731                let name = self.namer.call("unpackSint32x3");
5732                writeln!(
5733                    self.out,
5734                    "metal::int3 {name}(uint b0, \
5735                                        uint b1, \
5736                                        uint b2, \
5737                                        uint b3, \
5738                                        uint b4, \
5739                                        uint b5, \
5740                                        uint b6, \
5741                                        uint b7, \
5742                                        uint b8, \
5743                                        uint b9, \
5744                                        uint b10, \
5745                                        uint b11) {{"
5746                )?;
5747                writeln!(
5748                    self.out,
5749                    "{}return metal::int3(as_type<int>(b3 << 24 | b2 << 16 | b1 << 8 | b0), \
5750                                          as_type<int>(b7 << 24 | b6 << 16 | b5 << 8 | b4), \
5751                                          as_type<int>(b11 << 24 | b10 << 16 | b9 << 8 | b8));",
5752                    back::INDENT
5753                )?;
5754                writeln!(self.out, "}}")?;
5755                Ok((name, 12, Some(VectorSize::Tri), Scalar::I32))
5756            }
5757            Sint32x4 => {
5758                let name = self.namer.call("unpackSint32x4");
5759                writeln!(
5760                    self.out,
5761                    "metal::int4 {name}(uint b0, \
5762                                        uint b1, \
5763                                        uint b2, \
5764                                        uint b3, \
5765                                        uint b4, \
5766                                        uint b5, \
5767                                        uint b6, \
5768                                        uint b7, \
5769                                        uint b8, \
5770                                        uint b9, \
5771                                        uint b10, \
5772                                        uint b11, \
5773                                        uint b12, \
5774                                        uint b13, \
5775                                        uint b14, \
5776                                        uint b15) {{"
5777                )?;
5778                writeln!(
5779                    self.out,
5780                    "{}return metal::int4(as_type<int>(b3 << 24 | b2 << 16 | b1 << 8 | b0), \
5781                                          as_type<int>(b7 << 24 | b6 << 16 | b5 << 8 | b4), \
5782                                          as_type<int>(b11 << 24 | b10 << 16 | b9 << 8 | b8), \
5783                                          as_type<int>(b15 << 24 | b14 << 16 | b13 << 8 | b12));",
5784                    back::INDENT
5785                )?;
5786                writeln!(self.out, "}}")?;
5787                Ok((name, 16, Some(VectorSize::Quad), Scalar::I32))
5788            }
5789            Unorm10_10_10_2 => {
5790                let name = self.namer.call("unpackUnorm10_10_10_2");
5791                writeln!(
5792                    self.out,
5793                    "metal::float4 {name}(uint b0, \
5794                                          uint b1, \
5795                                          uint b2, \
5796                                          uint b3) {{"
5797                )?;
5798                writeln!(
5799                    self.out,
5800                    // The following is correct for RGBA packing, but our format seems to
5801                    // match ABGR, which can be fed into the Metal builtin function
5802                    // unpack_unorm10a2_to_float.
5803                    /*
5804                    "{}uint v = (b3 << 24 | b2 << 16 | b1 << 8 | b0); \
5805                       uint r = (v & 0xFFC00000) >> 22; \
5806                       uint g = (v & 0x003FF000) >> 12; \
5807                       uint b = (v & 0x00000FFC) >> 2; \
5808                       uint a = (v & 0x00000003); \
5809                       return metal::float4(float(r) / 1023.0f, float(g) / 1023.0f, float(b) / 1023.0f, float(a) / 3.0f);",
5810                    */
5811                    "{}return metal::unpack_unorm10a2_to_float(b3 << 24 | b2 << 16 | b1 << 8 | b0);",
5812                    back::INDENT
5813                )?;
5814                writeln!(self.out, "}}")?;
5815                Ok((name, 4, Some(VectorSize::Quad), Scalar::F32))
5816            }
5817            Unorm8x4Bgra => {
5818                let name = self.namer.call("unpackUnorm8x4Bgra");
5819                writeln!(
5820                    self.out,
5821                    "metal::float4 {name}(metal::uchar b0, \
5822                                          metal::uchar b1, \
5823                                          metal::uchar b2, \
5824                                          metal::uchar b3) {{"
5825                )?;
5826                writeln!(
5827                    self.out,
5828                    "{}return metal::float4(float(b2) / 255.0f, \
5829                                            float(b1) / 255.0f, \
5830                                            float(b0) / 255.0f, \
5831                                            float(b3) / 255.0f);",
5832                    back::INDENT
5833                )?;
5834                writeln!(self.out, "}}")?;
5835                Ok((name, 4, Some(VectorSize::Quad), Scalar::F32))
5836            }
5837            Float64 | Float64x2 | Float64x3 | Float64x4 => unreachable!(),
5838        }
5839    }
5840
5841    fn write_wrapped_unary_op(
5842        &mut self,
5843        module: &crate::Module,
5844        func_ctx: &back::FunctionCtx,
5845        op: crate::UnaryOperator,
5846        operand: Handle<crate::Expression>,
5847    ) -> BackendResult {
5848        let operand_ty = func_ctx.resolve_type(operand, &module.types);
5849        match op {
5850            // Negating the TYPE_MIN of a two's complement signed integer
5851            // type causes overflow, which is undefined behaviour in MSL. To
5852            // avoid this we bitcast the value to unsigned and negate it,
5853            // then bitcast back to signed.
5854            // This adheres to the WGSL spec in that the negative of the
5855            // type's minimum value should equal to the minimum value.
5856            crate::UnaryOperator::Negate
5857                if operand_ty.scalar_kind() == Some(crate::ScalarKind::Sint) =>
5858            {
5859                let Some((vector_size, scalar)) = operand_ty.vector_size_and_scalar() else {
5860                    return Ok(());
5861                };
5862                let wrapped = WrappedFunction::UnaryOp {
5863                    op,
5864                    ty: (vector_size, scalar),
5865                };
5866                if !self.wrapped_functions.insert(wrapped) {
5867                    return Ok(());
5868                }
5869
5870                let unsigned_scalar = crate::Scalar {
5871                    kind: crate::ScalarKind::Uint,
5872                    ..scalar
5873                };
5874                let mut type_name = String::new();
5875                let mut unsigned_type_name = String::new();
5876                match vector_size {
5877                    None => {
5878                        put_numeric_type(&mut type_name, scalar, &[])?;
5879                        put_numeric_type(&mut unsigned_type_name, unsigned_scalar, &[])?
5880                    }
5881                    Some(size) => {
5882                        put_numeric_type(&mut type_name, scalar, &[size])?;
5883                        put_numeric_type(&mut unsigned_type_name, unsigned_scalar, &[size])?;
5884                    }
5885                };
5886
5887                writeln!(self.out, "{type_name} {NEG_FUNCTION}({type_name} val) {{")?;
5888                let level = back::Level(1);
5889                // For sub-32-bit types, C++ integer promotion widens
5890                // `-as_type<ushort>(val)` to `int`, so we need static_cast
5891                // to truncate back before the outer as_type bitcast.
5892                if scalar.width < 4 {
5893                    writeln!(
5894                        self.out,
5895                        "{level}return as_type<{type_name}>(static_cast<{unsigned_type_name}>(-as_type<{unsigned_type_name}>(val)));"
5896                    )?;
5897                } else {
5898                    writeln!(
5899                        self.out,
5900                        "{level}return as_type<{type_name}>(-as_type<{unsigned_type_name}>(val));"
5901                    )?;
5902                }
5903                writeln!(self.out, "}}")?;
5904                writeln!(self.out)?;
5905            }
5906            _ => {}
5907        }
5908        Ok(())
5909    }
5910
5911    fn write_wrapped_binary_op(
5912        &mut self,
5913        module: &crate::Module,
5914        func_ctx: &back::FunctionCtx,
5915        expr: Handle<crate::Expression>,
5916        op: crate::BinaryOperator,
5917        left: Handle<crate::Expression>,
5918        right: Handle<crate::Expression>,
5919    ) -> BackendResult {
5920        let expr_ty = func_ctx.resolve_type(expr, &module.types);
5921        let left_ty = func_ctx.resolve_type(left, &module.types);
5922        let right_ty = func_ctx.resolve_type(right, &module.types);
5923        match (op, expr_ty.scalar_kind()) {
5924            // Signed integer division of TYPE_MIN / -1, or signed or
5925            // unsigned division by zero, gives an unspecified value in MSL.
5926            // We override the divisor to 1 in these cases.
5927            // This adheres to the WGSL spec in that:
5928            // * TYPE_MIN / -1 == TYPE_MIN
5929            // * x / 0 == x
5930            (
5931                crate::BinaryOperator::Divide,
5932                Some(crate::ScalarKind::Sint | crate::ScalarKind::Uint),
5933            ) if self.emit_int_div_checks => {
5934                let Some(left_wrapped_ty) = left_ty.vector_size_and_scalar() else {
5935                    return Ok(());
5936                };
5937                let Some(right_wrapped_ty) = right_ty.vector_size_and_scalar() else {
5938                    return Ok(());
5939                };
5940                let wrapped = WrappedFunction::BinaryOp {
5941                    op,
5942                    left_ty: left_wrapped_ty,
5943                    right_ty: right_wrapped_ty,
5944                };
5945                if !self.wrapped_functions.insert(wrapped) {
5946                    return Ok(());
5947                }
5948
5949                let Some((vector_size, scalar)) = expr_ty.vector_size_and_scalar() else {
5950                    return Ok(());
5951                };
5952                let mut type_name = String::new();
5953                match vector_size {
5954                    None => put_numeric_type(&mut type_name, scalar, &[])?,
5955                    Some(size) => put_numeric_type(&mut type_name, scalar, &[size])?,
5956                };
5957                writeln!(
5958                    self.out,
5959                    "{type_name} {DIV_FUNCTION}({type_name} lhs, {type_name} rhs) {{"
5960                )?;
5961                let level = back::Level(1);
5962                // Sub-32-bit types need typed literal wrappers (e.g. `short(1)`)
5963                // to avoid ambiguous metal::select overloads. For >= 32-bit,
5964                // bare literals like `1`, `-1`, `0` are unambiguous.
5965                let (lp, rp) = if scalar.width < 4 {
5966                    (format!("{type_name}("), ")".to_string())
5967                } else {
5968                    (String::new(), String::new())
5969                };
5970                match scalar.kind {
5971                    crate::ScalarKind::Sint => {
5972                        let min_val = match scalar.width {
5973                            2 => crate::Literal::I16(i16::MIN),
5974                            4 => crate::Literal::I32(i32::MIN),
5975                            8 => crate::Literal::I64(i64::MIN),
5976                            _ => {
5977                                return Err(Error::GenericValidation(format!(
5978                                    "Unexpected width for scalar {scalar:?}"
5979                                )));
5980                            }
5981                        };
5982                        write!(
5983                            self.out,
5984                            "{level}return lhs / metal::select(rhs, {lp}1{rp}, (lhs == "
5985                        )?;
5986                        self.put_literal(min_val)?;
5987                        writeln!(self.out, " & rhs == {lp}-1{rp}) | (rhs == {lp}0{rp}));")?
5988                    }
5989                    crate::ScalarKind::Uint => {
5990                        let suffix = if scalar.width < 4 { "" } else { "u" };
5991                        writeln!(
5992                            self.out,
5993                            "{level}return lhs / metal::select(rhs, {lp}1{suffix}{rp}, rhs == {lp}0{suffix}{rp});"
5994                        )?
5995                    }
5996                    _ => unreachable!(),
5997                }
5998                writeln!(self.out, "}}")?;
5999                writeln!(self.out)?;
6000            }
6001            // Integer modulo where one or both operands are negative, or the
6002            // divisor is zero, is undefined behaviour in MSL. To avoid this
6003            // we use the following equation:
6004            //
6005            // dividend - (dividend / divisor) * divisor
6006            //
6007            // overriding the divisor to 1 if either it is 0, or it is -1
6008            // and the dividend is TYPE_MIN.
6009            //
6010            // This adheres to the WGSL spec in that:
6011            // * TYPE_MIN % -1 == 0
6012            // * x % 0 == 0
6013            (
6014                crate::BinaryOperator::Modulo,
6015                Some(crate::ScalarKind::Sint | crate::ScalarKind::Uint),
6016            ) if self.emit_int_div_checks => {
6017                let Some(left_wrapped_ty) = left_ty.vector_size_and_scalar() else {
6018                    return Ok(());
6019                };
6020                let Some((right_vector_size, right_scalar)) = right_ty.vector_size_and_scalar()
6021                else {
6022                    return Ok(());
6023                };
6024                let wrapped = WrappedFunction::BinaryOp {
6025                    op,
6026                    left_ty: left_wrapped_ty,
6027                    right_ty: (right_vector_size, right_scalar),
6028                };
6029                if !self.wrapped_functions.insert(wrapped) {
6030                    return Ok(());
6031                }
6032
6033                let Some((vector_size, scalar)) = expr_ty.vector_size_and_scalar() else {
6034                    return Ok(());
6035                };
6036                let mut type_name = String::new();
6037                match vector_size {
6038                    None => put_numeric_type(&mut type_name, scalar, &[])?,
6039                    Some(size) => put_numeric_type(&mut type_name, scalar, &[size])?,
6040                };
6041                let mut rhs_type_name = String::new();
6042                match right_vector_size {
6043                    None => put_numeric_type(&mut rhs_type_name, right_scalar, &[])?,
6044                    Some(size) => put_numeric_type(&mut rhs_type_name, right_scalar, &[size])?,
6045                };
6046
6047                writeln!(
6048                    self.out,
6049                    "{type_name} {MOD_FUNCTION}({type_name} lhs, {type_name} rhs) {{"
6050                )?;
6051                let level = back::Level(1);
6052                let (lp, rp) = if scalar.width < 4 {
6053                    (format!("{type_name}("), ")".to_string())
6054                } else {
6055                    (String::new(), String::new())
6056                };
6057                match scalar.kind {
6058                    crate::ScalarKind::Sint => {
6059                        let min_val = match scalar.width {
6060                            2 => crate::Literal::I16(i16::MIN),
6061                            4 => crate::Literal::I32(i32::MIN),
6062                            8 => crate::Literal::I64(i64::MIN),
6063                            _ => {
6064                                return Err(Error::GenericValidation(format!(
6065                                    "Unexpected width for scalar {scalar:?}"
6066                                )));
6067                            }
6068                        };
6069                        write!(
6070                            self.out,
6071                            "{level}{rhs_type_name} divisor = metal::select(rhs, {lp}1{rp}, (lhs == "
6072                        )?;
6073                        self.put_literal(min_val)?;
6074                        writeln!(self.out, " & rhs == {lp}-1{rp}) | (rhs == {lp}0{rp}));")?;
6075                        writeln!(self.out, "{level}return lhs - (lhs / divisor) * divisor;")?
6076                    }
6077                    crate::ScalarKind::Uint => {
6078                        let suffix = if scalar.width < 4 { "" } else { "u" };
6079                        writeln!(
6080                            self.out,
6081                            "{level}return lhs % metal::select(rhs, {lp}1{suffix}{rp}, rhs == {lp}0{suffix}{rp});"
6082                        )?
6083                    }
6084                    _ => unreachable!(),
6085                }
6086                writeln!(self.out, "}}")?;
6087                writeln!(self.out)?;
6088            }
6089            _ => {}
6090        }
6091        Ok(())
6092    }
6093
6094    /// Build the mangled helper name for integer vector dot products.
6095    ///
6096    /// `scalar` must be a concrete integer scalar type.
6097    ///
6098    /// Result format: `{DOT_FUNCTION_PREFIX}_{type}{N}` (e.g., `naga_dot_int3`).
6099    fn get_dot_wrapper_function_helper_name(
6100        &self,
6101        scalar: crate::Scalar,
6102        size: crate::VectorSize,
6103    ) -> String {
6104        // Check for consistency with [`super::keywords::RESERVED_SET`]
6105        debug_assert!(concrete_int_scalars().any(|s| s == scalar));
6106
6107        let type_name = scalar.to_msl_name();
6108        let size_suffix = common::vector_size_str(size);
6109        format!("{DOT_FUNCTION_PREFIX}_{type_name}{size_suffix}")
6110    }
6111
6112    #[allow(clippy::too_many_arguments)]
6113    fn write_wrapped_math_function(
6114        &mut self,
6115        module: &crate::Module,
6116        func_ctx: &back::FunctionCtx,
6117        fun: crate::MathFunction,
6118        arg: Handle<crate::Expression>,
6119        _arg1: Option<Handle<crate::Expression>>,
6120        _arg2: Option<Handle<crate::Expression>>,
6121        _arg3: Option<Handle<crate::Expression>>,
6122    ) -> BackendResult {
6123        let arg_ty = func_ctx.resolve_type(arg, &module.types);
6124        match fun {
6125            // Taking the absolute value of the TYPE_MIN of a two's
6126            // complement signed integer type causes overflow, which is
6127            // undefined behaviour in MSL. To avoid this, when the value is
6128            // negative we bitcast the value to unsigned and negate it, then
6129            // bitcast back to signed.
6130            // This adheres to the WGSL spec in that the absolute of the
6131            // type's minimum value should equal to the minimum value.
6132            crate::MathFunction::Abs if arg_ty.scalar_kind() == Some(crate::ScalarKind::Sint) => {
6133                let Some((vector_size, scalar)) = arg_ty.vector_size_and_scalar() else {
6134                    return Ok(());
6135                };
6136                let wrapped = WrappedFunction::Math {
6137                    fun,
6138                    arg_ty: (vector_size, scalar),
6139                };
6140                if !self.wrapped_functions.insert(wrapped) {
6141                    return Ok(());
6142                }
6143
6144                let unsigned_scalar = crate::Scalar {
6145                    kind: crate::ScalarKind::Uint,
6146                    ..scalar
6147                };
6148                let mut type_name = String::new();
6149                let mut unsigned_type_name = String::new();
6150                match vector_size {
6151                    None => {
6152                        put_numeric_type(&mut type_name, scalar, &[])?;
6153                        put_numeric_type(&mut unsigned_type_name, unsigned_scalar, &[])?
6154                    }
6155                    Some(size) => {
6156                        put_numeric_type(&mut type_name, scalar, &[size])?;
6157                        put_numeric_type(&mut unsigned_type_name, unsigned_scalar, &[size])?;
6158                    }
6159                };
6160
6161                writeln!(self.out, "{type_name} {ABS_FUNCTION}({type_name} val) {{")?;
6162                let level = back::Level(1);
6163                let zero = if scalar.width < 4 {
6164                    format!("{type_name}(0)")
6165                } else {
6166                    "0".to_string()
6167                };
6168                let neg_expr = if scalar.width < 4 {
6169                    format!(
6170                        "static_cast<{unsigned_type_name}>(-as_type<{unsigned_type_name}>(val))"
6171                    )
6172                } else {
6173                    format!("-as_type<{unsigned_type_name}>(val)")
6174                };
6175                writeln!(self.out, "{level}return metal::select(as_type<{type_name}>({neg_expr}), val, val >= {zero});")?;
6176                writeln!(self.out, "}}")?;
6177                writeln!(self.out)?;
6178            }
6179
6180            crate::MathFunction::Dot => match *arg_ty {
6181                crate::TypeInner::Vector { size, scalar }
6182                    if matches!(
6183                        scalar.kind,
6184                        crate::ScalarKind::Sint | crate::ScalarKind::Uint
6185                    ) =>
6186                {
6187                    // De-duplicate per (fun, arg type) like other wrapped math functions
6188                    let wrapped = WrappedFunction::Math {
6189                        fun,
6190                        arg_ty: (Some(size), scalar),
6191                    };
6192                    if !self.wrapped_functions.insert(wrapped) {
6193                        return Ok(());
6194                    }
6195
6196                    let mut vec_ty = String::new();
6197                    put_numeric_type(&mut vec_ty, scalar, &[size])?;
6198                    let mut ret_ty = String::new();
6199                    put_numeric_type(&mut ret_ty, scalar, &[])?;
6200
6201                    let fun_name = self.get_dot_wrapper_function_helper_name(scalar, size);
6202
6203                    // Emit function signature and body using put_dot_product for the expression
6204                    writeln!(self.out, "{ret_ty} {fun_name}({vec_ty} a, {vec_ty} b) {{")?;
6205                    let level = back::Level(1);
6206                    write!(self.out, "{level}return ")?;
6207                    self.put_dot_product("a", "b", size as usize, |writer, name, index| {
6208                        write!(writer.out, "{name}.{}", back::COMPONENTS[index])?;
6209                        Ok(())
6210                    })?;
6211                    writeln!(self.out, ";")?;
6212                    writeln!(self.out, "}}")?;
6213                    writeln!(self.out)?;
6214                }
6215                _ => {}
6216            },
6217
6218            _ => {}
6219        }
6220        Ok(())
6221    }
6222
6223    fn write_wrapped_cast(
6224        &mut self,
6225        module: &crate::Module,
6226        func_ctx: &back::FunctionCtx,
6227        expr: Handle<crate::Expression>,
6228        kind: crate::ScalarKind,
6229        convert: Option<crate::Bytes>,
6230    ) -> BackendResult {
6231        // Avoid undefined behaviour when casting from a float to integer
6232        // when the value is out of range for the target type. Additionally
6233        // ensure we clamp to the correct value as per the WGSL spec.
6234        //
6235        // https://www.w3.org/TR/WGSL/#floating-point-conversion:
6236        // * If X is exactly representable in the target type T, then the
6237        //   result is that value.
6238        // * Otherwise, the result is the value in T closest to
6239        //   truncate(X) and also exactly representable in the original
6240        //   floating point type.
6241        let src_ty = func_ctx.resolve_type(expr, &module.types);
6242        let Some(width) = convert else {
6243            return Ok(());
6244        };
6245        let Some((vector_size, src_scalar)) = src_ty.vector_size_and_scalar() else {
6246            return Ok(());
6247        };
6248        let dst_scalar = crate::Scalar { kind, width };
6249        if src_scalar.kind != crate::ScalarKind::Float
6250            || (dst_scalar.kind != crate::ScalarKind::Sint
6251                && dst_scalar.kind != crate::ScalarKind::Uint)
6252        {
6253            return Ok(());
6254        }
6255        let wrapped = WrappedFunction::Cast {
6256            src_scalar,
6257            vector_size,
6258            dst_scalar,
6259        };
6260        if !self.wrapped_functions.insert(wrapped) {
6261            return Ok(());
6262        }
6263        let (min, max) = proc::min_max_float_representable_by(src_scalar, dst_scalar);
6264
6265        let mut src_type_name = String::new();
6266        match vector_size {
6267            None => put_numeric_type(&mut src_type_name, src_scalar, &[])?,
6268            Some(size) => put_numeric_type(&mut src_type_name, src_scalar, &[size])?,
6269        };
6270        let mut dst_type_name = String::new();
6271        match vector_size {
6272            None => put_numeric_type(&mut dst_type_name, dst_scalar, &[])?,
6273            Some(size) => put_numeric_type(&mut dst_type_name, dst_scalar, &[size])?,
6274        };
6275        let fun_name = match dst_scalar {
6276            crate::Scalar::I32 => F2I32_FUNCTION,
6277            crate::Scalar::U32 => F2U32_FUNCTION,
6278            crate::Scalar::I64 => F2I64_FUNCTION,
6279            crate::Scalar::U64 => F2U64_FUNCTION,
6280            _ => unreachable!(),
6281        };
6282
6283        writeln!(
6284            self.out,
6285            "{dst_type_name} {fun_name}({src_type_name} value) {{"
6286        )?;
6287        let level = back::Level(1);
6288        write!(
6289            self.out,
6290            "{level}return static_cast<{dst_type_name}>({NAMESPACE}::clamp(value, "
6291        )?;
6292        self.put_literal(min)?;
6293        write!(self.out, ", ")?;
6294        self.put_literal(max)?;
6295        writeln!(self.out, "));")?;
6296        writeln!(self.out, "}}")?;
6297        writeln!(self.out)?;
6298        Ok(())
6299    }
6300
6301    /// Helper function used by [`Self::write_wrapped_image_load`] and
6302    /// [`Self::write_wrapped_image_sample`] to write the shared YUV to RGB
6303    /// conversion code for external textures. Expects the preceding code to
6304    /// declare the Y component as a `float` variable of name `y`, the UV
6305    /// components as a `float2` variable of name `uv`, and the external
6306    /// texture params as a variable of name `params`. The emitted code will
6307    /// return the result.
6308    fn write_convert_yuv_to_rgb_and_return(
6309        &mut self,
6310        level: back::Level,
6311        y: &str,
6312        uv: &str,
6313        params: &str,
6314    ) -> BackendResult {
6315        let l1 = level;
6316        let l2 = l1.next();
6317
6318        // Convert from YUV to non-linear RGB in the source color space.
6319        writeln!(
6320            self.out,
6321            "{l1}float3 srcGammaRgb = ({params}.yuv_conversion_matrix * float4({y}, {uv}, 1.0)).rgb;"
6322        )?;
6323
6324        // Apply the inverse of the source transfer function to convert to
6325        // linear RGB in the source color space.
6326        writeln!(self.out, "{l1}float3 srcLinearRgb = {NAMESPACE}::select(")?;
6327        writeln!(self.out, "{l2}{NAMESPACE}::pow((srcGammaRgb + {params}.src_tf.a - 1.0) / {params}.src_tf.a, {params}.src_tf.g),")?;
6328        writeln!(self.out, "{l2}srcGammaRgb / {params}.src_tf.k,")?;
6329        writeln!(
6330            self.out,
6331            "{l2}srcGammaRgb < {params}.src_tf.k * {params}.src_tf.b);"
6332        )?;
6333
6334        // Multiply by the gamut conversion matrix to convert to linear RGB in
6335        // the destination color space.
6336        writeln!(
6337            self.out,
6338            "{l1}float3 dstLinearRgb = {params}.gamut_conversion_matrix * srcLinearRgb;"
6339        )?;
6340
6341        // Finally, apply the dest transfer function to convert to non-linear
6342        // RGB in the destination color space, and return the result.
6343        writeln!(self.out, "{l1}float3 dstGammaRgb = {NAMESPACE}::select(")?;
6344        writeln!(self.out, "{l2}{params}.dst_tf.a * {NAMESPACE}::pow(dstLinearRgb, 1.0 / {params}.dst_tf.g) - ({params}.dst_tf.a - 1),")?;
6345        writeln!(self.out, "{l2}{params}.dst_tf.k * dstLinearRgb,")?;
6346        writeln!(self.out, "{l2}dstLinearRgb < {params}.dst_tf.b);")?;
6347
6348        writeln!(self.out, "{l1}return float4(dstGammaRgb, 1.0);")?;
6349        Ok(())
6350    }
6351
6352    #[allow(clippy::too_many_arguments)]
6353    fn write_wrapped_image_load(
6354        &mut self,
6355        module: &crate::Module,
6356        func_ctx: &back::FunctionCtx,
6357        image: Handle<crate::Expression>,
6358        _coordinate: Handle<crate::Expression>,
6359        _array_index: Option<Handle<crate::Expression>>,
6360        _sample: Option<Handle<crate::Expression>>,
6361        _level: Option<Handle<crate::Expression>>,
6362    ) -> BackendResult {
6363        // We currently only need to wrap image loads for external textures
6364        let class = match *func_ctx.resolve_type(image, &module.types) {
6365            crate::TypeInner::Image { class, .. } => class,
6366            _ => unreachable!(),
6367        };
6368        if class != crate::ImageClass::External {
6369            return Ok(());
6370        }
6371        let wrapped = WrappedFunction::ImageLoad { class };
6372        if !self.wrapped_functions.insert(wrapped) {
6373            return Ok(());
6374        }
6375
6376        writeln!(self.out, "float4 {IMAGE_LOAD_EXTERNAL_FUNCTION}({EXTERNAL_TEXTURE_WRAPPER_STRUCT} tex, uint2 coords) {{")?;
6377        let l1 = back::Level(1);
6378        let l2 = l1.next();
6379        let l3 = l2.next();
6380        writeln!(
6381            self.out,
6382            "{l1}uint2 plane0_size = uint2(tex.plane0.get_width(), tex.plane0.get_height());"
6383        )?;
6384        // Clamp coords to provided size of external texture to prevent OOB
6385        // read. If params.size is zero then clamp to the actual size of the
6386        // texture.
6387        writeln!(
6388            self.out,
6389            "{l1}uint2 cropped_size = {NAMESPACE}::any(tex.params.size != 0) ? tex.params.size : plane0_size;"
6390        )?;
6391        writeln!(
6392            self.out,
6393            "{l1}coords = {NAMESPACE}::min(coords, cropped_size - 1);"
6394        )?;
6395
6396        // Apply load transformation
6397        writeln!(self.out, "{l1}uint2 plane0_coords = uint2({NAMESPACE}::round(tex.params.load_transform * float3(float2(coords), 1.0)));")?;
6398        writeln!(self.out, "{l1}if (tex.params.num_planes == 1u) {{")?;
6399        // For single plane, simply read from plane0
6400        writeln!(self.out, "{l2}return tex.plane0.read(plane0_coords);")?;
6401        writeln!(self.out, "{l1}}} else {{")?;
6402
6403        // Chroma planes may be subsampled so we must scale the coords accordingly.
6404        writeln!(
6405            self.out,
6406            "{l2}uint2 plane1_size = uint2(tex.plane1.get_width(), tex.plane1.get_height());"
6407        )?;
6408        writeln!(self.out, "{l2}uint2 plane1_coords = uint2({NAMESPACE}::floor(float2(plane0_coords) * float2(plane1_size) / float2(plane0_size)));")?;
6409
6410        // For multi-plane, read the Y value from plane 0
6411        writeln!(self.out, "{l2}float y = tex.plane0.read(plane0_coords).x;")?;
6412
6413        writeln!(self.out, "{l2}float2 uv;")?;
6414        writeln!(self.out, "{l2}if (tex.params.num_planes == 2u) {{")?;
6415        // For 2 planes, read UV from interleaved plane 1
6416        writeln!(self.out, "{l3}uv = tex.plane1.read(plane1_coords).xy;")?;
6417        writeln!(self.out, "{l2}}} else {{")?;
6418        // For 3 planes, read U and V from planes 1 and 2 respectively
6419        writeln!(
6420            self.out,
6421            "{l2}uint2 plane2_size = uint2(tex.plane2.get_width(), tex.plane2.get_height());"
6422        )?;
6423        writeln!(self.out, "{l2}uint2 plane2_coords = uint2({NAMESPACE}::floor(float2(plane0_coords) * float2(plane2_size) / float2(plane0_size)));")?;
6424        writeln!(
6425            self.out,
6426            "{l3}uv = float2(tex.plane1.read(plane1_coords).x, tex.plane2.read(plane2_coords).x);"
6427        )?;
6428        writeln!(self.out, "{l2}}}")?;
6429
6430        self.write_convert_yuv_to_rgb_and_return(l2, "y", "uv", "tex.params")?;
6431
6432        writeln!(self.out, "{l1}}}")?;
6433        writeln!(self.out, "}}")?;
6434        writeln!(self.out)?;
6435        Ok(())
6436    }
6437
6438    #[allow(clippy::too_many_arguments)]
6439    fn write_wrapped_image_sample(
6440        &mut self,
6441        module: &crate::Module,
6442        func_ctx: &back::FunctionCtx,
6443        image: Handle<crate::Expression>,
6444        _sampler: Handle<crate::Expression>,
6445        _gather: Option<crate::SwizzleComponent>,
6446        _coordinate: Handle<crate::Expression>,
6447        _array_index: Option<Handle<crate::Expression>>,
6448        _offset: Option<Handle<crate::Expression>>,
6449        _level: crate::SampleLevel,
6450        _depth_ref: Option<Handle<crate::Expression>>,
6451        clamp_to_edge: bool,
6452    ) -> BackendResult {
6453        // We currently only need to wrap textureSampleBaseClampToEdge, for
6454        // both sampled and external textures.
6455        if !clamp_to_edge {
6456            return Ok(());
6457        }
6458        let class = match *func_ctx.resolve_type(image, &module.types) {
6459            crate::TypeInner::Image { class, .. } => class,
6460            _ => unreachable!(),
6461        };
6462        let wrapped = WrappedFunction::ImageSample {
6463            class,
6464            clamp_to_edge: true,
6465        };
6466        if !self.wrapped_functions.insert(wrapped) {
6467            return Ok(());
6468        }
6469        match class {
6470            crate::ImageClass::External => {
6471                writeln!(self.out, "float4 {IMAGE_SAMPLE_BASE_CLAMP_TO_EDGE_FUNCTION}({EXTERNAL_TEXTURE_WRAPPER_STRUCT} tex, {NAMESPACE}::sampler samp, float2 coords) {{")?;
6472                let l1 = back::Level(1);
6473                let l2 = l1.next();
6474                let l3 = l2.next();
6475                writeln!(self.out, "{l1}uint2 plane0_size = uint2(tex.plane0.get_width(), tex.plane0.get_height());")?;
6476                writeln!(
6477                    self.out,
6478                    "{l1}coords = tex.params.sample_transform * float3(coords, 1.0);"
6479                )?;
6480
6481                // Calculate the sample bounds. The purported size of the texture
6482                // (params.size) is irrelevant here as we are dealing with normalized
6483                // coordinates. Usually we would clamp to (0,0)..(1,1). However, we must
6484                // apply the sample transformation to that, also bearing in mind that it
6485                // may contain a flip on either axis. We calculate and adjust for the
6486                // half-texel separately for each plane as it depends on the actual
6487                // texture size which may vary between planes.
6488                writeln!(
6489                    self.out,
6490                    "{l1}float2 bounds_min = tex.params.sample_transform * float3(0.0, 0.0, 1.0);"
6491                )?;
6492                writeln!(
6493                    self.out,
6494                    "{l1}float2 bounds_max = tex.params.sample_transform * float3(1.0, 1.0, 1.0);"
6495                )?;
6496                writeln!(self.out, "{l1}float4 bounds = float4({NAMESPACE}::min(bounds_min, bounds_max), {NAMESPACE}::max(bounds_min, bounds_max));")?;
6497                writeln!(
6498                    self.out,
6499                    "{l1}float2 plane0_half_texel = float2(0.5, 0.5) / float2(plane0_size);"
6500                )?;
6501                writeln!(
6502                    self.out,
6503                    "{l1}float2 plane0_coords = {NAMESPACE}::clamp(coords, bounds.xy + plane0_half_texel, bounds.zw - plane0_half_texel);"
6504                )?;
6505                writeln!(self.out, "{l1}if (tex.params.num_planes == 1u) {{")?;
6506                // For single plane, simply sample from plane0
6507                writeln!(
6508                    self.out,
6509                    "{l2}return tex.plane0.sample(samp, plane0_coords, {NAMESPACE}::level(0.0f));"
6510                )?;
6511                writeln!(self.out, "{l1}}} else {{")?;
6512                writeln!(self.out, "{l2}uint2 plane1_size = uint2(tex.plane1.get_width(), tex.plane1.get_height());")?;
6513                writeln!(
6514                    self.out,
6515                    "{l2}float2 plane1_half_texel = float2(0.5, 0.5) / float2(plane1_size);"
6516                )?;
6517                writeln!(
6518                    self.out,
6519                    "{l2}float2 plane1_coords = {NAMESPACE}::clamp(coords, bounds.xy + plane1_half_texel, bounds.zw - plane1_half_texel);"
6520                )?;
6521
6522                // For multi-plane, sample the Y value from plane 0
6523                writeln!(
6524                    self.out,
6525                    "{l2}float y = tex.plane0.sample(samp, plane0_coords, {NAMESPACE}::level(0.0f)).r;"
6526                )?;
6527                writeln!(self.out, "{l2}float2 uv = float2(0.0, 0.0);")?;
6528                writeln!(self.out, "{l2}if (tex.params.num_planes == 2u) {{")?;
6529                // For 2 planes, sample UV from interleaved plane 1
6530                writeln!(
6531                    self.out,
6532                    "{l3}uv = tex.plane1.sample(samp, plane1_coords, {NAMESPACE}::level(0.0f)).xy;"
6533                )?;
6534                writeln!(self.out, "{l2}}} else {{")?;
6535                // For 3 planes, sample U and V from planes 1 and 2 respectively
6536                writeln!(self.out, "{l3}uint2 plane2_size = uint2(tex.plane2.get_width(), tex.plane2.get_height());")?;
6537                writeln!(
6538                    self.out,
6539                    "{l3}float2 plane2_half_texel = float2(0.5, 0.5) / float2(plane2_size);"
6540                )?;
6541                writeln!(
6542                    self.out,
6543                    "{l3}float2 plane2_coords = {NAMESPACE}::clamp(coords, bounds.xy + plane2_half_texel, bounds.zw - plane1_half_texel);"
6544                )?;
6545                writeln!(self.out, "{l3}uv.x = tex.plane1.sample(samp, plane1_coords, {NAMESPACE}::level(0.0f)).x;")?;
6546                writeln!(self.out, "{l3}uv.y = tex.plane2.sample(samp, plane2_coords, {NAMESPACE}::level(0.0f)).x;")?;
6547                writeln!(self.out, "{l2}}}")?;
6548
6549                self.write_convert_yuv_to_rgb_and_return(l2, "y", "uv", "tex.params")?;
6550
6551                writeln!(self.out, "{l1}}}")?;
6552                writeln!(self.out, "}}")?;
6553                writeln!(self.out)?;
6554            }
6555            _ => {
6556                writeln!(self.out, "{NAMESPACE}::float4 {IMAGE_SAMPLE_BASE_CLAMP_TO_EDGE_FUNCTION}({NAMESPACE}::texture2d<float, {NAMESPACE}::access::sample> tex, {NAMESPACE}::sampler samp, {NAMESPACE}::float2 coords) {{")?;
6557                let l1 = back::Level(1);
6558                writeln!(self.out, "{l1}{NAMESPACE}::float2 half_texel = 0.5 / {NAMESPACE}::float2(tex.get_width(0u), tex.get_height(0u));")?;
6559                writeln!(
6560                    self.out,
6561                    "{l1}return tex.sample(samp, {NAMESPACE}::clamp(coords, half_texel, 1.0 - half_texel), {NAMESPACE}::level(0.0));"
6562                )?;
6563                writeln!(self.out, "}}")?;
6564                writeln!(self.out)?;
6565            }
6566        }
6567        Ok(())
6568    }
6569
6570    fn write_wrapped_image_query(
6571        &mut self,
6572        module: &crate::Module,
6573        func_ctx: &back::FunctionCtx,
6574        image: Handle<crate::Expression>,
6575        query: crate::ImageQuery,
6576    ) -> BackendResult {
6577        // We currently only need to wrap size image queries for external textures
6578        if !matches!(query, crate::ImageQuery::Size { .. }) {
6579            return Ok(());
6580        }
6581        let class = match *func_ctx.resolve_type(image, &module.types) {
6582            crate::TypeInner::Image { class, .. } => class,
6583            _ => unreachable!(),
6584        };
6585        if class != crate::ImageClass::External {
6586            return Ok(());
6587        }
6588        let wrapped = WrappedFunction::ImageQuerySize { class };
6589        if !self.wrapped_functions.insert(wrapped) {
6590            return Ok(());
6591        }
6592        writeln!(
6593            self.out,
6594            "uint2 {IMAGE_SIZE_EXTERNAL_FUNCTION}({EXTERNAL_TEXTURE_WRAPPER_STRUCT} tex) {{"
6595        )?;
6596        let l1 = back::Level(1);
6597        let l2 = l1.next();
6598        writeln!(
6599            self.out,
6600            "{l1}if ({NAMESPACE}::any(tex.params.size != uint2(0u))) {{"
6601        )?;
6602        writeln!(self.out, "{l2}return tex.params.size;")?;
6603        writeln!(self.out, "{l1}}} else {{")?;
6604        // params.size == (0, 0) indicates to query and return plane 0's actual size
6605        writeln!(
6606            self.out,
6607            "{l2}return uint2(tex.plane0.get_width(), tex.plane0.get_height());"
6608        )?;
6609        writeln!(self.out, "{l1}}}")?;
6610        writeln!(self.out, "}}")?;
6611        writeln!(self.out)?;
6612        Ok(())
6613    }
6614
6615    fn write_wrapped_cooperative_load(
6616        &mut self,
6617        module: &crate::Module,
6618        func_ctx: &back::FunctionCtx,
6619        columns: crate::CooperativeSize,
6620        rows: crate::CooperativeSize,
6621        pointer: Handle<crate::Expression>,
6622    ) -> BackendResult {
6623        let ptr_ty = func_ctx.resolve_type(pointer, &module.types);
6624        let space = ptr_ty.pointer_space().unwrap();
6625        let space_name = space.to_msl_name().unwrap_or_default();
6626        let scalar = ptr_ty
6627            .pointer_base_type()
6628            .unwrap()
6629            .inner_with(&module.types)
6630            .scalar()
6631            .unwrap();
6632        let wrapped = WrappedFunction::CooperativeLoad {
6633            space_name,
6634            columns,
6635            rows,
6636            scalar,
6637        };
6638        if !self.wrapped_functions.insert(wrapped) {
6639            return Ok(());
6640        }
6641        let scalar_name = scalar.to_msl_name();
6642        writeln!(
6643            self.out,
6644            "{NAMESPACE}::simdgroup_{scalar_name}{}x{} {COOPERATIVE_LOAD_FUNCTION}(const {space_name} {scalar_name}* ptr, int stride, bool is_row_major) {{",
6645            columns as u32, rows as u32,
6646        )?;
6647        let l1 = back::Level(1);
6648        writeln!(
6649            self.out,
6650            "{l1}{NAMESPACE}::simdgroup_{scalar_name}{}x{} m;",
6651            columns as u32, rows as u32
6652        )?;
6653        let matrix_origin = "0";
6654        writeln!(
6655            self.out,
6656            "{l1}simdgroup_load(m, ptr, stride, {matrix_origin}, is_row_major);"
6657        )?;
6658        writeln!(self.out, "{l1}return m;")?;
6659        writeln!(self.out, "}}")?;
6660        writeln!(self.out)?;
6661        Ok(())
6662    }
6663
6664    fn write_wrapped_cooperative_multiply_add(
6665        &mut self,
6666        module: &crate::Module,
6667        func_ctx: &back::FunctionCtx,
6668        space: crate::AddressSpace,
6669        a: Handle<crate::Expression>,
6670        b: Handle<crate::Expression>,
6671        c: Handle<crate::Expression>,
6672    ) -> BackendResult {
6673        let space_name = space.to_msl_name().unwrap_or_default();
6674        let (a_c, a_r, ab_scalar) = match *func_ctx.resolve_type(a, &module.types) {
6675            crate::TypeInner::CooperativeMatrix {
6676                columns,
6677                rows,
6678                scalar,
6679                ..
6680            } => (columns, rows, scalar),
6681            _ => unreachable!(),
6682        };
6683        let (b_c, b_r) = match *func_ctx.resolve_type(b, &module.types) {
6684            crate::TypeInner::CooperativeMatrix { columns, rows, .. } => (columns, rows),
6685            _ => unreachable!(),
6686        };
6687        let c_scalar = match *func_ctx.resolve_type(c, &module.types) {
6688            crate::TypeInner::CooperativeMatrix { scalar, .. } => scalar,
6689            _ => unreachable!(),
6690        };
6691        let wrapped = WrappedFunction::CooperativeMultiplyAdd {
6692            space_name,
6693            columns: b_c,
6694            rows: a_r,
6695            intermediate: a_c,
6696            ab_scalar,
6697            c_scalar,
6698        };
6699        if !self.wrapped_functions.insert(wrapped) {
6700            return Ok(());
6701        }
6702        let ab_scalar_name = ab_scalar.to_msl_name();
6703        let c_scalar_name = c_scalar.to_msl_name();
6704        writeln!(
6705            self.out,
6706            "{NAMESPACE}::simdgroup_{c_scalar_name}{}x{} {COOPERATIVE_MULTIPLY_ADD_FUNCTION}(const {space_name} {NAMESPACE}::simdgroup_{ab_scalar_name}{}x{}& a, const {space_name} {NAMESPACE}::simdgroup_{ab_scalar_name}{}x{}& b, const {space_name} {NAMESPACE}::simdgroup_{c_scalar_name}{}x{}& c) {{",
6707            b_c as u32, a_r as u32, a_c as u32, a_r as u32, b_c as u32, b_r as u32, b_c as u32, a_r as u32,
6708        )?;
6709        let l1 = back::Level(1);
6710        writeln!(
6711            self.out,
6712            "{l1}{NAMESPACE}::simdgroup_{c_scalar_name}{}x{} d;",
6713            b_c as u32, a_r as u32
6714        )?;
6715        writeln!(self.out, "{l1}simdgroup_multiply_accumulate(d,a,b,c);")?;
6716        writeln!(self.out, "{l1}return d;")?;
6717        writeln!(self.out, "}}")?;
6718        writeln!(self.out)?;
6719        Ok(())
6720    }
6721
6722    pub(super) fn write_wrapped_functions(
6723        &mut self,
6724        module: &crate::Module,
6725        func_ctx: &back::FunctionCtx,
6726        options: &Options,
6727    ) -> BackendResult {
6728        for (expr_handle, expr) in func_ctx.expressions.iter() {
6729            match *expr {
6730                crate::Expression::Unary { op, expr: operand } => {
6731                    self.write_wrapped_unary_op(module, func_ctx, op, operand)?;
6732                }
6733                crate::Expression::Binary { op, left, right } => {
6734                    self.write_wrapped_binary_op(module, func_ctx, expr_handle, op, left, right)?;
6735                }
6736                crate::Expression::Math {
6737                    fun,
6738                    arg,
6739                    arg1,
6740                    arg2,
6741                    arg3,
6742                } => {
6743                    self.write_wrapped_math_function(module, func_ctx, fun, arg, arg1, arg2, arg3)?;
6744                }
6745                crate::Expression::As {
6746                    expr,
6747                    kind,
6748                    convert,
6749                } => {
6750                    self.write_wrapped_cast(module, func_ctx, expr, kind, convert)?;
6751                }
6752                crate::Expression::ImageLoad {
6753                    image,
6754                    coordinate,
6755                    array_index,
6756                    sample,
6757                    level,
6758                } => {
6759                    self.write_wrapped_image_load(
6760                        module,
6761                        func_ctx,
6762                        image,
6763                        coordinate,
6764                        array_index,
6765                        sample,
6766                        level,
6767                    )?;
6768                }
6769                crate::Expression::ImageSample {
6770                    image,
6771                    sampler,
6772                    gather,
6773                    coordinate,
6774                    array_index,
6775                    offset,
6776                    level,
6777                    depth_ref,
6778                    clamp_to_edge,
6779                } => {
6780                    self.write_wrapped_image_sample(
6781                        module,
6782                        func_ctx,
6783                        image,
6784                        sampler,
6785                        gather,
6786                        coordinate,
6787                        array_index,
6788                        offset,
6789                        level,
6790                        depth_ref,
6791                        clamp_to_edge,
6792                    )?;
6793                }
6794                crate::Expression::ImageQuery { image, query } => {
6795                    self.write_wrapped_image_query(module, func_ctx, image, query)?;
6796                }
6797                crate::Expression::CooperativeLoad {
6798                    columns,
6799                    rows,
6800                    role: _,
6801                    ref data,
6802                } => {
6803                    self.write_wrapped_cooperative_load(
6804                        module,
6805                        func_ctx,
6806                        columns,
6807                        rows,
6808                        data.pointer,
6809                    )?;
6810                }
6811                crate::Expression::CooperativeMultiplyAdd { a, b, c } => {
6812                    let space = crate::AddressSpace::Private;
6813                    self.write_wrapped_cooperative_multiply_add(module, func_ctx, space, a, b, c)?;
6814                }
6815                crate::Expression::RayQueryGetIntersection { committed, .. } => {
6816                    self.write_rq_get_intersection_function(module, committed, options)?;
6817                }
6818                _ => {}
6819            }
6820        }
6821
6822        Ok(())
6823    }
6824
6825    // Returns the array of mapped entry point names.
6826    fn write_functions(
6827        &mut self,
6828        module: &crate::Module,
6829        mod_info: &valid::ModuleInfo,
6830        options: &Options,
6831        pipeline_options: &PipelineOptions,
6832    ) -> Result<TranslationInfo, Error> {
6833        use nt::VertexFormat;
6834
6835        // Define structs to hold resolved/generated data for vertex buffers and
6836        // their attributes.
6837        struct AttributeMappingResolved {
6838            ty_name: String,
6839            dimension: Option<crate::VectorSize>,
6840            scalar: crate::Scalar,
6841            name: String,
6842        }
6843        let mut am_resolved = FastHashMap::<u32, AttributeMappingResolved>::default();
6844
6845        struct VertexBufferMappingResolved<'a> {
6846            id: u32,
6847            stride: u32,
6848            step_mode: back::msl::VertexBufferStepMode,
6849            ty_name: String,
6850            param_name: String,
6851            elem_name: String,
6852            attributes: &'a Vec<back::msl::AttributeMapping>,
6853        }
6854        let mut vbm_resolved = Vec::<VertexBufferMappingResolved>::new();
6855
6856        // Define a struct to hold a named reference to a byte-unpacking function.
6857        struct UnpackingFunction {
6858            name: String,
6859            byte_count: u32,
6860            dimension: Option<crate::VectorSize>,
6861            scalar: crate::Scalar,
6862        }
6863        let mut unpacking_functions = FastHashMap::<VertexFormat, UnpackingFunction>::default();
6864
6865        // Check if we are attempting vertex pulling. If we are, generate some
6866        // names we'll need, and iterate the vertex buffer mappings to output
6867        // all the conversion functions we'll need to unpack the attribute data.
6868        // We can re-use these names for all entry points that need them, since
6869        // those entry points also use self.namer.
6870        let mut needs_vertex_id = false;
6871        let v_id = self.namer.call("v_id");
6872
6873        let mut needs_instance_id = false;
6874        let i_id = self.namer.call("i_id");
6875        if pipeline_options.vertex_pulling_transform {
6876            for vbm in &pipeline_options.vertex_buffer_mappings {
6877                let buffer_id = vbm.id;
6878                let buffer_stride = vbm.stride;
6879
6880                assert!(
6881                    buffer_stride > 0,
6882                    "Vertex pulling requires a non-zero buffer stride."
6883                );
6884
6885                match vbm.step_mode {
6886                    back::msl::VertexBufferStepMode::Constant => {}
6887                    back::msl::VertexBufferStepMode::ByVertex => {
6888                        needs_vertex_id = true;
6889                    }
6890                    back::msl::VertexBufferStepMode::ByInstance => {
6891                        needs_instance_id = true;
6892                    }
6893                }
6894
6895                let buffer_ty = self.namer.call(format!("vb_{buffer_id}_type").as_str());
6896                let buffer_param = self.namer.call(format!("vb_{buffer_id}_in").as_str());
6897                let buffer_elem = self.namer.call(format!("vb_{buffer_id}_elem").as_str());
6898
6899                vbm_resolved.push(VertexBufferMappingResolved {
6900                    id: buffer_id,
6901                    stride: buffer_stride,
6902                    step_mode: vbm.step_mode,
6903                    ty_name: buffer_ty,
6904                    param_name: buffer_param,
6905                    elem_name: buffer_elem,
6906                    attributes: &vbm.attributes,
6907                });
6908
6909                // Iterate the attributes and generate needed unpacking functions.
6910                for attribute in &vbm.attributes {
6911                    if unpacking_functions.contains_key(&attribute.format) {
6912                        continue;
6913                    }
6914                    let (name, byte_count, dimension, scalar) =
6915                        match self.write_unpacking_function(attribute.format) {
6916                            Ok((name, byte_count, dimension, scalar)) => {
6917                                (name, byte_count, dimension, scalar)
6918                            }
6919                            _ => {
6920                                continue;
6921                            }
6922                        };
6923                    unpacking_functions.insert(
6924                        attribute.format,
6925                        UnpackingFunction {
6926                            name,
6927                            byte_count,
6928                            dimension,
6929                            scalar,
6930                        },
6931                    );
6932                }
6933            }
6934        }
6935
6936        let mut pass_through_globals = Vec::new();
6937        for (fun_handle, fun) in module.functions.iter() {
6938            log::trace!(
6939                "function {:?}, handle {:?}",
6940                fun.name.as_deref().unwrap_or("(anonymous)"),
6941                fun_handle
6942            );
6943
6944            let ctx = back::FunctionCtx {
6945                ty: back::FunctionType::Function(fun_handle),
6946                info: &mod_info[fun_handle],
6947                expressions: &fun.expressions,
6948                named_expressions: &fun.named_expressions,
6949            };
6950
6951            writeln!(self.out)?;
6952            self.write_wrapped_functions(module, &ctx, options)?;
6953
6954            let fun_info = &mod_info[fun_handle];
6955            pass_through_globals.clear();
6956            let mut needs_buffer_sizes = false;
6957            for (handle, var) in module.global_variables.iter() {
6958                if !fun_info[handle].is_empty() {
6959                    if var.space.needs_pass_through() {
6960                        pass_through_globals.push(handle);
6961                    }
6962                    needs_buffer_sizes |= module.types[var.ty]
6963                        .inner
6964                        .needs_host_buffer_byte_size(&module.types);
6965                }
6966            }
6967
6968            let fun_name = &self.names[&NameKey::Function(fun_handle)];
6969            match fun.result {
6970                Some(ref result) => {
6971                    let ty_name = TypeContext {
6972                        handle: result.ty,
6973                        gctx: module.to_ctx(),
6974                        names: &self.names,
6975                        access: crate::StorageAccess::empty(),
6976                        first_time: false,
6977                    };
6978                    write!(self.out, "{ty_name}")?;
6979                }
6980                None => {
6981                    write!(self.out, "void")?;
6982                }
6983            }
6984            writeln!(self.out, " {fun_name}(")?;
6985
6986            for (index, arg) in fun.arguments.iter().enumerate() {
6987                let name = &self.names[&NameKey::FunctionArgument(fun_handle, index as u32)];
6988                let param_type_name = TypeContext {
6989                    handle: arg.ty,
6990                    gctx: module.to_ctx(),
6991                    names: &self.names,
6992                    access: crate::StorageAccess::empty(),
6993                    first_time: false,
6994                };
6995                let separator = separate(
6996                    !pass_through_globals.is_empty()
6997                        || index + 1 != fun.arguments.len()
6998                        || needs_buffer_sizes,
6999                );
7000                writeln!(
7001                    self.out,
7002                    "{}{} {}{}",
7003                    back::INDENT,
7004                    param_type_name,
7005                    name,
7006                    separator
7007                )?;
7008            }
7009            for (index, &handle) in pass_through_globals.iter().enumerate() {
7010                let tyvar = TypedGlobalVariable {
7011                    module,
7012                    names: &self.names,
7013                    handle,
7014                    usage: fun_info[handle],
7015                    reference: true,
7016                };
7017                let separator =
7018                    separate(index + 1 != pass_through_globals.len() || needs_buffer_sizes);
7019                write!(self.out, "{}", back::INDENT)?;
7020                tyvar.try_fmt(&mut self.out)?;
7021                writeln!(self.out, "{separator}")?;
7022            }
7023
7024            if needs_buffer_sizes {
7025                writeln!(
7026                    self.out,
7027                    "{}constant _mslBufferSizes& _buffer_sizes",
7028                    back::INDENT
7029                )?;
7030            }
7031
7032            writeln!(self.out, ") {{")?;
7033
7034            let guarded_indices =
7035                index::find_checked_indexes(module, fun, fun_info, options.bounds_check_policies);
7036
7037            let context = StatementContext {
7038                expression: ExpressionContext {
7039                    function: fun,
7040                    origin: FunctionOrigin::Handle(fun_handle),
7041                    info: fun_info,
7042                    lang_version: options.lang_version,
7043                    policies: options.bounds_check_policies,
7044                    guarded_indices,
7045                    module,
7046                    mod_info,
7047                    pipeline_options,
7048                    force_loop_bounding: options.force_loop_bounding,
7049                    emit_int_div_checks: options.emit_int_div_checks,
7050                    ray_query_initialization_tracking: options.ray_query_initialization_tracking,
7051                },
7052                result_struct: None,
7053            };
7054
7055            self.put_locals(&context.expression)?;
7056            self.update_expressions_to_bake(fun, fun_info, &context.expression);
7057            self.put_block(back::Level(1), &fun.body, &context)?;
7058            writeln!(self.out, "}}")?;
7059            self.named_expressions.clear();
7060        }
7061
7062        let ep_range = get_entry_points(module, pipeline_options.entry_point.as_ref())
7063            .map_err(|(stage, name)| Error::EntryPointNotFound(stage, name))?;
7064
7065        let mut info = TranslationInfo {
7066            entry_point_names: Vec::with_capacity(ep_range.len()),
7067        };
7068
7069        for ep_index in ep_range {
7070            let ep = &module.entry_points[ep_index];
7071            let fun = &ep.function;
7072            let fun_info = mod_info.get_entry_point(ep_index);
7073            let mut ep_error = None;
7074
7075            // For vertex_id and instance_id arguments, presume that we'll
7076            // use our generated names, but switch to the name of an
7077            // existing @builtin param, if we find one.
7078            let mut v_existing_id = None;
7079            let mut i_existing_id = None;
7080
7081            log::trace!(
7082                "entry point {:?}, index {:?}",
7083                fun.name.as_deref().unwrap_or("(anonymous)"),
7084                ep_index
7085            );
7086
7087            let ctx = back::FunctionCtx {
7088                ty: back::FunctionType::EntryPoint(ep_index as u16),
7089                info: fun_info,
7090                expressions: &fun.expressions,
7091                named_expressions: &fun.named_expressions,
7092            };
7093
7094            self.write_wrapped_functions(module, &ctx, options)?;
7095
7096            let (em_str, in_mode, out_mode, can_vertex_pull) = match ep.stage {
7097                crate::ShaderStage::Vertex => (
7098                    Some("vertex"),
7099                    LocationMode::VertexInput,
7100                    LocationMode::VertexOutput,
7101                    true,
7102                ),
7103                crate::ShaderStage::Fragment => (
7104                    Some("fragment"),
7105                    LocationMode::FragmentInput,
7106                    LocationMode::FragmentOutput,
7107                    false,
7108                ),
7109                crate::ShaderStage::Compute => (
7110                    Some("kernel"),
7111                    LocationMode::Uniform,
7112                    LocationMode::Uniform,
7113                    false,
7114                ),
7115                crate::ShaderStage::Task => {
7116                    (None, LocationMode::Uniform, LocationMode::Uniform, false)
7117                }
7118                crate::ShaderStage::Mesh => {
7119                    (None, LocationMode::Uniform, LocationMode::MeshOutput, false)
7120                }
7121                crate::ShaderStage::RayGeneration
7122                | crate::ShaderStage::AnyHit
7123                | crate::ShaderStage::ClosestHit
7124                | crate::ShaderStage::Miss => unimplemented!(),
7125            };
7126
7127            // Should this entry point be modified to do vertex pulling?
7128            let do_vertex_pulling = can_vertex_pull
7129                && pipeline_options.vertex_pulling_transform
7130                && !pipeline_options.vertex_buffer_mappings.is_empty();
7131
7132            // Is any global variable used by this entry point dynamically sized?
7133            let needs_buffer_sizes = do_vertex_pulling
7134                || module
7135                    .global_variables
7136                    .iter()
7137                    .filter(|&(handle, _)| !fun_info[handle].is_empty())
7138                    .any(|(_, var)| {
7139                        module.types[var.ty]
7140                            .inner
7141                            .needs_host_buffer_byte_size(&module.types)
7142                    });
7143
7144            // skip this entry point if any global bindings are missing,
7145            // or their types are incompatible.
7146            if !options.fake_missing_bindings {
7147                for (var_handle, var) in module.global_variables.iter() {
7148                    if fun_info[var_handle].is_empty() {
7149                        continue;
7150                    }
7151                    match var.space {
7152                        crate::AddressSpace::Uniform
7153                        | crate::AddressSpace::Storage { .. }
7154                        | crate::AddressSpace::Handle => {
7155                            let br = match var.binding {
7156                                Some(ref br) => br,
7157                                None => {
7158                                    let var_name = var.name.clone().unwrap_or_default();
7159                                    ep_error =
7160                                        Some(super::EntryPointError::MissingBinding(var_name));
7161                                    break;
7162                                }
7163                            };
7164                            let target = options.get_resource_binding_target(ep, br);
7165                            let good = match target {
7166                                Some(target) => {
7167                                    // We intentionally don't dereference binding_arrays here,
7168                                    // so that binding arrays fall to the buffer location.
7169
7170                                    match module.types[var.ty].inner {
7171                                        crate::TypeInner::Image {
7172                                            class: crate::ImageClass::External,
7173                                            ..
7174                                        } => target.external_texture.is_some(),
7175                                        crate::TypeInner::Image { .. } => target.texture.is_some(),
7176                                        crate::TypeInner::Sampler { .. } => {
7177                                            target.sampler.is_some()
7178                                        }
7179                                        _ => target.buffer.is_some(),
7180                                    }
7181                                }
7182                                None => false,
7183                            };
7184                            if !good {
7185                                ep_error = Some(super::EntryPointError::MissingBindTarget(*br));
7186                                break;
7187                            }
7188                        }
7189                        crate::AddressSpace::Immediate => {
7190                            if let Err(e) = options.resolve_immediates(ep) {
7191                                ep_error = Some(e);
7192                                break;
7193                            }
7194                        }
7195                        crate::AddressSpace::Function
7196                        | crate::AddressSpace::Private
7197                        | crate::AddressSpace::WorkGroup
7198                        | crate::AddressSpace::TaskPayload => {}
7199                        crate::AddressSpace::RayPayload
7200                        | crate::AddressSpace::IncomingRayPayload => unimplemented!(),
7201                    }
7202                }
7203                if needs_buffer_sizes {
7204                    if let Err(err) = options.resolve_sizes_buffer(ep) {
7205                        ep_error = Some(err);
7206                    }
7207                }
7208            }
7209
7210            if let Some(err) = ep_error {
7211                info.entry_point_names.push(Err(err));
7212                continue;
7213            }
7214            let fun_name = self.names[&NameKey::EntryPoint(ep_index as _)].clone();
7215            info.entry_point_names.push(Ok(fun_name.clone()));
7216
7217            writeln!(self.out)?;
7218
7219            // Since `Namer.reset` wasn't expecting struct members to be
7220            // suddenly injected into another namespace like this,
7221            // `self.names` doesn't keep them distinct from other variables.
7222            // Generate fresh names for these arguments, and remember the
7223            // mapping.
7224            let mut flattened_member_names = FastHashMap::default();
7225            // Varyings' members get their own namespace
7226            let mut varyings_namer = proc::Namer::default();
7227
7228            let mut empty_names = FastHashMap::default(); // Create a throwaway map
7229            varyings_namer.reset(
7230                module,
7231                &super::keywords::RESERVED_SET,
7232                proc::KeywordSet::empty(),
7233                proc::CaseInsensitiveKeywordSet::empty(),
7234                &[CLAMPED_LOD_LOAD_PREFIX],
7235                &mut empty_names,
7236            );
7237
7238            // List all the Naga `EntryPoint`'s `Function`'s arguments,
7239            // flattening structs into their members. In Metal, we will pass
7240            // each of these values to the entry point as a separate argument—
7241            // except for the varyings, handled next.
7242            let mut flattened_arguments = Vec::new();
7243            for (arg_index, arg) in fun.arguments.iter().enumerate() {
7244                match module.types[arg.ty].inner {
7245                    crate::TypeInner::Struct { ref members, .. } => {
7246                        for (member_index, member) in members.iter().enumerate() {
7247                            let member_index = member_index as u32;
7248                            flattened_arguments.push((
7249                                NameKey::StructMember(arg.ty, member_index),
7250                                member.ty,
7251                                member.binding.as_ref(),
7252                            ));
7253                            let name_key = NameKey::StructMember(arg.ty, member_index);
7254                            let name = match member.binding {
7255                                Some(crate::Binding::Location { .. }) => {
7256                                    if do_vertex_pulling {
7257                                        self.namer.call(&self.names[&name_key])
7258                                    } else {
7259                                        varyings_namer.call(&self.names[&name_key])
7260                                    }
7261                                }
7262                                _ => self.namer.call(&self.names[&name_key]),
7263                            };
7264                            flattened_member_names.insert(name_key, name);
7265                        }
7266                    }
7267                    _ => flattened_arguments.push((
7268                        NameKey::EntryPointArgument(ep_index as _, arg_index as u32),
7269                        arg.ty,
7270                        arg.binding.as_ref(),
7271                    )),
7272                }
7273            }
7274
7275            // Identify the varyings among the argument values, and maybe emit
7276            // a struct type named `<fun>Input` to hold them. If we are doing
7277            // vertex pulling, we instead update our attribute mapping to
7278            // note the types, names, and zero values of the attributes.
7279            let stage_in_name = self.namer.call(&format!("{fun_name}Input"));
7280            let varyings_member_name = self.namer.call("varyings");
7281            let mut has_varyings = false;
7282
7283            if !flattened_arguments.is_empty() {
7284                if !do_vertex_pulling {
7285                    writeln!(self.out, "struct {stage_in_name} {{")?;
7286                }
7287                for &(ref name_key, ty, binding) in flattened_arguments.iter() {
7288                    let Some(binding) = binding else {
7289                        continue;
7290                    };
7291                    let name = match *name_key {
7292                        NameKey::StructMember(..) => &flattened_member_names[name_key],
7293                        _ => &self.names[name_key],
7294                    };
7295                    let ty_name = TypeContext {
7296                        handle: ty,
7297                        gctx: module.to_ctx(),
7298                        names: &self.names,
7299                        access: crate::StorageAccess::empty(),
7300                        first_time: false,
7301                    };
7302                    let resolved = options.resolve_local_binding(binding, in_mode)?;
7303                    let location = match *binding {
7304                        crate::Binding::Location { location, .. } => Some(location),
7305                        crate::Binding::BuiltIn(crate::BuiltIn::Barycentric { .. }) => None,
7306                        crate::Binding::BuiltIn(_) => continue,
7307                    };
7308                    if do_vertex_pulling {
7309                        let Some(location) = location else {
7310                            continue;
7311                        };
7312                        // Update our attribute mapping.
7313                        am_resolved.insert(
7314                            location,
7315                            AttributeMappingResolved {
7316                                ty_name: ty_name.to_string(),
7317                                dimension: ty_name.vector_size(),
7318                                scalar: ty_name.scalar().unwrap(),
7319                                name: name.to_string(),
7320                            },
7321                        );
7322                    } else {
7323                        has_varyings = true;
7324                        if let super::ResolvedBinding::User {
7325                            prefix,
7326                            index,
7327                            interpolation: Some(super::ResolvedInterpolation::PerVertex),
7328                        } = resolved
7329                        {
7330                            if options.lang_version < (4, 0) {
7331                                return Err(Error::PerVertexNotSupported);
7332                            }
7333                            write!(
7334                                self.out,
7335                                "{}{NAMESPACE}::vertex_value<{}> {name} [[user({prefix}{index})]]",
7336                                back::INDENT,
7337                                ty_name.unwrap_array()
7338                            )?;
7339                        } else {
7340                            write!(self.out, "{}{} {}", back::INDENT, ty_name, name)?;
7341                            resolved.try_fmt(&mut self.out)?;
7342                        }
7343                        writeln!(self.out, ";")?;
7344                    }
7345                }
7346                if !do_vertex_pulling {
7347                    writeln!(self.out, "}};")?;
7348                }
7349            }
7350
7351            // Define a struct type named for the return value, if any, named
7352            // `<fun>Output`.
7353            let stage_out_name = self.namer.call(&format!("{fun_name}Output"));
7354            let result_member_name = self.namer.call("member");
7355            let result_type_name = match fun.result {
7356                Some(ref result) if ep.stage != crate::ShaderStage::Task => {
7357                    let mut result_members = Vec::new();
7358                    if let crate::TypeInner::Struct { ref members, .. } =
7359                        module.types[result.ty].inner
7360                    {
7361                        for (member_index, member) in members.iter().enumerate() {
7362                            result_members.push((
7363                                &self.names[&NameKey::StructMember(result.ty, member_index as u32)],
7364                                member.ty,
7365                                member.binding.as_ref(),
7366                            ));
7367                        }
7368                    } else {
7369                        result_members.push((
7370                            &result_member_name,
7371                            result.ty,
7372                            result.binding.as_ref(),
7373                        ));
7374                    }
7375
7376                    writeln!(self.out, "struct {stage_out_name} {{")?;
7377                    let mut has_point_size = false;
7378                    for (name, ty, binding) in result_members {
7379                        let ty_name = TypeContext {
7380                            handle: ty,
7381                            gctx: module.to_ctx(),
7382                            names: &self.names,
7383                            access: crate::StorageAccess::empty(),
7384                            first_time: true,
7385                        };
7386                        let binding = binding.ok_or_else(|| {
7387                            Error::GenericValidation("Expected binding, got None".into())
7388                        })?;
7389
7390                        if let crate::Binding::BuiltIn(crate::BuiltIn::PointSize) = *binding {
7391                            has_point_size = true;
7392                            if !pipeline_options.allow_and_force_point_size {
7393                                continue;
7394                            }
7395                        }
7396
7397                        let array_len = match module.types[ty].inner {
7398                            crate::TypeInner::Array {
7399                                size: crate::ArraySize::Constant(size),
7400                                ..
7401                            } => Some(size),
7402                            _ => None,
7403                        };
7404                        let resolved = options.resolve_local_binding(binding, out_mode)?;
7405                        write!(self.out, "{}{} {}", back::INDENT, ty_name, name)?;
7406                        resolved.try_fmt(&mut self.out)?;
7407                        if let Some(array_len) = array_len {
7408                            write!(self.out, " [{array_len}]")?;
7409                        }
7410                        writeln!(self.out, ";")?;
7411                    }
7412
7413                    if pipeline_options.allow_and_force_point_size
7414                        && ep.stage == crate::ShaderStage::Vertex
7415                        && !has_point_size
7416                    {
7417                        // inject the point size output last
7418                        writeln!(
7419                            self.out,
7420                            "{}float _point_size [[point_size]];",
7421                            back::INDENT
7422                        )?;
7423                    }
7424                    writeln!(self.out, "}};")?;
7425                    &stage_out_name
7426                }
7427                Some(ref result) if ep.stage == crate::ShaderStage::Task => {
7428                    assert_eq!(
7429                        module.types[result.ty].inner,
7430                        crate::TypeInner::Vector {
7431                            size: crate::VectorSize::Tri,
7432                            scalar: crate::Scalar::U32
7433                        }
7434                    );
7435
7436                    "metal::uint3"
7437                }
7438                _ => "void",
7439            };
7440
7441            let out_mesh_info = if let Some(ref mesh_info) = ep.mesh_info {
7442                Some(self.write_mesh_output_types(
7443                    mesh_info,
7444                    &fun_name,
7445                    module,
7446                    pipeline_options.allow_and_force_point_size,
7447                    options,
7448                )?)
7449            } else {
7450                None
7451            };
7452
7453            // If we're doing a vertex pulling transform, define the buffer
7454            // structure types.
7455            if do_vertex_pulling {
7456                for vbm in &vbm_resolved {
7457                    let buffer_stride = vbm.stride;
7458                    let buffer_ty = &vbm.ty_name;
7459
7460                    // Define a structure of bytes of the appropriate size.
7461                    // When we access the attributes, we'll be unpacking these
7462                    // bytes at some offset.
7463                    writeln!(
7464                        self.out,
7465                        "struct {buffer_ty} {{ metal::uchar data[{buffer_stride}]; }};"
7466                    )?;
7467                }
7468            }
7469
7470            let is_wrapped = matches!(
7471                ep.stage,
7472                crate::ShaderStage::Task | crate::ShaderStage::Mesh
7473            );
7474            let fun_name = fun_name.clone();
7475            let nested_fun_name = if is_wrapped {
7476                self.namer.call(&format!("_{fun_name}"))
7477            } else {
7478                fun_name.clone()
7479            };
7480
7481            // https://web.archive.org/web/20181029003926/https://developer.apple.com/metal/Metal-Shading-Language-Specification.pdf
7482            if ep.stage == crate::ShaderStage::Compute && options.lang_version >= (2, 1) {
7483                let total_threads =
7484                    ep.workgroup_size[0] * ep.workgroup_size[1] * ep.workgroup_size[2];
7485                write!(
7486                    self.out,
7487                    "[[max_total_threads_per_threadgroup({total_threads})]] "
7488                )?;
7489            }
7490
7491            // Write the entry point function's name, and begin its argument list.
7492            if let Some(em_str) = em_str {
7493                write!(self.out, "{em_str} ")?;
7494            }
7495            writeln!(self.out, "{result_type_name} {nested_fun_name}(")?;
7496
7497            let mut args = Vec::new();
7498
7499            // If we have produced a struct holding the `EntryPoint`'s
7500            // `Function`'s arguments' varyings, pass that struct first.
7501            if has_varyings {
7502                args.push(EntryPointArgument {
7503                    ty_name: stage_in_name,
7504                    name: varyings_member_name.clone(),
7505                    binding: " [[stage_in]]".to_string(),
7506                    init: None,
7507                });
7508            }
7509
7510            let mut local_invocation_index = None;
7511
7512            // Then pass the remaining arguments not included in the varyings
7513            // struct.
7514            for &(ref name_key, ty, binding) in flattened_arguments.iter() {
7515                let binding = match binding {
7516                    Some(&crate::Binding::BuiltIn(crate::BuiltIn::Barycentric { .. })) => continue,
7517                    Some(binding @ &crate::Binding::BuiltIn { .. }) => binding,
7518                    _ => continue,
7519                };
7520                let name = match *name_key {
7521                    NameKey::StructMember(..) => &flattened_member_names[name_key],
7522                    _ => &self.names[name_key],
7523                };
7524
7525                if binding == &crate::Binding::BuiltIn(crate::BuiltIn::LocalInvocationIndex) {
7526                    local_invocation_index = Some(name_key);
7527                }
7528
7529                let ty_name = TypeContext {
7530                    handle: ty,
7531                    gctx: module.to_ctx(),
7532                    names: &self.names,
7533                    access: crate::StorageAccess::empty(),
7534                    first_time: false,
7535                };
7536
7537                match *binding {
7538                    crate::Binding::BuiltIn(crate::BuiltIn::VertexIndex) => {
7539                        v_existing_id = Some(name.clone());
7540                    }
7541                    crate::Binding::BuiltIn(crate::BuiltIn::InstanceIndex) => {
7542                        i_existing_id = Some(name.clone());
7543                    }
7544                    _ => {}
7545                };
7546
7547                let resolved = options.resolve_local_binding(binding, in_mode)?;
7548                let mut binding = String::new();
7549                resolved.try_fmt(&mut binding)?;
7550
7551                args.push(EntryPointArgument {
7552                    ty_name: format!("{ty_name}"),
7553                    name: name.clone(),
7554                    binding,
7555                    init: None,
7556                });
7557            }
7558
7559            let need_workgroup_variables_initialization =
7560                self.need_workgroup_variables_initialization(options, ep, module, fun_info);
7561
7562            if local_invocation_index.is_none()
7563                && (need_workgroup_variables_initialization
7564                    || ep.stage == crate::ShaderStage::Task
7565                    || ep.stage == crate::ShaderStage::Mesh)
7566            {
7567                args.push(EntryPointArgument {
7568                    ty_name: "uint".to_string(),
7569                    name: "__local_invocation_index".to_string(),
7570                    binding: " [[thread_index_in_threadgroup]]".to_string(),
7571                    init: None,
7572                });
7573            }
7574
7575            // Those global variables used by this entry point and its callees
7576            // get passed as arguments. `Private` globals are an exception, they
7577            // don't outlive this invocation, so we declare them below as locals
7578            // within the entry point.
7579            for (handle, var) in module.global_variables.iter() {
7580                let usage = fun_info[handle];
7581                if usage.is_empty() || var.space == crate::AddressSpace::Private {
7582                    continue;
7583                }
7584
7585                if options.lang_version < (1, 2) {
7586                    match var.space {
7587                        // This restriction is not documented in the MSL spec
7588                        // but validation will fail if it is not upheld.
7589                        //
7590                        // We infer the required version from the "Function
7591                        // Buffer Read-Writes" section of [what's new], where
7592                        // the feature sets listed correspond with the ones
7593                        // supporting MSL 1.2.
7594                        //
7595                        // [what's new]: https://developer.apple.com/library/archive/documentation/Miscellaneous/Conceptual/MetalProgrammingGuide/WhatsNewiniOS10tvOS10andOSX1012/WhatsNewiniOS10tvOS10andOSX1012.html
7596                        crate::AddressSpace::Storage { access }
7597                            if access.contains(crate::StorageAccess::STORE)
7598                                && ep.stage == crate::ShaderStage::Fragment =>
7599                        {
7600                            return Err(Error::UnsupportedWritableStorageBuffer)
7601                        }
7602                        crate::AddressSpace::Handle => {
7603                            match module.types[var.ty].inner {
7604                                crate::TypeInner::Image {
7605                                    class: crate::ImageClass::Storage { access, .. },
7606                                    ..
7607                                } => {
7608                                    // This restriction is not documented in the MSL spec
7609                                    // but validation will fail if it is not upheld.
7610                                    //
7611                                    // We infer the required version from the "Function
7612                                    // Texture Read-Writes" section of [what's new], where
7613                                    // the feature sets listed correspond with the ones
7614                                    // supporting MSL 1.2.
7615                                    //
7616                                    // [what's new]: https://developer.apple.com/library/archive/documentation/Miscellaneous/Conceptual/MetalProgrammingGuide/WhatsNewiniOS10tvOS10andOSX1012/WhatsNewiniOS10tvOS10andOSX1012.html
7617                                    if access.contains(crate::StorageAccess::STORE)
7618                                        && (ep.stage == crate::ShaderStage::Vertex
7619                                            || ep.stage == crate::ShaderStage::Fragment)
7620                                    {
7621                                        return Err(Error::UnsupportedWritableStorageTexture(
7622                                            ep.stage,
7623                                        ));
7624                                    }
7625
7626                                    if access.contains(
7627                                        crate::StorageAccess::LOAD | crate::StorageAccess::STORE,
7628                                    ) {
7629                                        return Err(Error::UnsupportedRWStorageTexture);
7630                                    }
7631                                }
7632                                _ => {}
7633                            }
7634                        }
7635                        _ => {}
7636                    }
7637                }
7638
7639                // Check min MSL version for binding arrays
7640                match var.space {
7641                    crate::AddressSpace::Handle => match module.types[var.ty].inner {
7642                        crate::TypeInner::BindingArray { base, .. } => {
7643                            match module.types[base].inner {
7644                                crate::TypeInner::Sampler { .. } => {
7645                                    if options.lang_version < (2, 0) {
7646                                        return Err(Error::UnsupportedArrayOf(
7647                                            "samplers".to_string(),
7648                                        ));
7649                                    }
7650                                }
7651                                crate::TypeInner::Image { class, .. } => match class {
7652                                    crate::ImageClass::Sampled { .. }
7653                                    | crate::ImageClass::Depth { .. }
7654                                    | crate::ImageClass::Storage {
7655                                        access: crate::StorageAccess::LOAD,
7656                                        ..
7657                                    } => {
7658                                        // Array of textures since:
7659                                        // - iOS: Metal 1.2 (check depends on https://github.com/gfx-rs/naga/issues/2164)
7660                                        // - macOS: Metal 2
7661
7662                                        if options.lang_version < (2, 0) {
7663                                            return Err(Error::UnsupportedArrayOf(
7664                                                "textures".to_string(),
7665                                            ));
7666                                        }
7667                                    }
7668                                    crate::ImageClass::Storage {
7669                                        access: crate::StorageAccess::STORE,
7670                                        ..
7671                                    } => {
7672                                        // Array of write-only textures since:
7673                                        // - iOS: Metal 2.2 (check depends on https://github.com/gfx-rs/naga/issues/2164)
7674                                        // - macOS: Metal 2
7675
7676                                        if options.lang_version < (2, 0) {
7677                                            return Err(Error::UnsupportedArrayOf(
7678                                                "write-only textures".to_string(),
7679                                            ));
7680                                        }
7681                                    }
7682                                    crate::ImageClass::Storage { .. } => {
7683                                        if options.lang_version < (3, 0) {
7684                                            return Err(Error::UnsupportedArrayOf(
7685                                                "read-write textures".to_string(),
7686                                            ));
7687                                        }
7688                                    }
7689                                    crate::ImageClass::External => {
7690                                        return Err(Error::UnsupportedArrayOf(
7691                                            "external textures".to_string(),
7692                                        ));
7693                                    }
7694                                },
7695                                _ => {
7696                                    return Err(Error::UnsupportedArrayOfType(base));
7697                                }
7698                            }
7699                        }
7700                        _ => {}
7701                    },
7702                    _ => {}
7703                }
7704
7705                // the resolves have already been checked for `!fake_missing_bindings` case
7706                let resolved = match var.space {
7707                    crate::AddressSpace::Immediate => options.resolve_immediates(ep).ok(),
7708                    crate::AddressSpace::WorkGroup => None,
7709                    crate::AddressSpace::TaskPayload => Some(back::msl::ResolvedBinding::Payload),
7710                    _ => options
7711                        .resolve_resource_binding(ep, var.binding.as_ref().unwrap())
7712                        .ok(),
7713                };
7714                if let Some(ref resolved) = resolved {
7715                    // Inline samplers are be defined in the EP body
7716                    if resolved.as_inline_sampler(options).is_some() {
7717                        continue;
7718                    }
7719                }
7720
7721                match module.types[var.ty].inner {
7722                    crate::TypeInner::Image {
7723                        class: crate::ImageClass::External,
7724                        ..
7725                    } => {
7726                        // External texture global variables get lowered to 3 textures
7727                        // and a constant buffer. We must emit a separate argument for
7728                        // each of these.
7729                        let target = match resolved {
7730                            Some(back::msl::ResolvedBinding::Resource(target)) => {
7731                                target.external_texture
7732                            }
7733                            _ => None,
7734                        };
7735
7736                        for i in 0..3 {
7737                            let plane_name = &self.names[&NameKey::ExternalTextureGlobalVariable(
7738                                handle,
7739                                ExternalTextureNameKey::Plane(i),
7740                            )];
7741                            let ty_name = format!(
7742                                "{NAMESPACE}::texture2d<float, {NAMESPACE}::access::sample>"
7743                            );
7744                            let name = plane_name.clone();
7745                            let binding = if let Some(ref target) = target {
7746                                format!(" [[texture({})]]", target.planes[i])
7747                            } else {
7748                                String::new()
7749                            };
7750                            args.push(EntryPointArgument {
7751                                ty_name,
7752                                name,
7753                                binding,
7754                                init: None,
7755                            });
7756                        }
7757                        let params_ty_name = &self.names
7758                            [&NameKey::Type(module.special_types.external_texture_params.unwrap())];
7759                        let params_name = &self.names[&NameKey::ExternalTextureGlobalVariable(
7760                            handle,
7761                            ExternalTextureNameKey::Params,
7762                        )];
7763                        let binding = if let Some(ref target) = target {
7764                            format!(" [[buffer({})]]", target.params)
7765                        } else {
7766                            String::new()
7767                        };
7768
7769                        args.push(EntryPointArgument {
7770                            ty_name: format!("constant {params_ty_name}&"),
7771                            name: params_name.clone(),
7772                            binding,
7773                            init: None,
7774                        });
7775                    }
7776                    _ => {
7777                        if var.space == crate::AddressSpace::WorkGroup
7778                            && ep.stage == crate::ShaderStage::Mesh
7779                        {
7780                            continue;
7781                        }
7782                        let tyvar = TypedGlobalVariable {
7783                            module,
7784                            names: &self.names,
7785                            handle,
7786                            usage,
7787                            reference: true,
7788                        };
7789                        let parts = tyvar.to_parts()?;
7790                        let mut binding = String::new();
7791                        if let Some(resolved) = resolved {
7792                            resolved.try_fmt(&mut binding)?;
7793                        }
7794                        args.push(EntryPointArgument {
7795                            ty_name: parts.ty_name,
7796                            name: parts.var_name,
7797                            binding,
7798                            init: var.init,
7799                        });
7800                    }
7801                }
7802            }
7803
7804            if do_vertex_pulling {
7805                if needs_vertex_id && v_existing_id.is_none() {
7806                    // Write the [[vertex_id]] argument.
7807                    args.push(EntryPointArgument {
7808                        ty_name: "uint".to_string(),
7809                        name: v_id.clone(),
7810                        binding: " [[vertex_id]]".to_string(),
7811                        init: None,
7812                    });
7813                }
7814
7815                if needs_instance_id && i_existing_id.is_none() {
7816                    args.push(EntryPointArgument {
7817                        ty_name: "uint".to_string(),
7818                        name: i_id.clone(),
7819                        binding: " [[instance_id]]".to_string(),
7820                        init: None,
7821                    });
7822                }
7823
7824                // Iterate vbm_resolved, output one argument for every vertex buffer,
7825                // using the names we generated earlier.
7826                for vbm in &vbm_resolved {
7827                    let id = &vbm.id;
7828                    let ty_name = &vbm.ty_name;
7829                    let param_name = &vbm.param_name;
7830                    args.push(EntryPointArgument {
7831                        ty_name: format!("const device {ty_name}*"),
7832                        name: param_name.clone(),
7833                        binding: format!(" [[buffer({id})]]"),
7834                        init: None,
7835                    });
7836                }
7837            }
7838
7839            // If this entry uses any variable-length arrays, their sizes are
7840            // passed as a final struct-typed argument.
7841            if needs_buffer_sizes {
7842                // this is checked earlier
7843                let resolved = options.resolve_sizes_buffer(ep).unwrap();
7844                let mut binding = String::new();
7845                resolved.try_fmt(&mut binding)?;
7846                args.push(EntryPointArgument {
7847                    ty_name: "constant _mslBufferSizes&".to_string(),
7848                    name: "_buffer_sizes".to_string(),
7849                    binding,
7850                    init: None,
7851                });
7852            }
7853
7854            let mut is_first_arg = true;
7855            for arg in &args {
7856                if is_first_arg {
7857                    write!(self.out, "  ")?;
7858                } else {
7859                    write!(self.out, ", ")?;
7860                }
7861                is_first_arg = false;
7862                write!(self.out, "{} {}", arg.ty_name, arg.name)?;
7863                if !is_wrapped {
7864                    write!(self.out, "{}", arg.binding)?;
7865                    if let Some(init) = arg.init {
7866                        write!(self.out, " = ")?;
7867                        self.put_const_expression(
7868                            init,
7869                            module,
7870                            mod_info,
7871                            &module.global_expressions,
7872                        )?;
7873                    }
7874                }
7875                writeln!(self.out)?;
7876            }
7877            if ep.stage == crate::ShaderStage::Mesh {
7878                for (handle, var) in module.global_variables.iter() {
7879                    if var.space != crate::AddressSpace::WorkGroup || fun_info[handle].is_empty() {
7880                        continue;
7881                    }
7882                    if is_first_arg {
7883                        write!(self.out, "  ")?;
7884                    } else {
7885                        write!(self.out, ", ")?;
7886                    }
7887                    let ty_context = TypeContext {
7888                        handle: module.global_variables[handle].ty,
7889                        gctx: module.to_ctx(),
7890                        names: &self.names,
7891                        access: crate::StorageAccess::empty(),
7892                        first_time: false,
7893                    };
7894                    writeln!(
7895                        self.out,
7896                        "threadgroup {ty_context}& {}",
7897                        self.names[&NameKey::GlobalVariable(handle)]
7898                    )?;
7899                }
7900            }
7901
7902            // end of the entry point argument list
7903            writeln!(self.out, ") {{")?;
7904
7905            // Starting the function body.
7906            if do_vertex_pulling {
7907                // Provide zero values for all the attributes, which we will overwrite with
7908                // real data from the vertex attribute buffers, if the indices are in-bounds.
7909                for vbm in &vbm_resolved {
7910                    for attribute in vbm.attributes {
7911                        let location = attribute.shader_location;
7912                        let am_option = am_resolved.get(&location);
7913                        if am_option.is_none() {
7914                            // This bound attribute isn't used in this entry point, so
7915                            // don't bother zero-initializing it.
7916                            continue;
7917                        }
7918                        let am = am_option.unwrap();
7919                        let attribute_ty_name = &am.ty_name;
7920                        let attribute_name = &am.name;
7921
7922                        writeln!(
7923                            self.out,
7924                            "{}{attribute_ty_name} {attribute_name} = {{}};",
7925                            back::Level(1)
7926                        )?;
7927                    }
7928
7929                    // Output a bounds check block that will set real values for the
7930                    // attributes, if the bounds are satisfied.
7931                    write!(self.out, "{}if (", back::Level(1))?;
7932
7933                    let idx = &vbm.id;
7934                    let stride = &vbm.stride;
7935                    let index_name = match vbm.step_mode {
7936                        back::msl::VertexBufferStepMode::Constant => "0",
7937                        back::msl::VertexBufferStepMode::ByVertex => {
7938                            if let Some(ref name) = v_existing_id {
7939                                name
7940                            } else {
7941                                &v_id
7942                            }
7943                        }
7944                        back::msl::VertexBufferStepMode::ByInstance => {
7945                            if let Some(ref name) = i_existing_id {
7946                                name
7947                            } else {
7948                                &i_id
7949                            }
7950                        }
7951                    };
7952                    write!(
7953                        self.out,
7954                        "{index_name} < (_buffer_sizes.buffer_size{idx} / {stride})"
7955                    )?;
7956
7957                    writeln!(self.out, ") {{")?;
7958
7959                    // Pull the bytes out of the vertex buffer.
7960                    let ty_name = &vbm.ty_name;
7961                    let elem_name = &vbm.elem_name;
7962                    let param_name = &vbm.param_name;
7963
7964                    writeln!(
7965                        self.out,
7966                        "{}const {ty_name} {elem_name} = {param_name}[{index_name}];",
7967                        back::Level(2),
7968                    )?;
7969
7970                    // Now set real values for each of the attributes, by unpacking the data
7971                    // from the buffer elements.
7972                    for attribute in vbm.attributes {
7973                        let location = attribute.shader_location;
7974                        let Some(am) = am_resolved.get(&location) else {
7975                            // This bound attribute isn't used in this entry point, so
7976                            // don't bother extracting the data. Too bad we emitted the
7977                            // unpacking function earlier -- it might not get used.
7978                            continue;
7979                        };
7980                        let attribute_name = &am.name;
7981                        let attribute_ty_name = &am.ty_name;
7982
7983                        let offset = attribute.offset;
7984                        let func = unpacking_functions
7985                            .get(&attribute.format)
7986                            .expect("Should have generated this unpacking function earlier.");
7987                        let func_name = &func.name;
7988
7989                        // Check dimensionality of the attribute compared to the unpacking
7990                        // function. If attribute dimension > unpack dimension, we have to
7991                        // pad out the unpack value from a vec4(0, 0, 0, 1) of matching
7992                        // scalar type. Otherwise, if attribute dimension is < unpack
7993                        // dimension, then we need to explicitly truncate the result.
7994                        let needs_padding_or_truncation = am.dimension.cmp(&func.dimension);
7995
7996                        // We need an extra type conversion if the shader type does not
7997                        // match the type returned from the unpacking function.
7998                        let needs_conversion = am.scalar != func.scalar;
7999
8000                        if needs_padding_or_truncation != Ordering::Equal {
8001                            // Emit a comment flagging that a conversion is happening,
8002                            // since the actual logic can be at the end of a long line.
8003                            writeln!(
8004                                self.out,
8005                                "{}// {attribute_ty_name} <- {:?}",
8006                                back::Level(2),
8007                                attribute.format
8008                            )?;
8009                        }
8010
8011                        write!(self.out, "{}{attribute_name} = ", back::Level(2),)?;
8012
8013                        if needs_padding_or_truncation == Ordering::Greater {
8014                            // Needs padding: emit constructor call for wider type
8015                            write!(self.out, "{attribute_ty_name}(")?;
8016                        }
8017
8018                        // Emit call to unpacking function
8019                        if needs_conversion {
8020                            put_numeric_type(&mut self.out, am.scalar, func.dimension.as_slice())?;
8021                            write!(self.out, "(")?;
8022                        }
8023                        write!(self.out, "{func_name}({elem_name}.data[{offset}]")?;
8024                        for i in (offset + 1)..(offset + func.byte_count) {
8025                            write!(self.out, ", {elem_name}.data[{i}]")?;
8026                        }
8027                        write!(self.out, ")")?;
8028                        if needs_conversion {
8029                            write!(self.out, ")")?;
8030                        }
8031
8032                        match needs_padding_or_truncation {
8033                            Ordering::Greater => {
8034                                // Padding
8035                                let ty_is_int = scalar_is_int(am.scalar);
8036                                let zero_value = if ty_is_int { "0" } else { "0.0" };
8037                                let one_value = if ty_is_int { "1" } else { "1.0" };
8038                                for i in func.dimension.map_or(1, u8::from)
8039                                    ..am.dimension.map_or(1, u8::from)
8040                                {
8041                                    write!(
8042                                        self.out,
8043                                        ", {}",
8044                                        if i == 3 { one_value } else { zero_value }
8045                                    )?;
8046                                }
8047                            }
8048                            Ordering::Less => {
8049                                // Truncate to the first `am.dimension` components
8050                                write!(
8051                                    self.out,
8052                                    ".{}",
8053                                    &"xyzw"[0..usize::from(am.dimension.map_or(1, u8::from))]
8054                                )?;
8055                            }
8056                            Ordering::Equal => {}
8057                        }
8058
8059                        if needs_padding_or_truncation == Ordering::Greater {
8060                            write!(self.out, ")")?;
8061                        }
8062
8063                        writeln!(self.out, ";")?;
8064                    }
8065
8066                    // End the bounds check / attribute setting block.
8067                    writeln!(self.out, "{}}}", back::Level(1))?;
8068                }
8069            }
8070
8071            // Metal doesn't support private mutable variables outside of functions,
8072            // so we put them here, just like the locals.
8073            for (handle, var) in module.global_variables.iter() {
8074                let usage = fun_info[handle];
8075                if usage.is_empty() {
8076                    continue;
8077                }
8078                if var.space == crate::AddressSpace::Private {
8079                    let tyvar = TypedGlobalVariable {
8080                        module,
8081                        names: &self.names,
8082                        handle,
8083                        usage,
8084
8085                        reference: false,
8086                    };
8087                    write!(self.out, "{}", back::INDENT)?;
8088                    tyvar.try_fmt(&mut self.out)?;
8089                    match var.init {
8090                        Some(value) => {
8091                            write!(self.out, " = ")?;
8092                            self.put_const_expression(
8093                                value,
8094                                module,
8095                                mod_info,
8096                                &module.global_expressions,
8097                            )?;
8098                            writeln!(self.out, ";")?;
8099                        }
8100                        None => {
8101                            writeln!(self.out, " = {{}};")?;
8102                        }
8103                    };
8104                } else if let Some(ref binding) = var.binding {
8105                    let resolved = options.resolve_resource_binding(ep, binding).unwrap();
8106                    if let Some(sampler) = resolved.as_inline_sampler(options) {
8107                        // write an inline sampler
8108                        let name = &self.names[&NameKey::GlobalVariable(handle)];
8109                        writeln!(
8110                            self.out,
8111                            "{}constexpr {}::sampler {}(",
8112                            back::INDENT,
8113                            NAMESPACE,
8114                            name
8115                        )?;
8116                        self.put_inline_sampler_properties(back::Level(2), sampler)?;
8117                        writeln!(self.out, "{});", back::INDENT)?;
8118                    } else if let crate::TypeInner::Image {
8119                        class: crate::ImageClass::External,
8120                        ..
8121                    } = module.types[var.ty].inner
8122                    {
8123                        // Wrap the individual arguments for each external texture global
8124                        // in a struct which can be easily passed around.
8125                        let wrapper_name = &self.names[&NameKey::GlobalVariable(handle)];
8126                        let l1 = back::Level(1);
8127                        let l2 = l1.next();
8128                        writeln!(
8129                            self.out,
8130                            "{l1}const {EXTERNAL_TEXTURE_WRAPPER_STRUCT} {wrapper_name} {{"
8131                        )?;
8132                        for i in 0..3 {
8133                            let plane_name = &self.names[&NameKey::ExternalTextureGlobalVariable(
8134                                handle,
8135                                ExternalTextureNameKey::Plane(i),
8136                            )];
8137                            writeln!(self.out, "{l2}.plane{i} = {plane_name},")?;
8138                        }
8139                        let params_name = &self.names[&NameKey::ExternalTextureGlobalVariable(
8140                            handle,
8141                            ExternalTextureNameKey::Params,
8142                        )];
8143                        writeln!(self.out, "{l2}.params = {params_name},")?;
8144                        writeln!(self.out, "{l1}}};")?;
8145                    }
8146                }
8147            }
8148
8149            if need_workgroup_variables_initialization {
8150                self.write_workgroup_variables_initialization(
8151                    module,
8152                    mod_info,
8153                    fun_info,
8154                    local_invocation_index,
8155                    ep.stage,
8156                )?;
8157            }
8158
8159            // Now take the arguments that we gathered into structs, and the
8160            // structs that we flattened into arguments, and emit local
8161            // variables with initializers that put everything back the way the
8162            // body code expects.
8163            //
8164            // If we had to generate fresh names for struct members passed as
8165            // arguments, be sure to use those names when rebuilding the struct.
8166            //
8167            // "Each day, I change some zeros to ones, and some ones to zeros.
8168            // The rest, I leave alone."
8169            for (arg_index, arg) in fun.arguments.iter().enumerate() {
8170                let arg_name =
8171                    &self.names[&NameKey::EntryPointArgument(ep_index as _, arg_index as u32)];
8172                match module.types[arg.ty].inner {
8173                    crate::TypeInner::Struct { ref members, .. } => {
8174                        let struct_name = &self.names[&NameKey::Type(arg.ty)];
8175                        write!(
8176                            self.out,
8177                            "{}const {} {} = {{ ",
8178                            back::INDENT,
8179                            struct_name,
8180                            arg_name
8181                        )?;
8182                        for (member_index, member) in members.iter().enumerate() {
8183                            let key = NameKey::StructMember(arg.ty, member_index as u32);
8184                            let name = &flattened_member_names[&key];
8185                            if member_index != 0 {
8186                                write!(self.out, ", ")?;
8187                            }
8188                            // insert padding initialization, if needed
8189                            if self
8190                                .struct_member_pads
8191                                .contains(&(arg.ty, member_index as u32))
8192                            {
8193                                write!(self.out, "{{}}, ")?;
8194                            }
8195                            match member.binding {
8196                                Some(crate::Binding::Location {
8197                                    interpolation: Some(crate::Interpolation::PerVertex),
8198                                    ..
8199                                }) => {
8200                                    writeln!(
8201                                        self.out,
8202                                        "{0}{{ {1}.{2}.get({NAMESPACE}::vertex_index::first), {1}.{2}.get({NAMESPACE}::vertex_index::second), {1}.{2}.get({NAMESPACE}::vertex_index::third) }}",
8203                                        back::INDENT,
8204                                        varyings_member_name,
8205                                        arg_name,
8206                                    )?;
8207                                    continue;
8208                                }
8209                                Some(crate::Binding::Location { .. }) => {
8210                                    if has_varyings {
8211                                        write!(self.out, "{varyings_member_name}.")?;
8212                                    }
8213                                }
8214                                _ => (),
8215                            }
8216                            write!(self.out, "{name}")?;
8217                        }
8218                        writeln!(self.out, " }};")?;
8219                    }
8220                    _ => match arg.binding {
8221                        Some(crate::Binding::Location {
8222                            interpolation: Some(crate::Interpolation::PerVertex),
8223                            ..
8224                        }) => {
8225                            let ty_name = TypeContext {
8226                                handle: arg.ty,
8227                                gctx: module.to_ctx(),
8228                                names: &self.names,
8229                                access: crate::StorageAccess::empty(),
8230                                first_time: false,
8231                            };
8232                            writeln!(
8233                                self.out,
8234                                "{0}const {ty_name} {arg_name} = {{ {1}.{2}.get({NAMESPACE}::vertex_index::first), {1}.{2}.get({NAMESPACE}::vertex_index::second), {1}.{2}.get({NAMESPACE}::vertex_index::third) }};",
8235                                back::INDENT,
8236                                varyings_member_name,
8237                                arg_name,
8238                            )?;
8239                        }
8240                        Some(crate::Binding::Location { .. })
8241                        | Some(crate::Binding::BuiltIn(crate::BuiltIn::Barycentric { .. })) => {
8242                            if has_varyings {
8243                                writeln!(
8244                                    self.out,
8245                                    "{}const auto {} = {}.{};",
8246                                    back::INDENT,
8247                                    arg_name,
8248                                    varyings_member_name,
8249                                    arg_name
8250                                )?;
8251                            }
8252                        }
8253                        _ => {}
8254                    },
8255                }
8256            }
8257
8258            let guarded_indices =
8259                index::find_checked_indexes(module, fun, fun_info, options.bounds_check_policies);
8260
8261            let context = StatementContext {
8262                expression: ExpressionContext {
8263                    function: fun,
8264                    origin: FunctionOrigin::EntryPoint(ep_index as _),
8265                    info: fun_info,
8266                    lang_version: options.lang_version,
8267                    policies: options.bounds_check_policies,
8268                    guarded_indices,
8269                    module,
8270                    mod_info,
8271                    pipeline_options,
8272                    force_loop_bounding: options.force_loop_bounding,
8273                    emit_int_div_checks: options.emit_int_div_checks,
8274                    ray_query_initialization_tracking: options.ray_query_initialization_tracking,
8275                },
8276                result_struct: if ep.stage == crate::ShaderStage::Task {
8277                    None
8278                } else {
8279                    Some(&stage_out_name)
8280                },
8281            };
8282
8283            // Finally, declare all the local variables that we need
8284            //TODO: we can postpone this till the relevant expressions are emitted
8285            self.put_locals(&context.expression)?;
8286            self.update_expressions_to_bake(fun, fun_info, &context.expression);
8287            self.put_block(back::Level(1), &fun.body, &context)?;
8288            writeln!(self.out, "}}")?;
8289            if ep_index + 1 != module.entry_points.len() {
8290                writeln!(self.out)?;
8291            }
8292            self.named_expressions.clear();
8293
8294            if is_wrapped {
8295                self.write_wrapper_function(NestedFunctionInfo {
8296                    options,
8297                    ep,
8298                    module,
8299                    mod_info,
8300                    fun_info,
8301                    args,
8302                    local_invocation_index,
8303                    nested_name: &nested_fun_name,
8304                    outer_name: &fun_name,
8305                    out_mesh_info,
8306                })?;
8307            }
8308        }
8309
8310        Ok(info)
8311    }
8312
8313    pub(super) fn write_barrier(
8314        &mut self,
8315        flags: crate::Barrier,
8316        level: back::Level,
8317    ) -> BackendResult {
8318        // Note: OR-ring bitflags requires `__HAVE_MEMFLAG_OPERATORS__`,
8319        // so we try to avoid it here.
8320        if flags.is_empty() {
8321            writeln!(
8322                self.out,
8323                "{level}{NAMESPACE}::threadgroup_barrier({NAMESPACE}::mem_flags::mem_none);",
8324            )?;
8325        }
8326        if flags.contains(crate::Barrier::STORAGE) {
8327            writeln!(
8328                self.out,
8329                "{level}{NAMESPACE}::threadgroup_barrier({NAMESPACE}::mem_flags::mem_device);",
8330            )?;
8331        }
8332        if flags.contains(crate::Barrier::WORK_GROUP) {
8333            writeln!(
8334                self.out,
8335                "{level}{NAMESPACE}::threadgroup_barrier({NAMESPACE}::mem_flags::mem_threadgroup);",
8336            )?;
8337            if self.needs_object_memory_barriers {
8338                writeln!(
8339                    self.out,
8340                    "{level}{NAMESPACE}::threadgroup_barrier({NAMESPACE}::mem_flags::mem_object_data);",
8341                )?;
8342            }
8343        }
8344        if flags.contains(crate::Barrier::SUB_GROUP) {
8345            writeln!(
8346                self.out,
8347                "{level}{NAMESPACE}::simdgroup_barrier({NAMESPACE}::mem_flags::mem_threadgroup);",
8348            )?;
8349        }
8350        if flags.contains(crate::Barrier::TEXTURE) {
8351            writeln!(
8352                self.out,
8353                "{level}{NAMESPACE}::threadgroup_barrier({NAMESPACE}::mem_flags::mem_texture);",
8354            )?;
8355        }
8356        Ok(())
8357    }
8358}
8359
8360/// Initializing workgroup variables is more tricky for Metal because we have to deal
8361/// with atomics at the type-level (which don't have a copy constructor).
8362mod workgroup_mem_init {
8363    use crate::EntryPoint;
8364
8365    use super::*;
8366
8367    enum Access {
8368        GlobalVariable(Handle<crate::GlobalVariable>),
8369        StructMember(Handle<crate::Type>, u32),
8370        Array(usize),
8371    }
8372
8373    impl Access {
8374        fn write<W: Write>(
8375            &self,
8376            writer: &mut W,
8377            names: &FastHashMap<NameKey, String>,
8378        ) -> Result<(), core::fmt::Error> {
8379            match *self {
8380                Access::GlobalVariable(handle) => {
8381                    write!(writer, "{}", &names[&NameKey::GlobalVariable(handle)])
8382                }
8383                Access::StructMember(handle, index) => {
8384                    write!(writer, ".{}", &names[&NameKey::StructMember(handle, index)])
8385                }
8386                Access::Array(depth) => write!(writer, ".{WRAPPED_ARRAY_FIELD}[__i{depth}]"),
8387            }
8388        }
8389    }
8390
8391    struct AccessStack {
8392        stack: Vec<Access>,
8393        array_depth: usize,
8394    }
8395
8396    impl AccessStack {
8397        const fn new() -> Self {
8398            Self {
8399                stack: Vec::new(),
8400                array_depth: 0,
8401            }
8402        }
8403
8404        fn enter_array<R>(&mut self, cb: impl FnOnce(&mut Self, usize) -> R) -> R {
8405            let array_depth = self.array_depth;
8406            self.stack.push(Access::Array(array_depth));
8407            self.array_depth += 1;
8408            let res = cb(self, array_depth);
8409            self.stack.pop();
8410            self.array_depth -= 1;
8411            res
8412        }
8413
8414        fn enter<R>(&mut self, new: Access, cb: impl FnOnce(&mut Self) -> R) -> R {
8415            self.stack.push(new);
8416            let res = cb(self);
8417            self.stack.pop();
8418            res
8419        }
8420
8421        fn write<W: Write>(
8422            &self,
8423            writer: &mut W,
8424            names: &FastHashMap<NameKey, String>,
8425        ) -> Result<(), core::fmt::Error> {
8426            for next in self.stack.iter() {
8427                next.write(writer, names)?;
8428            }
8429            Ok(())
8430        }
8431    }
8432
8433    impl<W: Write> Writer<W> {
8434        pub(super) fn need_workgroup_variables_initialization(
8435            &mut self,
8436            options: &Options,
8437            ep: &EntryPoint,
8438            module: &crate::Module,
8439            fun_info: &valid::FunctionInfo,
8440        ) -> bool {
8441            let is_task = ep.stage == crate::ShaderStage::Task;
8442            options.zero_initialize_workgroup_memory
8443                && ep.stage.compute_like()
8444                && module.global_variables.iter().any(|(handle, var)| {
8445                    let is_right_address_space = var.space == crate::AddressSpace::WorkGroup
8446                        || (var.space == crate::AddressSpace::TaskPayload && is_task);
8447                    !fun_info[handle].is_empty() && is_right_address_space
8448                })
8449        }
8450
8451        pub fn write_workgroup_variables_initialization(
8452            &mut self,
8453            module: &crate::Module,
8454            module_info: &valid::ModuleInfo,
8455            fun_info: &valid::FunctionInfo,
8456            local_invocation_index: Option<&NameKey>,
8457            stage: crate::ShaderStage,
8458        ) -> BackendResult {
8459            let level = back::Level(1);
8460
8461            writeln!(
8462                self.out,
8463                "{}if ({} == 0u) {{",
8464                level,
8465                local_invocation_index
8466                    .map(|name_key| self.names[name_key].as_str())
8467                    .unwrap_or("__local_invocation_index"),
8468            )?;
8469
8470            let mut access_stack = AccessStack::new();
8471
8472            let is_task = stage == crate::ShaderStage::Task;
8473            let vars = module.global_variables.iter().filter(|&(handle, var)| {
8474                let is_right_address_space = var.space == crate::AddressSpace::WorkGroup
8475                    || (var.space == crate::AddressSpace::TaskPayload && is_task);
8476                !fun_info[handle].is_empty() && is_right_address_space
8477            });
8478
8479            for (handle, var) in vars {
8480                access_stack.enter(Access::GlobalVariable(handle), |access_stack| {
8481                    self.write_workgroup_variable_initialization(
8482                        module,
8483                        module_info,
8484                        var.ty,
8485                        access_stack,
8486                        level.next(),
8487                    )
8488                })?;
8489            }
8490
8491            writeln!(self.out, "{level}}}")?;
8492            self.write_barrier(crate::Barrier::WORK_GROUP, level)
8493        }
8494
8495        fn write_workgroup_variable_initialization(
8496            &mut self,
8497            module: &crate::Module,
8498            module_info: &valid::ModuleInfo,
8499            ty: Handle<crate::Type>,
8500            access_stack: &mut AccessStack,
8501            level: back::Level,
8502        ) -> BackendResult {
8503            if module_info[ty].contains(valid::TypeFlags::CONSTRUCTIBLE) {
8504                write!(self.out, "{level}")?;
8505                access_stack.write(&mut self.out, &self.names)?;
8506                writeln!(self.out, " = {{}};")?;
8507            } else {
8508                match module.types[ty].inner {
8509                    crate::TypeInner::Atomic { .. } => {
8510                        write!(
8511                            self.out,
8512                            "{level}{NAMESPACE}::atomic_store_explicit({ATOMIC_REFERENCE}"
8513                        )?;
8514                        access_stack.write(&mut self.out, &self.names)?;
8515                        writeln!(self.out, ", 0, {NAMESPACE}::memory_order_relaxed);")?;
8516                    }
8517                    crate::TypeInner::Array { base, size, .. } => {
8518                        let count = match size.resolve(module.to_ctx())? {
8519                            proc::IndexableLength::Known(count) => count,
8520                            proc::IndexableLength::Dynamic => unreachable!(),
8521                        };
8522
8523                        access_stack.enter_array(|access_stack, array_depth| {
8524                            writeln!(
8525                                self.out,
8526                                "{level}for (int __i{array_depth} = 0; __i{array_depth} < {count}; __i{array_depth}++) {{"
8527                            )?;
8528                            self.write_workgroup_variable_initialization(
8529                                module,
8530                                module_info,
8531                                base,
8532                                access_stack,
8533                                level.next(),
8534                            )?;
8535                            writeln!(self.out, "{level}}}")?;
8536                            BackendResult::Ok(())
8537                        })?;
8538                    }
8539                    crate::TypeInner::Struct { ref members, .. } => {
8540                        for (index, member) in members.iter().enumerate() {
8541                            access_stack.enter(
8542                                Access::StructMember(ty, index as u32),
8543                                |access_stack| {
8544                                    self.write_workgroup_variable_initialization(
8545                                        module,
8546                                        module_info,
8547                                        member.ty,
8548                                        access_stack,
8549                                        level,
8550                                    )
8551                                },
8552                            )?;
8553                        }
8554                    }
8555                    _ => unreachable!(),
8556                }
8557            }
8558
8559            Ok(())
8560        }
8561    }
8562}
8563
8564impl crate::AtomicFunction {
8565    const fn to_msl(self) -> &'static str {
8566        match self {
8567            Self::Add => "fetch_add",
8568            Self::Subtract => "fetch_sub",
8569            Self::And => "fetch_and",
8570            Self::InclusiveOr => "fetch_or",
8571            Self::ExclusiveOr => "fetch_xor",
8572            Self::Min => "fetch_min",
8573            Self::Max => "fetch_max",
8574            Self::Exchange { compare: None } => "exchange",
8575            Self::Exchange { compare: Some(_) } => ATOMIC_COMP_EXCH_FUNCTION,
8576        }
8577    }
8578
8579    fn to_msl_64_bit(self) -> Result<&'static str, Error> {
8580        Ok(match self {
8581            Self::Min => "min",
8582            Self::Max => "max",
8583            _ => Err(Error::FeatureNotImplemented(
8584                "64-bit atomic operation other than min/max".to_string(),
8585            ))?,
8586        })
8587    }
8588}