naga/back/msl/
writer.rs

1use alloc::{
2    format,
3    string::{String, ToString},
4    vec,
5    vec::Vec,
6};
7use core::{
8    cmp::Ordering,
9    fmt::{Display, Error as FmtError, Formatter, Write},
10    iter,
11};
12use num_traits::real::Real as _;
13
14use half::f16;
15
16use super::{
17    ray::RT_NAMESPACE, sampler as sm, Error, LocationMode, Options, PipelineOptions,
18    TranslationInfo, NAMESPACE, WRAPPED_ARRAY_FIELD,
19};
20use crate::{
21    arena::{Handle, HandleSet},
22    back::{
23        self, get_entry_points,
24        msl::{mesh_shader::NestedFunctionInfo, BackendResult, EntryPointArgument},
25        Baked,
26    },
27    common,
28    proc::{
29        self, concrete_int_scalars,
30        index::{self, BoundsCheck},
31        ExternalTextureNameKey, NameKey, TypeResolution,
32    },
33    valid, FastHashMap, FastHashSet,
34};
35
36#[cfg(test)]
37use core::ptr;
38
39// This is a hack: we need to pass a pointer to an atomic,
40// but generally the backend isn't putting "&" in front of every pointer.
41// Some more general handling of pointers is needed to be implemented here.
42const ATOMIC_REFERENCE: &str = "&";
43
44pub(crate) const ATOMIC_COMP_EXCH_FUNCTION: &str = "naga_atomic_compare_exchange_weak_explicit";
45pub(crate) const MODF_FUNCTION: &str = "naga_modf";
46pub(crate) const FREXP_FUNCTION: &str = "naga_frexp";
47pub(crate) const ABS_FUNCTION: &str = "naga_abs";
48pub(crate) const DIV_FUNCTION: &str = "naga_div";
49pub(crate) const DOT_FUNCTION_PREFIX: &str = "naga_dot";
50pub(crate) const MOD_FUNCTION: &str = "naga_mod";
51pub(crate) const NEG_FUNCTION: &str = "naga_neg";
52pub(crate) const F2I32_FUNCTION: &str = "naga_f2i32";
53pub(crate) const F2U32_FUNCTION: &str = "naga_f2u32";
54pub(crate) const F2I64_FUNCTION: &str = "naga_f2i64";
55pub(crate) const F2U64_FUNCTION: &str = "naga_f2u64";
56pub(crate) const IMAGE_LOAD_EXTERNAL_FUNCTION: &str = "nagaTextureLoadExternal";
57pub(crate) const IMAGE_SIZE_EXTERNAL_FUNCTION: &str = "nagaTextureDimensionsExternal";
58pub(crate) const IMAGE_SAMPLE_BASE_CLAMP_TO_EDGE_FUNCTION: &str =
59    "nagaTextureSampleBaseClampToEdge";
60/// For some reason, Metal does not let you have `metal::texture<..>*` as a buffer argument.
61/// However, if you put that texture inside a struct, everything is totally fine. This
62/// baffles me to no end.
63///
64/// As such, we wrap all argument buffers in a struct that has a single generic `<T>` field.
65/// This allows `NagaArgumentBufferWrapper<metal::texture<..>>*` to work. The astute among
66/// you have noticed that this should be exactly the same to the compiler, and you're correct.
67pub(crate) const ARGUMENT_BUFFER_WRAPPER_STRUCT: &str = "NagaArgumentBufferWrapper";
68/// Name of the struct that is declared to wrap the 3 textures and parameters
69/// buffer that [`crate::ImageClass::External`] variables are lowered to,
70/// allowing them to be conveniently passed to user-defined or wrapper
71/// functions. The struct is declared in [`Writer::write_type_defs`].
72pub(crate) const EXTERNAL_TEXTURE_WRAPPER_STRUCT: &str = "NagaExternalTextureWrapper";
73pub(crate) const COOPERATIVE_LOAD_FUNCTION: &str = "NagaCooperativeLoad";
74pub(crate) const COOPERATIVE_MULTIPLY_ADD_FUNCTION: &str = "NagaCooperativeMultiplyAdd";
75
76/// Write the Metal name for a Naga numeric type: scalar, vector, or matrix.
77///
78/// The `sizes` slice determines whether this function writes a
79/// scalar, vector, or matrix type:
80///
81/// - An empty slice produces a scalar type.
82/// - A one-element slice produces a vector type.
83/// - A two element slice `[ROWS COLUMNS]` produces a matrix of the given size.
84fn put_numeric_type(
85    out: &mut impl Write,
86    scalar: crate::Scalar,
87    sizes: &[crate::VectorSize],
88) -> Result<(), FmtError> {
89    match (scalar, sizes) {
90        (scalar, &[]) => {
91            write!(out, "{}", scalar.to_msl_name())
92        }
93        (scalar, &[rows]) => {
94            write!(
95                out,
96                "{}::{}{}",
97                NAMESPACE,
98                scalar.to_msl_name(),
99                common::vector_size_str(rows)
100            )
101        }
102        (scalar, &[rows, columns]) => {
103            write!(
104                out,
105                "{}::{}{}x{}",
106                NAMESPACE,
107                scalar.to_msl_name(),
108                common::vector_size_str(columns),
109                common::vector_size_str(rows)
110            )
111        }
112        (_, _) => Ok(()), // not meaningful
113    }
114}
115
116const fn scalar_is_int(scalar: crate::Scalar) -> bool {
117    use crate::ScalarKind::*;
118    match scalar.kind {
119        Sint | Uint | AbstractInt | Bool => true,
120        Float | AbstractFloat => false,
121    }
122}
123
124/// Prefix for cached clamped level-of-detail values for `ImageLoad` expressions.
125const CLAMPED_LOD_LOAD_PREFIX: &str = "clamped_lod_e";
126
127/// Prefix for reinterpreted expressions using `as_type<T>(...)`.
128const REINTERPRET_PREFIX: &str = "reinterpreted_";
129
130/// Wrapper for identifier names for clamped level-of-detail values
131///
132/// Values of this type implement [`core::fmt::Display`], formatting as
133/// the name of the variable used to hold the cached clamped
134/// level-of-detail value for an `ImageLoad` expression.
135struct ClampedLod(Handle<crate::Expression>);
136
137impl Display for ClampedLod {
138    fn fmt(&self, f: &mut Formatter<'_>) -> core::fmt::Result {
139        self.0.write_prefixed(f, CLAMPED_LOD_LOAD_PREFIX)
140    }
141}
142
143/// Wrapper for generating `struct _mslBufferSizes` member names for
144/// runtime-sized array lengths.
145///
146/// On Metal, `wgpu_hal` passes the element counts for all runtime-sized arrays
147/// as an argument to the entry point. This argument's type in the MSL is
148/// `struct _mslBufferSizes`, a Naga-synthesized struct with a `uint` member for
149/// each global variable containing a runtime-sized array.
150///
151/// If `global` is a [`Handle`] for a [`GlobalVariable`] that contains a
152/// runtime-sized array, then the value `ArraySize(global)` implements
153/// [`core::fmt::Display`], formatting as the name of the struct member carrying
154/// the number of elements in that runtime-sized array.
155///
156/// [`GlobalVariable`]: crate::GlobalVariable
157struct ArraySizeMember(Handle<crate::GlobalVariable>);
158
159impl Display for ArraySizeMember {
160    fn fmt(&self, f: &mut Formatter<'_>) -> core::fmt::Result {
161        self.0.write_prefixed(f, "size")
162    }
163}
164
165/// Wrapper for reinterpreted variables using `as_type<target_type>(orig)`.
166///
167/// Implements [`core::fmt::Display`], formatting as a name derived from
168/// `target_type` and the variable name of `orig`.
169#[derive(Clone, Copy)]
170struct Reinterpreted<'a> {
171    target_type: &'a str,
172    orig: Handle<crate::Expression>,
173}
174
175impl<'a> Reinterpreted<'a> {
176    const fn new(target_type: &'a str, orig: Handle<crate::Expression>) -> Self {
177        Self { target_type, orig }
178    }
179}
180
181impl Display for Reinterpreted<'_> {
182    fn fmt(&self, f: &mut Formatter<'_>) -> core::fmt::Result {
183        f.write_str(REINTERPRET_PREFIX)?;
184        f.write_str(self.target_type)?;
185        self.orig.write_prefixed(f, "_e")
186    }
187}
188
189pub(super) struct TypeContext<'a> {
190    pub handle: Handle<crate::Type>,
191    pub gctx: proc::GlobalCtx<'a>,
192    pub names: &'a FastHashMap<NameKey, String>,
193    pub access: crate::StorageAccess,
194    pub first_time: bool,
195}
196
197impl TypeContext<'_> {
198    fn scalar(&self) -> Option<crate::Scalar> {
199        let ty = &self.gctx.types[self.handle];
200        ty.inner.scalar()
201    }
202
203    fn vector_size(&self) -> Option<crate::VectorSize> {
204        let ty = &self.gctx.types[self.handle];
205        match ty.inner {
206            crate::TypeInner::Vector { size, .. } => Some(size),
207            _ => None,
208        }
209    }
210
211    fn unwrap_array(self) -> Self {
212        match self.gctx.types[self.handle].inner {
213            crate::TypeInner::Array { base, .. } => Self {
214                handle: base,
215                ..self
216            },
217            _ => self,
218        }
219    }
220}
221
222impl Display for TypeContext<'_> {
223    fn fmt(&self, out: &mut Formatter<'_>) -> Result<(), FmtError> {
224        let ty = &self.gctx.types[self.handle];
225        if ty.needs_alias() && !self.first_time {
226            let name = &self.names[&NameKey::Type(self.handle)];
227            return write!(out, "{name}");
228        }
229
230        match ty.inner {
231            crate::TypeInner::Scalar(scalar) => put_numeric_type(out, scalar, &[]),
232            crate::TypeInner::Atomic(scalar) => {
233                write!(out, "{}::atomic_{}", NAMESPACE, scalar.to_msl_name())
234            }
235            crate::TypeInner::Vector { size, scalar } => put_numeric_type(out, scalar, &[size]),
236            crate::TypeInner::Matrix {
237                columns,
238                rows,
239                scalar,
240            } => put_numeric_type(out, scalar, &[rows, columns]),
241            // Requires Metal-2.3
242            crate::TypeInner::CooperativeMatrix {
243                columns,
244                rows,
245                scalar,
246                role: _,
247            } => {
248                write!(
249                    out,
250                    "{NAMESPACE}::simdgroup_{}{}x{}",
251                    scalar.to_msl_name(),
252                    columns as u32,
253                    rows as u32,
254                )
255            }
256            crate::TypeInner::Pointer { base, space } => {
257                let sub = Self {
258                    handle: base,
259                    first_time: false,
260                    ..*self
261                };
262                let space_name = match space.to_msl_name() {
263                    Some(name) => name,
264                    None => return Ok(()),
265                };
266                write!(out, "{space_name} {sub}&")
267            }
268            crate::TypeInner::ValuePointer {
269                size,
270                scalar,
271                space,
272            } => {
273                match space.to_msl_name() {
274                    Some(name) => write!(out, "{name} ")?,
275                    None => return Ok(()),
276                };
277                match size {
278                    Some(rows) => put_numeric_type(out, scalar, &[rows])?,
279                    None => put_numeric_type(out, scalar, &[])?,
280                };
281
282                write!(out, "&")
283            }
284            crate::TypeInner::Array { base, .. } => {
285                let sub = Self {
286                    handle: base,
287                    first_time: false,
288                    ..*self
289                };
290                // Array lengths go at the end of the type definition,
291                // so just print the element type here.
292                write!(out, "{sub}")
293            }
294            crate::TypeInner::Struct { .. } => unreachable!(),
295            crate::TypeInner::Image {
296                dim,
297                arrayed,
298                class,
299            } => {
300                let dim_str = match dim {
301                    crate::ImageDimension::D1 => "1d",
302                    crate::ImageDimension::D2 => "2d",
303                    crate::ImageDimension::D3 => "3d",
304                    crate::ImageDimension::Cube => "cube",
305                };
306                let (texture_str, msaa_str, scalar, access) = match class {
307                    crate::ImageClass::Sampled { kind, multi } => {
308                        let (msaa_str, access) = if multi {
309                            ("_ms", "read")
310                        } else {
311                            ("", "sample")
312                        };
313                        let scalar = crate::Scalar { kind, width: 4 };
314                        ("texture", msaa_str, scalar, access)
315                    }
316                    crate::ImageClass::Depth { multi } => {
317                        let (msaa_str, access) = if multi {
318                            ("_ms", "read")
319                        } else {
320                            ("", "sample")
321                        };
322                        let scalar = crate::Scalar {
323                            kind: crate::ScalarKind::Float,
324                            width: 4,
325                        };
326                        ("depth", msaa_str, scalar, access)
327                    }
328                    crate::ImageClass::Storage { format, .. } => {
329                        let access = if self
330                            .access
331                            .contains(crate::StorageAccess::LOAD | crate::StorageAccess::STORE)
332                        {
333                            "read_write"
334                        } else if self.access.contains(crate::StorageAccess::STORE) {
335                            "write"
336                        } else if self.access.contains(crate::StorageAccess::LOAD) {
337                            "read"
338                        } else {
339                            log::warn!(
340                                "Storage access for {:?} (name '{}'): {:?}",
341                                self.handle,
342                                ty.name.as_deref().unwrap_or_default(),
343                                self.access
344                            );
345                            unreachable!("module is not valid");
346                        };
347                        ("texture", "", format.into(), access)
348                    }
349                    crate::ImageClass::External => {
350                        return write!(out, "{EXTERNAL_TEXTURE_WRAPPER_STRUCT}");
351                    }
352                };
353                let base_name = scalar.to_msl_name();
354                let array_str = if arrayed { "_array" } else { "" };
355                write!(
356                    out,
357                    "{NAMESPACE}::{texture_str}{dim_str}{msaa_str}{array_str}<{base_name}, {NAMESPACE}::access::{access}>",
358                )
359            }
360            crate::TypeInner::Sampler { comparison: _ } => {
361                write!(out, "{NAMESPACE}::sampler")
362            }
363            crate::TypeInner::AccelerationStructure { vertex_return } => {
364                if vertex_return {
365                    unimplemented!("metal does not support vertex ray hit return")
366                }
367                write!(out, "{RT_NAMESPACE}::instance_acceleration_structure")
368            }
369            crate::TypeInner::RayQuery { vertex_return } => {
370                if vertex_return {
371                    unimplemented!("metal does not support vertex ray hit return")
372                }
373                write!(out, "{}", super::ray::metal_intersector_ty())
374            }
375            crate::TypeInner::BindingArray { base, .. } => {
376                let base_tyname = Self {
377                    handle: base,
378                    first_time: false,
379                    ..*self
380                };
381
382                write!(
383                    out,
384                    "constant {ARGUMENT_BUFFER_WRAPPER_STRUCT}<{base_tyname}>*"
385                )
386            }
387        }
388    }
389}
390
391pub(super) struct TypedGlobalVariable<'a> {
392    pub module: &'a crate::Module,
393    pub names: &'a FastHashMap<NameKey, String>,
394    pub handle: Handle<crate::GlobalVariable>,
395    pub usage: valid::GlobalUse,
396    pub reference: bool,
397}
398
399struct TypedGlobalVariableParts {
400    ty_name: String,
401    var_name: String,
402}
403
404impl TypedGlobalVariable<'_> {
405    fn to_parts(&self) -> Result<TypedGlobalVariableParts, Error> {
406        let var = &self.module.global_variables[self.handle];
407        let name = &self.names[&NameKey::GlobalVariable(self.handle)];
408
409        let storage_access = match var.space {
410            crate::AddressSpace::Storage { access } => access,
411            _ => match self.module.types[var.ty].inner {
412                crate::TypeInner::Image {
413                    class: crate::ImageClass::Storage { access, .. },
414                    ..
415                } => access,
416                crate::TypeInner::BindingArray { base, .. } => {
417                    match self.module.types[base].inner {
418                        crate::TypeInner::Image {
419                            class: crate::ImageClass::Storage { access, .. },
420                            ..
421                        } => access,
422                        _ => crate::StorageAccess::default(),
423                    }
424                }
425                _ => crate::StorageAccess::default(),
426            },
427        };
428        let ty_name = TypeContext {
429            handle: var.ty,
430            gctx: self.module.to_ctx(),
431            names: self.names,
432            access: storage_access,
433            first_time: false,
434        };
435
436        let access = if var.space.needs_access_qualifier()
437            && !self.usage.intersects(valid::GlobalUse::WRITE)
438        {
439            "const"
440        } else {
441            ""
442        };
443        let (coherent, space, access, reference) = match (var.space.to_msl_name(), var.space) {
444            (Some(space), crate::AddressSpace::WorkGroup) => {
445                ("", space, access, if self.reference { "&" } else { "" })
446            }
447            (Some(space), _) if self.reference => {
448                let coherent = if var
449                    .memory_decorations
450                    .contains(crate::MemoryDecorations::COHERENT)
451                {
452                    "coherent "
453                } else {
454                    ""
455                };
456                (coherent, space, access, "&")
457            }
458            _ => ("", "", "", ""),
459        };
460
461        let ty = format!(
462            "{coherent}{space}{}{ty_name}{}{access}{reference}",
463            if space.is_empty() { "" } else { " " },
464            if access.is_empty() { "" } else { " " },
465        );
466
467        Ok(TypedGlobalVariableParts {
468            ty_name: ty,
469            var_name: name.clone(),
470        })
471    }
472    pub(super) fn try_fmt<W: Write>(&self, out: &mut W) -> BackendResult {
473        let parts = self.to_parts()?;
474
475        Ok(write!(out, "{} {}", parts.ty_name, parts.var_name)?)
476    }
477}
478
479#[derive(Eq, PartialEq, Hash)]
480pub(super) enum WrappedFunction {
481    UnaryOp {
482        op: crate::UnaryOperator,
483        ty: (Option<crate::VectorSize>, crate::Scalar),
484    },
485    BinaryOp {
486        op: crate::BinaryOperator,
487        left_ty: (Option<crate::VectorSize>, crate::Scalar),
488        right_ty: (Option<crate::VectorSize>, crate::Scalar),
489    },
490    Math {
491        fun: crate::MathFunction,
492        arg_ty: (Option<crate::VectorSize>, crate::Scalar),
493    },
494    Cast {
495        src_scalar: crate::Scalar,
496        vector_size: Option<crate::VectorSize>,
497        dst_scalar: crate::Scalar,
498    },
499    ImageLoad {
500        class: crate::ImageClass,
501    },
502    ImageSample {
503        class: crate::ImageClass,
504        clamp_to_edge: bool,
505    },
506    ImageQuerySize {
507        class: crate::ImageClass,
508    },
509    CooperativeLoad {
510        space_name: &'static str,
511        columns: crate::CooperativeSize,
512        rows: crate::CooperativeSize,
513        scalar: crate::Scalar,
514    },
515    CooperativeMultiplyAdd {
516        space_name: &'static str,
517        columns: crate::CooperativeSize,
518        rows: crate::CooperativeSize,
519        intermediate: crate::CooperativeSize,
520        scalar: crate::Scalar,
521    },
522    RayQueryGetIntersection {
523        committed: bool,
524    },
525}
526
527pub struct Writer<W> {
528    pub(super) out: W,
529    pub(super) names: FastHashMap<NameKey, String>,
530    pub(super) named_expressions: crate::NamedExpressions,
531    /// Set of expressions that need to be baked to avoid unnecessary repetition in output
532    pub(super) need_bake_expressions: back::NeedBakeExpressions,
533    pub(super) namer: proc::Namer,
534    pub(super) wrapped_functions: FastHashSet<WrappedFunction>,
535    #[cfg(test)]
536    pub(super) put_expression_stack_pointers: FastHashSet<*const ()>,
537    #[cfg(test)]
538    pub(super) put_block_stack_pointers: FastHashSet<*const ()>,
539    /// Set of (struct type, struct field index) denoting which fields require
540    /// padding inserted **before** them (i.e. between fields at index - 1 and index)
541    pub(super) struct_member_pads: FastHashSet<(Handle<crate::Type>, u32)>,
542    pub(super) needs_object_memory_barriers: bool,
543}
544
545impl crate::Scalar {
546    pub(super) fn to_msl_name(self) -> &'static str {
547        use crate::ScalarKind as Sk;
548        match self {
549            Self {
550                kind: Sk::Float,
551                width: 4,
552            } => "float",
553            Self {
554                kind: Sk::Float,
555                width: 2,
556            } => "half",
557            Self {
558                kind: Sk::Sint,
559                width: 2,
560            } => "short",
561            Self {
562                kind: Sk::Uint,
563                width: 2,
564            } => "ushort",
565            Self {
566                kind: Sk::Sint,
567                width: 4,
568            } => "int",
569            Self {
570                kind: Sk::Uint,
571                width: 4,
572            } => "uint",
573            Self {
574                kind: Sk::Sint,
575                width: 8,
576            } => "long",
577            Self {
578                kind: Sk::Uint,
579                width: 8,
580            } => "ulong",
581            Self {
582                kind: Sk::Bool,
583                width: _,
584            } => "bool",
585            Self {
586                kind: Sk::AbstractInt | Sk::AbstractFloat,
587                width: _,
588            } => unreachable!("Found Abstract scalar kind"),
589            _ => unreachable!("Unsupported scalar kind: {:?}", self),
590        }
591    }
592}
593
594const fn separate(need_separator: bool) -> &'static str {
595    if need_separator {
596        ","
597    } else {
598        ""
599    }
600}
601
602fn should_pack_struct_member(
603    members: &[crate::StructMember],
604    span: u32,
605    index: usize,
606    module: &crate::Module,
607) -> Option<crate::Scalar> {
608    let member = &members[index];
609
610    let ty_inner = &module.types[member.ty].inner;
611    let last_offset = member.offset + ty_inner.size(module.to_ctx());
612    let next_offset = match members.get(index + 1) {
613        Some(next) => next.offset,
614        None => span,
615    };
616    let is_tight = next_offset == last_offset;
617
618    match *ty_inner {
619        crate::TypeInner::Vector {
620            size: crate::VectorSize::Tri,
621            scalar: scalar @ crate::Scalar { width: 4 | 2, .. },
622        } if is_tight => Some(scalar),
623        _ => None,
624    }
625}
626
627fn needs_array_length(ty: Handle<crate::Type>, arena: &crate::UniqueArena<crate::Type>) -> bool {
628    match arena[ty].inner {
629        crate::TypeInner::Struct { ref members, .. } => {
630            if let Some(member) = members.last() {
631                if let crate::TypeInner::Array {
632                    size: crate::ArraySize::Dynamic,
633                    ..
634                } = arena[member.ty].inner
635                {
636                    return true;
637                }
638            }
639            false
640        }
641        crate::TypeInner::Array {
642            size: crate::ArraySize::Dynamic,
643            ..
644        } => true,
645        _ => false,
646    }
647}
648
649impl crate::AddressSpace {
650    /// Returns true if global variables in this address space are
651    /// passed in function arguments. These arguments need to be
652    /// passed through any functions called from the entry point.
653    const fn needs_pass_through(&self) -> bool {
654        match *self {
655            Self::Uniform
656            | Self::Storage { .. }
657            | Self::Private
658            | Self::WorkGroup
659            | Self::Immediate
660            | Self::Handle
661            | Self::TaskPayload => true,
662            Self::Function => false,
663            Self::RayPayload | Self::IncomingRayPayload => unreachable!(),
664        }
665    }
666
667    /// Returns true if the address space may need a "const" qualifier.
668    const fn needs_access_qualifier(&self) -> bool {
669        match *self {
670            //Note: we are ignoring the storage access here, and instead
671            // rely on the actual use of a global by functions. This means we
672            // may end up with "const" even if the binding is read-write,
673            // and that should be OK.
674            Self::Storage { .. } => true,
675            Self::TaskPayload => true,
676            Self::RayPayload | Self::IncomingRayPayload => unimplemented!(),
677            // These should always be read-write.
678            Self::Private | Self::WorkGroup => false,
679            // These translate to `constant` address space, no need for qualifiers.
680            Self::Uniform | Self::Immediate => false,
681            // Not applicable.
682            Self::Handle | Self::Function => false,
683        }
684    }
685
686    const fn to_msl_name(self) -> Option<&'static str> {
687        match self {
688            Self::Handle => None,
689            Self::Uniform | Self::Immediate => Some("constant"),
690            Self::Storage { .. } => Some("device"),
691            // note for `RayPayload`, this probably needs to be emulated as a
692            // private variable, as metal has essentially an inout input
693            // for where it is passed.
694            Self::Private | Self::Function | Self::RayPayload => Some("thread"),
695            Self::WorkGroup => Some("threadgroup"),
696            Self::TaskPayload => Some("object_data"),
697            Self::IncomingRayPayload => Some("ray_data"),
698        }
699    }
700}
701
702impl crate::Type {
703    // Returns `true` if we need to emit an alias for this type.
704    const fn needs_alias(&self) -> bool {
705        use crate::TypeInner as Ti;
706
707        match self.inner {
708            // value types are concise enough, we only alias them if they are named
709            Ti::Scalar(_)
710            | Ti::Vector { .. }
711            | Ti::Matrix { .. }
712            | Ti::CooperativeMatrix { .. }
713            | Ti::Atomic(_)
714            | Ti::Pointer { .. }
715            | Ti::ValuePointer { .. } => self.name.is_some(),
716            // composite types are better to be aliased, regardless of the name
717            Ti::Struct { .. } | Ti::Array { .. } => true,
718            // handle types may be different, depending on the global var access, so we always inline them
719            Ti::Image { .. }
720            | Ti::Sampler { .. }
721            | Ti::AccelerationStructure { .. }
722            | Ti::RayQuery { .. }
723            | Ti::BindingArray { .. } => false,
724        }
725    }
726}
727
728#[derive(Clone, Copy)]
729enum FunctionOrigin {
730    Handle(Handle<crate::Function>),
731    EntryPoint(proc::EntryPointIndex),
732}
733
734trait NameKeyExt {
735    fn local(origin: FunctionOrigin, local_handle: Handle<crate::LocalVariable>) -> NameKey {
736        match origin {
737            FunctionOrigin::Handle(handle) => NameKey::FunctionLocal(handle, local_handle),
738            FunctionOrigin::EntryPoint(idx) => NameKey::EntryPointLocal(idx, local_handle),
739        }
740    }
741
742    /// Return the name key for a local variable used by ReadZeroSkipWrite bounds-check
743    /// policy when it needs to produce a pointer-typed result for an OOB access. These
744    /// are unique per accessed type, so the second argument is a type handle. See docs
745    /// for [`crate::back::msl`].
746    fn oob_local_for_type(origin: FunctionOrigin, ty: Handle<crate::Type>) -> NameKey {
747        match origin {
748            FunctionOrigin::Handle(handle) => NameKey::FunctionOobLocal(handle, ty),
749            FunctionOrigin::EntryPoint(idx) => NameKey::EntryPointOobLocal(idx, ty),
750        }
751    }
752}
753
754impl NameKeyExt for NameKey {}
755
756/// A level of detail argument.
757///
758/// When [`BoundsCheckPolicy::Restrict`] applies to an [`ImageLoad`] access, we
759/// save the clamped level of detail in a temporary variable whose name is based
760/// on the handle of the `ImageLoad` expression. But for other policies, we just
761/// use the expression directly.
762///
763/// [`BoundsCheckPolicy::Restrict`]: index::BoundsCheckPolicy::Restrict
764/// [`ImageLoad`]: crate::Expression::ImageLoad
765#[derive(Clone, Copy)]
766enum LevelOfDetail {
767    Direct(Handle<crate::Expression>),
768    Restricted(Handle<crate::Expression>),
769}
770
771/// Values needed to select a particular texel for [`ImageLoad`] and [`ImageStore`].
772///
773/// When this is used in code paths unconcerned with the `Restrict` bounds check
774/// policy, the `LevelOfDetail` enum introduces an unneeded match, since `level`
775/// will always be either `None` or `Some(Direct(_))`. But this turns out not to
776/// be too awkward. If that changes, we can revisit.
777///
778/// [`ImageLoad`]: crate::Expression::ImageLoad
779/// [`ImageStore`]: crate::Statement::ImageStore
780struct TexelAddress {
781    coordinate: Handle<crate::Expression>,
782    array_index: Option<Handle<crate::Expression>>,
783    sample: Option<Handle<crate::Expression>>,
784    level: Option<LevelOfDetail>,
785}
786
787pub(super) struct ExpressionContext<'a> {
788    pub(super) function: &'a crate::Function,
789    origin: FunctionOrigin,
790    pub(super) info: &'a valid::FunctionInfo,
791    pub(super) module: &'a crate::Module,
792    pub(super) mod_info: &'a valid::ModuleInfo,
793    pub(super) pipeline_options: &'a PipelineOptions,
794    pub(super) lang_version: (u8, u8),
795    pub(super) policies: index::BoundsCheckPolicies,
796
797    /// The set of expressions used as indices in `ReadZeroSkipWrite`-policy
798    /// accesses. These may need to be cached in temporary variables. See
799    /// `index::find_checked_indexes` for details.
800    pub(super) guarded_indices: HandleSet<crate::Expression>,
801    /// See [`Writer::gen_force_bounded_loop_statements`] for details.
802    pub(super) force_loop_bounding: bool,
803}
804
805impl<'a> ExpressionContext<'a> {
806    fn resolve_type(&self, handle: Handle<crate::Expression>) -> &'a crate::TypeInner {
807        self.info[handle].ty.inner_with(&self.module.types)
808    }
809
810    /// Return true if calls to `image`'s `read` and `write` methods should supply a level of detail.
811    ///
812    /// Only mipmapped images need to specify a level of detail. Since 1D
813    /// textures cannot have mipmaps, MSL requires that the level argument to
814    /// texture1d queries and accesses must be a constexpr 0. It's easiest
815    /// just to omit the level entirely for 1D textures.
816    fn image_needs_lod(&self, image: Handle<crate::Expression>) -> bool {
817        let image_ty = self.resolve_type(image);
818        if let crate::TypeInner::Image { dim, class, .. } = *image_ty {
819            class.is_mipmapped() && dim != crate::ImageDimension::D1
820        } else {
821            false
822        }
823    }
824
825    fn choose_bounds_check_policy(
826        &self,
827        pointer: Handle<crate::Expression>,
828    ) -> index::BoundsCheckPolicy {
829        self.policies
830            .choose_policy(pointer, &self.module.types, self.info)
831    }
832
833    /// See docs for [`proc::index::access_needs_check`].
834    fn access_needs_check(
835        &self,
836        base: Handle<crate::Expression>,
837        index: index::GuardedIndex,
838    ) -> Option<index::IndexableLength> {
839        index::access_needs_check(
840            base,
841            index,
842            self.module,
843            &self.function.expressions,
844            self.info,
845        )
846    }
847
848    /// See docs for [`proc::index::bounds_check_iter`].
849    fn bounds_check_iter(
850        &self,
851        chain: Handle<crate::Expression>,
852    ) -> impl Iterator<Item = BoundsCheck> + '_ {
853        index::bounds_check_iter(chain, self.module, self.function, self.info)
854    }
855
856    /// See docs for [`proc::index::oob_local_types`].
857    fn oob_local_types(&self) -> FastHashSet<Handle<crate::Type>> {
858        index::oob_local_types(self.module, self.function, self.info, self.policies)
859    }
860
861    fn get_packed_vec_kind(&self, expr_handle: Handle<crate::Expression>) -> Option<crate::Scalar> {
862        match self.function.expressions[expr_handle] {
863            crate::Expression::AccessIndex { base, index } => {
864                let ty = match *self.resolve_type(base) {
865                    crate::TypeInner::Pointer { base, .. } => &self.module.types[base].inner,
866                    ref ty => ty,
867                };
868                match *ty {
869                    crate::TypeInner::Struct {
870                        ref members, span, ..
871                    } => should_pack_struct_member(members, span, index as usize, self.module),
872                    _ => None,
873                }
874            }
875            _ => None,
876        }
877    }
878}
879
880pub(super) struct StatementContext<'a> {
881    pub(super) expression: ExpressionContext<'a>,
882    pub(super) result_struct: Option<&'a str>,
883}
884
885impl<W: Write> Writer<W> {
886    /// Creates a new `Writer` instance.
887    pub fn new(out: W) -> Self {
888        Writer {
889            out,
890            names: FastHashMap::default(),
891            named_expressions: Default::default(),
892            need_bake_expressions: Default::default(),
893            namer: proc::Namer::default(),
894            wrapped_functions: FastHashSet::default(),
895            #[cfg(test)]
896            put_expression_stack_pointers: Default::default(),
897            #[cfg(test)]
898            put_block_stack_pointers: Default::default(),
899            struct_member_pads: FastHashSet::default(),
900            needs_object_memory_barriers: false,
901        }
902    }
903
904    /// Finishes writing and returns the output.
905    // See https://github.com/rust-lang/rust-clippy/issues/4979.
906    pub fn finish(self) -> W {
907        self.out
908    }
909
910    /// Generates statements to be inserted immediately before and at the very
911    /// start of the body of each loop, to defeat MSL infinite loop reasoning.
912    /// The 0th item of the returned tuple should be inserted immediately prior
913    /// to the loop and the 1st item should be inserted at the very start of
914    /// the loop body.
915    ///
916    /// # What is this trying to solve?
917    ///
918    /// In Metal Shading Language, an infinite loop has undefined behavior.
919    /// (This rule is inherited from C++14.) This means that, if the MSL
920    /// compiler determines that a given loop will never exit, it may assume
921    /// that it is never reached. It may thus assume that any conditions
922    /// sufficient to cause the loop to be reached must be false. Like many
923    /// optimizing compilers, MSL uses this kind of analysis to establish limits
924    /// on the range of values variables involved in those conditions might
925    /// hold.
926    ///
927    /// For example, suppose the MSL compiler sees the code:
928    ///
929    /// ```ignore
930    /// if (i >= 10) {
931    ///     while (true) { }
932    /// }
933    /// ```
934    ///
935    /// It will recognize that the `while` loop will never terminate, conclude
936    /// that it must be unreachable, and thus infer that, if this code is
937    /// reached, then `i < 10` at that point.
938    ///
939    /// Now suppose that, at some point where `i` has the same value as above,
940    /// the compiler sees the code:
941    ///
942    /// ```ignore
943    /// if (i < 10) {
944    ///     a[i] = 1;
945    /// }
946    /// ```
947    ///
948    /// Because the compiler is confident that `i < 10`, it will make the
949    /// assignment to `a[i]` unconditional, rewriting this code as, simply:
950    ///
951    /// ```ignore
952    /// a[i] = 1;
953    /// ```
954    ///
955    /// If that `if` condition was injected by Naga to implement a bounds check,
956    /// the MSL compiler's optimizations could allow out-of-bounds array
957    /// accesses to occur.
958    ///
959    /// Naga cannot feasibly anticipate whether the MSL compiler will determine
960    /// that a loop is infinite, so an attacker could craft a Naga module
961    /// containing an infinite loop protected by conditions that cause the Metal
962    /// compiler to remove bounds checks that Naga injected elsewhere in the
963    /// function.
964    ///
965    /// This rewrite could occur even if the conditional assignment appears
966    /// *before* the `while` loop, as long as `i < 10` by the time the loop is
967    /// reached. This would allow the attacker to save the results of
968    /// unauthorized reads somewhere accessible before entering the infinite
969    /// loop. But even worse, the MSL compiler has been observed to simply
970    /// delete the infinite loop entirely, so that even code dominated by the
971    /// loop becomes reachable. This would make the attack even more flexible,
972    /// since shaders that would appear to never terminate would actually exit
973    /// nicely, after having stolen data from elsewhere in the GPU address
974    /// space.
975    ///
976    /// To avoid UB, Naga must persuade the MSL compiler that no loop Naga
977    /// generates is infinite. One approach would be to add inline assembly to
978    /// each loop that is annotated as potentially branching out of the loop,
979    /// but which in fact generates no instructions. Unfortunately, inline
980    /// assembly is not handled correctly by some Metal device drivers.
981    ///
982    /// A previously used approach was to add the following code to the bottom
983    /// of every loop:
984    ///
985    /// ```ignore
986    /// if (volatile bool unpredictable = false; unpredictable)
987    ///     break;
988    /// ```
989    ///
990    /// Although the `if` condition will always be false in any real execution,
991    /// the `volatile` qualifier prevents the compiler from assuming this. Thus,
992    /// it must assume that the `break` might be reached, and hence that the
993    /// loop is not unbounded. This prevents the range analysis impact described
994    /// above. Unfortunately this prevented the compiler from making important,
995    /// and safe, optimizations such as loop unrolling and was observed to
996    /// significantly hurt performance.
997    ///
998    /// Our current approach declares a counter before every loop and
999    /// increments it every iteration, breaking after 2^64 iterations:
1000    ///
1001    /// ```ignore
1002    /// uint2 loop_bound = uint2(0);
1003    /// while (true) {
1004    ///   if (metal::all(loop_bound == uint2(4294967295))) { break; }
1005    ///   loop_bound += uint2(loop_bound.y == 4294967295, 1);
1006    /// }
1007    /// ```
1008    ///
1009    /// This convinces the compiler that the loop is finite and therefore may
1010    /// execute, whilst at the same time allowing optimizations such as loop
1011    /// unrolling. Furthermore the 64-bit counter is large enough it seems
1012    /// implausible that it would affect the execution of any shader.
1013    ///
1014    /// This approach is also used by Chromium WebGPU's Dawn shader compiler:
1015    /// <https://dawn.googlesource.com/dawn/+/d9e2d1f718678ebee0728b999830576c410cce0a/src/tint/lang/core/ir/transform/prevent_infinite_loops.cc>
1016    fn gen_force_bounded_loop_statements(
1017        &mut self,
1018        level: back::Level,
1019        context: &StatementContext,
1020    ) -> Option<(String, String)> {
1021        if !context.expression.force_loop_bounding {
1022            return None;
1023        }
1024
1025        let loop_bound_name = self.namer.call("loop_bound");
1026        // Count down from u32::MAX rather than up from 0 to avoid hang on
1027        // certain Intel drivers. See <https://github.com/gfx-rs/wgpu/issues/7319>.
1028        let decl = format!("{level}uint2 {loop_bound_name} = uint2({}u);", u32::MAX);
1029        let level = level.next();
1030        let break_and_inc = format!(
1031            "{level}if ({NAMESPACE}::all({loop_bound_name} == uint2(0u))) {{ break; }}
1032{level}{loop_bound_name} -= uint2({loop_bound_name}.y == 0u, 1u);"
1033        );
1034
1035        Some((decl, break_and_inc))
1036    }
1037
1038    fn put_call_parameters(
1039        &mut self,
1040        parameters: impl Iterator<Item = Handle<crate::Expression>>,
1041        context: &ExpressionContext,
1042    ) -> BackendResult {
1043        self.put_call_parameters_impl(parameters, context, |writer, context, expr| {
1044            writer.put_expression(expr, context, true)
1045        })
1046    }
1047
1048    fn put_call_parameters_impl<C, E>(
1049        &mut self,
1050        parameters: impl Iterator<Item = Handle<crate::Expression>>,
1051        ctx: &C,
1052        put_expression: E,
1053    ) -> BackendResult
1054    where
1055        E: Fn(&mut Self, &C, Handle<crate::Expression>) -> BackendResult,
1056    {
1057        write!(self.out, "(")?;
1058        for (i, handle) in parameters.enumerate() {
1059            if i != 0 {
1060                write!(self.out, ", ")?;
1061            }
1062            put_expression(self, ctx, handle)?;
1063        }
1064        write!(self.out, ")")?;
1065        Ok(())
1066    }
1067
1068    /// Writes the local variables of the given function, as well as any extra
1069    /// out-of-bounds locals that are needed.
1070    ///
1071    /// The names of the OOB locals are also added to `self.names` at the same
1072    /// time.
1073    fn put_locals(&mut self, context: &ExpressionContext) -> BackendResult {
1074        let oob_local_types = context.oob_local_types();
1075        for &ty in oob_local_types.iter() {
1076            let name_key = NameKey::oob_local_for_type(context.origin, ty);
1077            self.names.insert(name_key, self.namer.call("oob"));
1078        }
1079
1080        for (name_key, ty, init) in context
1081            .function
1082            .local_variables
1083            .iter()
1084            .map(|(local_handle, local)| {
1085                let name_key = NameKey::local(context.origin, local_handle);
1086                (name_key, local.ty, local.init)
1087            })
1088            .chain(oob_local_types.iter().map(|&ty| {
1089                let name_key = NameKey::oob_local_for_type(context.origin, ty);
1090                (name_key, ty, None)
1091            }))
1092        {
1093            let ty_name = TypeContext {
1094                handle: ty,
1095                gctx: context.module.to_ctx(),
1096                names: &self.names,
1097                access: crate::StorageAccess::empty(),
1098                first_time: false,
1099            };
1100            write!(
1101                self.out,
1102                "{}{} {}",
1103                back::INDENT,
1104                ty_name,
1105                self.names[&name_key]
1106            )?;
1107            match init {
1108                Some(value) => {
1109                    write!(self.out, " = ")?;
1110                    self.put_expression(value, context, true)?;
1111                }
1112                None => {
1113                    write!(self.out, " = {{}}")?;
1114                }
1115            };
1116            writeln!(self.out, ";")?;
1117        }
1118        Ok(())
1119    }
1120
1121    fn put_level_of_detail(
1122        &mut self,
1123        level: LevelOfDetail,
1124        context: &ExpressionContext,
1125    ) -> BackendResult {
1126        match level {
1127            LevelOfDetail::Direct(expr) => self.put_expression(expr, context, true)?,
1128            LevelOfDetail::Restricted(load) => write!(self.out, "{}", ClampedLod(load))?,
1129        }
1130        Ok(())
1131    }
1132
1133    fn put_image_query(
1134        &mut self,
1135        image: Handle<crate::Expression>,
1136        query: &str,
1137        level: Option<LevelOfDetail>,
1138        context: &ExpressionContext,
1139    ) -> BackendResult {
1140        self.put_expression(image, context, false)?;
1141        write!(self.out, ".get_{query}(")?;
1142        if let Some(level) = level {
1143            self.put_level_of_detail(level, context)?;
1144        }
1145        write!(self.out, ")")?;
1146        Ok(())
1147    }
1148
1149    fn put_image_size_query(
1150        &mut self,
1151        image: Handle<crate::Expression>,
1152        level: Option<LevelOfDetail>,
1153        kind: crate::ScalarKind,
1154        context: &ExpressionContext,
1155    ) -> BackendResult {
1156        if let crate::TypeInner::Image {
1157            class: crate::ImageClass::External,
1158            ..
1159        } = *context.resolve_type(image)
1160        {
1161            write!(self.out, "{IMAGE_SIZE_EXTERNAL_FUNCTION}(")?;
1162            self.put_expression(image, context, true)?;
1163            write!(self.out, ")")?;
1164            return Ok(());
1165        }
1166
1167        //Note: MSL only has separate width/height/depth queries,
1168        // so compose the result of them.
1169        let dim = match *context.resolve_type(image) {
1170            crate::TypeInner::Image { dim, .. } => dim,
1171            ref other => unreachable!("Unexpected type {:?}", other),
1172        };
1173        let scalar = crate::Scalar { kind, width: 4 };
1174        let coordinate_type = scalar.to_msl_name();
1175        match dim {
1176            crate::ImageDimension::D1 => {
1177                // Since 1D textures never have mipmaps, MSL requires that the
1178                // `level` argument be a constexpr 0. It's simplest for us just
1179                // to pass `None` and omit the level entirely.
1180                if kind == crate::ScalarKind::Uint {
1181                    // No need to construct a vector. No cast needed.
1182                    self.put_image_query(image, "width", None, context)?;
1183                } else {
1184                    // There's no definition for `int` in the `metal` namespace.
1185                    write!(self.out, "int(")?;
1186                    self.put_image_query(image, "width", None, context)?;
1187                    write!(self.out, ")")?;
1188                }
1189            }
1190            crate::ImageDimension::D2 => {
1191                write!(self.out, "{NAMESPACE}::{coordinate_type}2(")?;
1192                self.put_image_query(image, "width", level, context)?;
1193                write!(self.out, ", ")?;
1194                self.put_image_query(image, "height", level, context)?;
1195                write!(self.out, ")")?;
1196            }
1197            crate::ImageDimension::D3 => {
1198                write!(self.out, "{NAMESPACE}::{coordinate_type}3(")?;
1199                self.put_image_query(image, "width", level, context)?;
1200                write!(self.out, ", ")?;
1201                self.put_image_query(image, "height", level, context)?;
1202                write!(self.out, ", ")?;
1203                self.put_image_query(image, "depth", level, context)?;
1204                write!(self.out, ")")?;
1205            }
1206            crate::ImageDimension::Cube => {
1207                write!(self.out, "{NAMESPACE}::{coordinate_type}2(")?;
1208                self.put_image_query(image, "width", level, context)?;
1209                write!(self.out, ")")?;
1210            }
1211        }
1212        Ok(())
1213    }
1214
1215    fn put_cast_to_uint_scalar_or_vector(
1216        &mut self,
1217        expr: Handle<crate::Expression>,
1218        context: &ExpressionContext,
1219    ) -> BackendResult {
1220        // coordinates in IR are int, but Metal expects uint
1221        match *context.resolve_type(expr) {
1222            crate::TypeInner::Scalar(_) => {
1223                put_numeric_type(&mut self.out, crate::Scalar::U32, &[])?
1224            }
1225            crate::TypeInner::Vector { size, .. } => {
1226                put_numeric_type(&mut self.out, crate::Scalar::U32, &[size])?
1227            }
1228            _ => {
1229                return Err(Error::GenericValidation(
1230                    "Invalid type for image coordinate".into(),
1231                ))
1232            }
1233        };
1234
1235        write!(self.out, "(")?;
1236        self.put_expression(expr, context, true)?;
1237        write!(self.out, ")")?;
1238        Ok(())
1239    }
1240
1241    fn put_image_sample_level(
1242        &mut self,
1243        image: Handle<crate::Expression>,
1244        level: crate::SampleLevel,
1245        context: &ExpressionContext,
1246    ) -> BackendResult {
1247        let has_levels = context.image_needs_lod(image);
1248        match level {
1249            crate::SampleLevel::Auto => {}
1250            crate::SampleLevel::Zero => {
1251                //TODO: do we support Zero on `Sampled` image classes?
1252            }
1253            _ if !has_levels => {
1254                log::warn!("1D image can't be sampled with level {level:?}");
1255            }
1256            crate::SampleLevel::Exact(h) => {
1257                write!(self.out, ", {NAMESPACE}::level(")?;
1258                self.put_expression(h, context, true)?;
1259                write!(self.out, ")")?;
1260            }
1261            crate::SampleLevel::Bias(h) => {
1262                write!(self.out, ", {NAMESPACE}::bias(")?;
1263                self.put_expression(h, context, true)?;
1264                write!(self.out, ")")?;
1265            }
1266            crate::SampleLevel::Gradient { x, y } => {
1267                write!(self.out, ", {NAMESPACE}::gradient2d(")?;
1268                self.put_expression(x, context, true)?;
1269                write!(self.out, ", ")?;
1270                self.put_expression(y, context, true)?;
1271                write!(self.out, ")")?;
1272            }
1273        }
1274        Ok(())
1275    }
1276
1277    fn put_image_coordinate_limits(
1278        &mut self,
1279        image: Handle<crate::Expression>,
1280        level: Option<LevelOfDetail>,
1281        context: &ExpressionContext,
1282    ) -> BackendResult {
1283        self.put_image_size_query(image, level, crate::ScalarKind::Uint, context)?;
1284        write!(self.out, " - 1")?;
1285        Ok(())
1286    }
1287
1288    /// General function for writing restricted image indexes.
1289    ///
1290    /// This is used to produce restricted mip levels, array indices, and sample
1291    /// indices for [`ImageLoad`] and [`ImageStore`] accesses under the
1292    /// [`Restrict`] bounds check policy.
1293    ///
1294    /// This function writes an expression of the form:
1295    ///
1296    /// ```ignore
1297    ///
1298    ///     metal::min(uint(INDEX), IMAGE.LIMIT_METHOD() - 1)
1299    ///
1300    /// ```
1301    ///
1302    /// [`ImageLoad`]: crate::Expression::ImageLoad
1303    /// [`ImageStore`]: crate::Statement::ImageStore
1304    /// [`Restrict`]: index::BoundsCheckPolicy::Restrict
1305    fn put_restricted_scalar_image_index(
1306        &mut self,
1307        image: Handle<crate::Expression>,
1308        index: Handle<crate::Expression>,
1309        limit_method: &str,
1310        context: &ExpressionContext,
1311    ) -> BackendResult {
1312        write!(self.out, "{NAMESPACE}::min(uint(")?;
1313        self.put_expression(index, context, true)?;
1314        write!(self.out, "), ")?;
1315        self.put_expression(image, context, false)?;
1316        write!(self.out, ".{limit_method}() - 1)")?;
1317        Ok(())
1318    }
1319
1320    fn put_restricted_texel_address(
1321        &mut self,
1322        image: Handle<crate::Expression>,
1323        address: &TexelAddress,
1324        context: &ExpressionContext,
1325    ) -> BackendResult {
1326        // Write the coordinate.
1327        write!(self.out, "{NAMESPACE}::min(")?;
1328        self.put_cast_to_uint_scalar_or_vector(address.coordinate, context)?;
1329        write!(self.out, ", ")?;
1330        self.put_image_coordinate_limits(image, address.level, context)?;
1331        write!(self.out, ")")?;
1332
1333        // Write the array index, if present.
1334        if let Some(array_index) = address.array_index {
1335            write!(self.out, ", ")?;
1336            self.put_restricted_scalar_image_index(image, array_index, "get_array_size", context)?;
1337        }
1338
1339        // Write the sample index, if present.
1340        if let Some(sample) = address.sample {
1341            write!(self.out, ", ")?;
1342            self.put_restricted_scalar_image_index(image, sample, "get_num_samples", context)?;
1343        }
1344
1345        // The level of detail should be clamped and cached by
1346        // `put_cache_restricted_level`, so we don't need to clamp it here.
1347        if let Some(level) = address.level {
1348            write!(self.out, ", ")?;
1349            self.put_level_of_detail(level, context)?;
1350        }
1351
1352        Ok(())
1353    }
1354
1355    /// Write an expression that is true if the given image access is in bounds.
1356    fn put_image_access_bounds_check(
1357        &mut self,
1358        image: Handle<crate::Expression>,
1359        address: &TexelAddress,
1360        context: &ExpressionContext,
1361    ) -> BackendResult {
1362        let mut conjunction = "";
1363
1364        // First, check the level of detail. Only if that is in bounds can we
1365        // use it to find the appropriate bounds for the coordinates.
1366        let level = if let Some(level) = address.level {
1367            write!(self.out, "uint(")?;
1368            self.put_level_of_detail(level, context)?;
1369            write!(self.out, ") < ")?;
1370            self.put_expression(image, context, true)?;
1371            write!(self.out, ".get_num_mip_levels()")?;
1372            conjunction = " && ";
1373            Some(level)
1374        } else {
1375            None
1376        };
1377
1378        // Check sample index, if present.
1379        if let Some(sample) = address.sample {
1380            write!(self.out, "uint(")?;
1381            self.put_expression(sample, context, true)?;
1382            write!(self.out, ") < ")?;
1383            self.put_expression(image, context, true)?;
1384            write!(self.out, ".get_num_samples()")?;
1385            conjunction = " && ";
1386        }
1387
1388        // Check array index, if present.
1389        if let Some(array_index) = address.array_index {
1390            write!(self.out, "{conjunction}uint(")?;
1391            self.put_expression(array_index, context, true)?;
1392            write!(self.out, ") < ")?;
1393            self.put_expression(image, context, true)?;
1394            write!(self.out, ".get_array_size()")?;
1395            conjunction = " && ";
1396        }
1397
1398        // Finally, check if the coordinates are within bounds.
1399        let coord_is_vector = match *context.resolve_type(address.coordinate) {
1400            crate::TypeInner::Vector { .. } => true,
1401            _ => false,
1402        };
1403        write!(self.out, "{conjunction}")?;
1404        if coord_is_vector {
1405            write!(self.out, "{NAMESPACE}::all(")?;
1406        }
1407        self.put_cast_to_uint_scalar_or_vector(address.coordinate, context)?;
1408        write!(self.out, " < ")?;
1409        self.put_image_size_query(image, level, crate::ScalarKind::Uint, context)?;
1410        if coord_is_vector {
1411            write!(self.out, ")")?;
1412        }
1413
1414        Ok(())
1415    }
1416
1417    fn put_image_load(
1418        &mut self,
1419        load: Handle<crate::Expression>,
1420        image: Handle<crate::Expression>,
1421        mut address: TexelAddress,
1422        context: &ExpressionContext,
1423    ) -> BackendResult {
1424        if let crate::TypeInner::Image {
1425            class: crate::ImageClass::External,
1426            ..
1427        } = *context.resolve_type(image)
1428        {
1429            write!(self.out, "{IMAGE_LOAD_EXTERNAL_FUNCTION}(")?;
1430            self.put_expression(image, context, true)?;
1431            write!(self.out, ", ")?;
1432            self.put_cast_to_uint_scalar_or_vector(address.coordinate, context)?;
1433            write!(self.out, ")")?;
1434            return Ok(());
1435        }
1436
1437        match context.policies.image_load {
1438            proc::BoundsCheckPolicy::Restrict => {
1439                // Use the cached restricted level of detail, if any. Omit the
1440                // level altogether for 1D textures.
1441                if address.level.is_some() {
1442                    address.level = if context.image_needs_lod(image) {
1443                        Some(LevelOfDetail::Restricted(load))
1444                    } else {
1445                        None
1446                    }
1447                }
1448
1449                self.put_expression(image, context, false)?;
1450                write!(self.out, ".read(")?;
1451                self.put_restricted_texel_address(image, &address, context)?;
1452                write!(self.out, ")")?;
1453            }
1454            proc::BoundsCheckPolicy::ReadZeroSkipWrite => {
1455                write!(self.out, "(")?;
1456                self.put_image_access_bounds_check(image, &address, context)?;
1457                write!(self.out, " ? ")?;
1458                self.put_unchecked_image_load(image, &address, context)?;
1459                write!(self.out, ": DefaultConstructible())")?;
1460            }
1461            proc::BoundsCheckPolicy::Unchecked => {
1462                self.put_unchecked_image_load(image, &address, context)?;
1463            }
1464        }
1465
1466        Ok(())
1467    }
1468
1469    fn put_unchecked_image_load(
1470        &mut self,
1471        image: Handle<crate::Expression>,
1472        address: &TexelAddress,
1473        context: &ExpressionContext,
1474    ) -> BackendResult {
1475        self.put_expression(image, context, false)?;
1476        write!(self.out, ".read(")?;
1477        // coordinates in IR are int, but Metal expects uint
1478        self.put_cast_to_uint_scalar_or_vector(address.coordinate, context)?;
1479        if let Some(expr) = address.array_index {
1480            write!(self.out, ", ")?;
1481            self.put_expression(expr, context, true)?;
1482        }
1483        if let Some(sample) = address.sample {
1484            write!(self.out, ", ")?;
1485            self.put_expression(sample, context, true)?;
1486        }
1487        if let Some(level) = address.level {
1488            if context.image_needs_lod(image) {
1489                write!(self.out, ", ")?;
1490                self.put_level_of_detail(level, context)?;
1491            }
1492        }
1493        write!(self.out, ")")?;
1494
1495        Ok(())
1496    }
1497
1498    fn put_image_atomic(
1499        &mut self,
1500        level: back::Level,
1501        image: Handle<crate::Expression>,
1502        address: &TexelAddress,
1503        fun: crate::AtomicFunction,
1504        value: Handle<crate::Expression>,
1505        context: &StatementContext,
1506    ) -> BackendResult {
1507        write!(self.out, "{level}")?;
1508        self.put_expression(image, &context.expression, false)?;
1509        let op = if context.expression.resolve_type(value).scalar_width() == Some(8) {
1510            fun.to_msl_64_bit()?
1511        } else {
1512            fun.to_msl()
1513        };
1514        write!(self.out, ".atomic_{op}(")?;
1515        // coordinates in IR are int, but Metal expects uint
1516        self.put_cast_to_uint_scalar_or_vector(address.coordinate, &context.expression)?;
1517        write!(self.out, ", ")?;
1518        self.put_expression(value, &context.expression, true)?;
1519        writeln!(self.out, ");")?;
1520
1521        // Workaround for Apple Metal TBDR driver bug: fragment shader atomic
1522        // texture writes randomly drop unless followed by a standard texture
1523        // write. Insert a dead-code write behind an unprovable condition so
1524        // the compiler emits proper memory safety barriers.
1525        // See: https://projects.blender.org/blender/blender/commit/aa95220576706122d79c91c7f5c522e6c7416425
1526        let value_ty = context.expression.resolve_type(value);
1527        let zero_value = match (value_ty.scalar_kind(), value_ty.scalar_width()) {
1528            (Some(crate::ScalarKind::Sint), _) => "int4(0)",
1529            (_, Some(8)) => "ulong4(0uL)",
1530            _ => "uint4(0u)",
1531        };
1532        let coord_ty = context.expression.resolve_type(address.coordinate);
1533        let x = if matches!(coord_ty, crate::TypeInner::Scalar(_)) {
1534            ""
1535        } else {
1536            ".x"
1537        };
1538        write!(self.out, "{level}if (")?;
1539        self.put_expression(address.coordinate, &context.expression, true)?;
1540        write!(self.out, "{x} == -99999) {{ ")?;
1541        self.put_expression(image, &context.expression, false)?;
1542        write!(self.out, ".write({zero_value}, ")?;
1543        self.put_cast_to_uint_scalar_or_vector(address.coordinate, &context.expression)?;
1544        if let Some(array_index) = address.array_index {
1545            write!(self.out, ", ")?;
1546            self.put_expression(array_index, &context.expression, true)?;
1547        }
1548        writeln!(self.out, "); }}")?;
1549
1550        Ok(())
1551    }
1552
1553    fn put_image_store(
1554        &mut self,
1555        level: back::Level,
1556        image: Handle<crate::Expression>,
1557        address: &TexelAddress,
1558        value: Handle<crate::Expression>,
1559        context: &StatementContext,
1560    ) -> BackendResult {
1561        write!(self.out, "{level}")?;
1562        self.put_expression(image, &context.expression, false)?;
1563        write!(self.out, ".write(")?;
1564        self.put_expression(value, &context.expression, true)?;
1565        write!(self.out, ", ")?;
1566        // coordinates in IR are int, but Metal expects uint
1567        self.put_cast_to_uint_scalar_or_vector(address.coordinate, &context.expression)?;
1568        if let Some(expr) = address.array_index {
1569            write!(self.out, ", ")?;
1570            self.put_expression(expr, &context.expression, true)?;
1571        }
1572        writeln!(self.out, ");")?;
1573
1574        Ok(())
1575    }
1576
1577    /// Write the maximum valid index of the dynamically sized array at the end of `handle`.
1578    ///
1579    /// The 'maximum valid index' is simply one less than the array's length.
1580    ///
1581    /// This emits an expression of the form `a / b`, so the caller must
1582    /// parenthesize its output if it will be applying operators of higher
1583    /// precedence.
1584    ///
1585    /// `handle` must be the handle of a global variable whose final member is a
1586    /// dynamically sized array.
1587    fn put_dynamic_array_max_index(
1588        &mut self,
1589        handle: Handle<crate::GlobalVariable>,
1590        context: &ExpressionContext,
1591    ) -> BackendResult {
1592        let global = &context.module.global_variables[handle];
1593        let (offset, array_ty) = match context.module.types[global.ty].inner {
1594            crate::TypeInner::Struct { ref members, .. } => match members.last() {
1595                Some(&crate::StructMember { offset, ty, .. }) => (offset, ty),
1596                None => return Err(Error::GenericValidation("Struct has no members".into())),
1597            },
1598            crate::TypeInner::Array {
1599                size: crate::ArraySize::Dynamic,
1600                ..
1601            } => (0, global.ty),
1602            ref ty => {
1603                return Err(Error::GenericValidation(format!(
1604                    "Expected type with dynamic array, got {ty:?}"
1605                )))
1606            }
1607        };
1608
1609        let (size, stride) = match context.module.types[array_ty].inner {
1610            crate::TypeInner::Array { base, stride, .. } => (
1611                context.module.types[base]
1612                    .inner
1613                    .size(context.module.to_ctx()),
1614                stride,
1615            ),
1616            ref ty => {
1617                return Err(Error::GenericValidation(format!(
1618                    "Expected array type, got {ty:?}"
1619                )))
1620            }
1621        };
1622
1623        // When the stride length is larger than the size, the final element's stride of
1624        // bytes would have padding following the value. But the buffer size in
1625        // `buffer_sizes.sizeN` may not include this padding - it only needs to be large
1626        // enough to hold the actual values' bytes.
1627        //
1628        // So subtract off the size to get a byte size that falls at the start or within
1629        // the final element. Then divide by the stride size, to get one less than the
1630        // length, and then add one. This works even if the buffer size does include the
1631        // stride padding, since division rounds towards zero (MSL 2.4 §6.1). It will fail
1632        // if there are zero elements in the array, but the WebGPU `validating shader binding`
1633        // rules, together with draw-time validation when `minBindingSize` is zero,
1634        // prevent that.
1635        write!(
1636            self.out,
1637            "(_buffer_sizes.{member} - {offset} - {size}) / {stride}",
1638            member = ArraySizeMember(handle),
1639            offset = offset,
1640            size = size,
1641            stride = stride,
1642        )?;
1643        Ok(())
1644    }
1645
1646    /// Emit code for the arithmetic expression of the dot product.
1647    ///
1648    /// The argument `extractor` is a function that accepts a `Writer`, a vector, and
1649    /// an index. It writes out the expression for the vector component at that index.
1650    fn put_dot_product<T: Copy>(
1651        &mut self,
1652        arg: T,
1653        arg1: T,
1654        size: usize,
1655        extractor: impl Fn(&mut Self, T, usize) -> BackendResult,
1656    ) -> BackendResult {
1657        // Write parentheses around the dot product expression to prevent operators
1658        // with different precedences from applying earlier.
1659        write!(self.out, "(")?;
1660
1661        // Cycle through all the components of the vector
1662        for index in 0..size {
1663            // Write the addition to the previous product
1664            // This will print an extra '+' at the beginning but that is fine in msl
1665            write!(self.out, " + ")?;
1666            extractor(self, arg, index)?;
1667            write!(self.out, " * ")?;
1668            extractor(self, arg1, index)?;
1669        }
1670
1671        write!(self.out, ")")?;
1672        Ok(())
1673    }
1674
1675    /// Emit code for the WGSL functions `pack4x{I, U}8[Clamp]`.
1676    fn put_pack4x8(
1677        &mut self,
1678        arg: Handle<crate::Expression>,
1679        context: &ExpressionContext<'_>,
1680        was_signed: bool,
1681        clamp_bounds: Option<(&str, &str)>,
1682    ) -> Result<(), Error> {
1683        let write_arg = |this: &mut Self| -> BackendResult {
1684            if let Some((min, max)) = clamp_bounds {
1685                // Clamping with scalar bounds works (component-wise) even for packed_[u]char4.
1686                write!(this.out, "{NAMESPACE}::clamp(")?;
1687                this.put_expression(arg, context, true)?;
1688                write!(this.out, ", {min}, {max})")?;
1689            } else {
1690                this.put_expression(arg, context, true)?;
1691            }
1692            Ok(())
1693        };
1694
1695        if context.lang_version >= (2, 1) {
1696            let packed_type = if was_signed {
1697                "packed_char4"
1698            } else {
1699                "packed_uchar4"
1700            };
1701            // Metal uses little endian byte order, which matches what WGSL expects here.
1702            write!(self.out, "as_type<uint>({packed_type}(")?;
1703            write_arg(self)?;
1704            write!(self.out, "))")?;
1705        } else {
1706            // MSL < 2.1 doesn't support `as_type` casting between packed chars and scalars.
1707            if was_signed {
1708                write!(self.out, "uint(")?;
1709            }
1710            write!(self.out, "(")?;
1711            write_arg(self)?;
1712            write!(self.out, "[0] & 0xFF) | ((")?;
1713            write_arg(self)?;
1714            write!(self.out, "[1] & 0xFF) << 8) | ((")?;
1715            write_arg(self)?;
1716            write!(self.out, "[2] & 0xFF) << 16) | ((")?;
1717            write_arg(self)?;
1718            write!(self.out, "[3] & 0xFF) << 24)")?;
1719            if was_signed {
1720                write!(self.out, ")")?;
1721            }
1722        }
1723
1724        Ok(())
1725    }
1726
1727    /// Emit code for the isign expression.
1728    ///
1729    fn put_isign(
1730        &mut self,
1731        arg: Handle<crate::Expression>,
1732        context: &ExpressionContext,
1733    ) -> BackendResult {
1734        write!(self.out, "{NAMESPACE}::select({NAMESPACE}::select(")?;
1735        let scalar = context
1736            .resolve_type(arg)
1737            .scalar()
1738            .expect("put_isign should only be called for args which have an integer scalar type")
1739            .to_msl_name();
1740        match context.resolve_type(arg) {
1741            &crate::TypeInner::Vector { size, .. } => {
1742                let size = common::vector_size_str(size);
1743                write!(self.out, "{scalar}{size}(-1), {scalar}{size}(1)")?;
1744            }
1745            _ => {
1746                write!(self.out, "{scalar}(-1), {scalar}(1)")?;
1747            }
1748        }
1749        write!(self.out, ", (")?;
1750        self.put_expression(arg, context, true)?;
1751        write!(self.out, " > 0)), {scalar}(0), (")?;
1752        self.put_expression(arg, context, true)?;
1753        write!(self.out, " == 0))")?;
1754        Ok(())
1755    }
1756
1757    pub(super) fn put_const_expression(
1758        &mut self,
1759        expr_handle: Handle<crate::Expression>,
1760        module: &crate::Module,
1761        mod_info: &valid::ModuleInfo,
1762        arena: &crate::Arena<crate::Expression>,
1763    ) -> BackendResult {
1764        self.put_possibly_const_expression(
1765            expr_handle,
1766            arena,
1767            module,
1768            mod_info,
1769            &(module, mod_info),
1770            |&(_, mod_info), expr| &mod_info[expr],
1771            |writer, &(module, _), expr| writer.put_const_expression(expr, module, mod_info, arena),
1772        )
1773    }
1774
1775    fn put_literal(&mut self, literal: crate::Literal) -> BackendResult {
1776        match literal {
1777            crate::Literal::F64(_) => {
1778                return Err(Error::CapabilityNotSupported(valid::Capabilities::FLOAT64))
1779            }
1780            crate::Literal::F16(value) => {
1781                if value.is_infinite() {
1782                    let sign = if value.is_sign_negative() { "-" } else { "" };
1783                    write!(self.out, "{sign}INFINITY")?;
1784                } else if value.is_nan() {
1785                    write!(self.out, "NAN")?;
1786                } else {
1787                    let suffix = if value.fract() == f16::from_f32(0.0) {
1788                        ".0h"
1789                    } else {
1790                        "h"
1791                    };
1792                    write!(self.out, "{value}{suffix}")?;
1793                }
1794            }
1795            crate::Literal::F32(value) => {
1796                if value.is_infinite() {
1797                    let sign = if value.is_sign_negative() { "-" } else { "" };
1798                    write!(self.out, "{sign}INFINITY")?;
1799                } else if value.is_nan() {
1800                    write!(self.out, "NAN")?;
1801                } else {
1802                    let suffix = if value.fract() == 0.0 { ".0" } else { "" };
1803                    write!(self.out, "{value}{suffix}")?;
1804                }
1805            }
1806            crate::Literal::U16(value) => {
1807                write!(self.out, "static_cast<ushort>({value})")?;
1808            }
1809            crate::Literal::I16(value) => {
1810                write!(self.out, "static_cast<short>({value})")?;
1811            }
1812            crate::Literal::U32(value) => {
1813                write!(self.out, "{value}u")?;
1814            }
1815            crate::Literal::I32(value) => {
1816                // `-2147483648` is parsed as unary negation of positive 2147483648.
1817                // 2147483648 is too large for int32_t meaning the expression gets
1818                // promoted to a int64_t which is not our intention. Avoid this by instead
1819                // using `-2147483647 - 1`.
1820                if value == i32::MIN {
1821                    write!(self.out, "({} - 1)", value + 1)?;
1822                } else {
1823                    write!(self.out, "{value}")?;
1824                }
1825            }
1826            crate::Literal::U64(value) => {
1827                write!(self.out, "{value}uL")?;
1828            }
1829            crate::Literal::I64(value) => {
1830                // `-9223372036854775808` is parsed as unary negation of positive
1831                // 9223372036854775808. 9223372036854775808 is too large for int64_t
1832                // causing Metal to emit a `-Wconstant-conversion` warning, and change the
1833                // value to `-9223372036854775808`. Which would then be negated, possibly
1834                // causing undefined behaviour. Avoid this by instead using
1835                // `-9223372036854775808L - 1L`.
1836                if value == i64::MIN {
1837                    write!(self.out, "({}L - 1L)", value + 1)?;
1838                } else {
1839                    write!(self.out, "{value}L")?;
1840                }
1841            }
1842            crate::Literal::Bool(value) => {
1843                write!(self.out, "{value}")?;
1844            }
1845            crate::Literal::AbstractInt(_) | crate::Literal::AbstractFloat(_) => {
1846                return Err(Error::GenericValidation(
1847                    "Unsupported abstract literal".into(),
1848                ));
1849            }
1850        }
1851        Ok(())
1852    }
1853
1854    #[allow(clippy::too_many_arguments)]
1855    fn put_possibly_const_expression<C, I, E>(
1856        &mut self,
1857        expr_handle: Handle<crate::Expression>,
1858        expressions: &crate::Arena<crate::Expression>,
1859        module: &crate::Module,
1860        mod_info: &valid::ModuleInfo,
1861        ctx: &C,
1862        get_expr_ty: I,
1863        put_expression: E,
1864    ) -> BackendResult
1865    where
1866        I: Fn(&C, Handle<crate::Expression>) -> &TypeResolution,
1867        E: Fn(&mut Self, &C, Handle<crate::Expression>) -> BackendResult,
1868    {
1869        match expressions[expr_handle] {
1870            crate::Expression::Literal(literal) => {
1871                self.put_literal(literal)?;
1872            }
1873            crate::Expression::Constant(handle) => {
1874                let constant = &module.constants[handle];
1875                if constant.name.is_some() {
1876                    write!(self.out, "{}", self.names[&NameKey::Constant(handle)])?;
1877                } else {
1878                    self.put_const_expression(
1879                        constant.init,
1880                        module,
1881                        mod_info,
1882                        &module.global_expressions,
1883                    )?;
1884                }
1885            }
1886            crate::Expression::ZeroValue(ty) => {
1887                let ty_name = TypeContext {
1888                    handle: ty,
1889                    gctx: module.to_ctx(),
1890                    names: &self.names,
1891                    access: crate::StorageAccess::empty(),
1892                    first_time: false,
1893                };
1894                write!(self.out, "{ty_name} {{}}")?;
1895            }
1896            crate::Expression::Compose { ty, ref components } => {
1897                let ty_name = TypeContext {
1898                    handle: ty,
1899                    gctx: module.to_ctx(),
1900                    names: &self.names,
1901                    access: crate::StorageAccess::empty(),
1902                    first_time: false,
1903                };
1904                write!(self.out, "{ty_name}")?;
1905                match module.types[ty].inner {
1906                    crate::TypeInner::Scalar(_)
1907                    | crate::TypeInner::Vector { .. }
1908                    | crate::TypeInner::Matrix { .. } => {
1909                        self.put_call_parameters_impl(
1910                            components.iter().copied(),
1911                            ctx,
1912                            put_expression,
1913                        )?;
1914                    }
1915                    crate::TypeInner::Array { .. } => {
1916                        // Naga Arrays are Metal arrays wrapped in structs, so
1917                        // we need two levels of braces.
1918                        write!(self.out, " {{{{")?;
1919                        for (index, &component) in components.iter().enumerate() {
1920                            if index != 0 {
1921                                write!(self.out, ", ")?;
1922                            }
1923                            put_expression(self, ctx, component)?;
1924                        }
1925                        write!(self.out, "}}}}")?;
1926                    }
1927                    crate::TypeInner::Struct { .. } => {
1928                        write!(self.out, " {{")?;
1929                        for (index, &component) in components.iter().enumerate() {
1930                            if index != 0 {
1931                                write!(self.out, ", ")?;
1932                            }
1933                            // insert padding initialization, if needed
1934                            if self.struct_member_pads.contains(&(ty, index as u32)) {
1935                                write!(self.out, "{{}}, ")?;
1936                            }
1937                            put_expression(self, ctx, component)?;
1938                        }
1939                        write!(self.out, "}}")?;
1940                    }
1941                    _ => return Err(Error::UnsupportedCompose(ty)),
1942                }
1943            }
1944            crate::Expression::Splat { size, value } => {
1945                let scalar = match *get_expr_ty(ctx, value).inner_with(&module.types) {
1946                    crate::TypeInner::Scalar(scalar) => scalar,
1947                    ref ty => {
1948                        return Err(Error::GenericValidation(format!(
1949                            "Expected splat value type must be a scalar, got {ty:?}",
1950                        )))
1951                    }
1952                };
1953                put_numeric_type(&mut self.out, scalar, &[size])?;
1954                write!(self.out, "(")?;
1955                put_expression(self, ctx, value)?;
1956                write!(self.out, ")")?;
1957            }
1958            _ => {
1959                return Err(Error::Override);
1960            }
1961        }
1962
1963        Ok(())
1964    }
1965
1966    /// Emit code for the expression `expr_handle`.
1967    ///
1968    /// The `is_scoped` argument is true if the surrounding operators have the
1969    /// precedence of the comma operator, or lower. So, for example:
1970    ///
1971    /// - Pass `true` for `is_scoped` when writing function arguments, an
1972    ///   expression statement, an initializer expression, or anything already
1973    ///   wrapped in parenthesis.
1974    ///
1975    /// - Pass `false` if it is an operand of a `?:` operator, a `[]`, or really
1976    ///   almost anything else.
1977    pub(super) fn put_expression(
1978        &mut self,
1979        expr_handle: Handle<crate::Expression>,
1980        context: &ExpressionContext,
1981        is_scoped: bool,
1982    ) -> BackendResult {
1983        // Add to the set in order to track the stack size.
1984        #[cfg(test)]
1985        self.put_expression_stack_pointers
1986            .insert(ptr::from_ref(&expr_handle).cast());
1987
1988        if let Some(name) = self.named_expressions.get(&expr_handle) {
1989            write!(self.out, "{name}")?;
1990            return Ok(());
1991        }
1992
1993        let expression = &context.function.expressions[expr_handle];
1994        match *expression {
1995            crate::Expression::Literal(_)
1996            | crate::Expression::Constant(_)
1997            | crate::Expression::ZeroValue(_)
1998            | crate::Expression::Compose { .. }
1999            | crate::Expression::Splat { .. } => {
2000                self.put_possibly_const_expression(
2001                    expr_handle,
2002                    &context.function.expressions,
2003                    context.module,
2004                    context.mod_info,
2005                    context,
2006                    |context, expr: Handle<crate::Expression>| &context.info[expr].ty,
2007                    |writer, context, expr| writer.put_expression(expr, context, true),
2008                )?;
2009            }
2010            crate::Expression::Override(_) => return Err(Error::Override),
2011            crate::Expression::Access { base, .. }
2012            | crate::Expression::AccessIndex { base, .. } => {
2013                // This is an acceptable place to generate a `ReadZeroSkipWrite` check.
2014                // Since `put_bounds_checks` and `put_access_chain` handle an entire
2015                // access chain at a time, recursing back through `put_expression` only
2016                // for index expressions and the base object, we will never see intermediate
2017                // `Access` or `AccessIndex` expressions here.
2018                let policy = context.choose_bounds_check_policy(base);
2019                if policy == index::BoundsCheckPolicy::ReadZeroSkipWrite
2020                    && self.put_bounds_checks(
2021                        expr_handle,
2022                        context,
2023                        back::Level(0),
2024                        if is_scoped { "" } else { "(" },
2025                    )?
2026                {
2027                    write!(self.out, " ? ")?;
2028                    self.put_access_chain(expr_handle, policy, context)?;
2029                    write!(self.out, " : ")?;
2030
2031                    if context.resolve_type(base).pointer_space().is_some() {
2032                        // We can't just use `DefaultConstructible` if this is a pointer.
2033                        // Instead, we create a dummy local variable to serve as pointer
2034                        // target if the access is out of bounds.
2035                        let result_ty = context.info[expr_handle]
2036                            .ty
2037                            .inner_with(&context.module.types)
2038                            .pointer_base_type();
2039                        let result_ty_handle = match result_ty {
2040                            Some(TypeResolution::Handle(handle)) => handle,
2041                            Some(TypeResolution::Value(_)) => {
2042                                // As long as the result of a pointer access expression is
2043                                // passed to a function or stored in a let binding, the
2044                                // type will be in the arena. If additional uses of
2045                                // pointers become valid, this assumption might no longer
2046                                // hold. Note that the LHS of a load or store doesn't
2047                                // take this path -- there is dedicated code in `put_load`
2048                                // and `put_store`.
2049                                unreachable!(
2050                                    "Expected type {result_ty:?} of access through pointer type {base:?} to be in the arena",
2051                                );
2052                            }
2053                            None => {
2054                                unreachable!(
2055                                    "Expected access through pointer type {base:?} to return a pointer, but got {result_ty:?}",
2056                                )
2057                            }
2058                        };
2059                        let name_key =
2060                            NameKey::oob_local_for_type(context.origin, result_ty_handle);
2061                        self.out.write_str(&self.names[&name_key])?;
2062                    } else {
2063                        write!(self.out, "DefaultConstructible()")?;
2064                    }
2065
2066                    if !is_scoped {
2067                        write!(self.out, ")")?;
2068                    }
2069                } else {
2070                    self.put_access_chain(expr_handle, policy, context)?;
2071                }
2072            }
2073            crate::Expression::Swizzle {
2074                size,
2075                vector,
2076                pattern,
2077            } => {
2078                self.put_wrapped_expression_for_packed_vec3_access(
2079                    vector,
2080                    context,
2081                    false,
2082                    &Self::put_expression,
2083                )?;
2084                write!(self.out, ".")?;
2085                for &sc in pattern[..size as usize].iter() {
2086                    write!(self.out, "{}", back::COMPONENTS[sc as usize])?;
2087                }
2088            }
2089            crate::Expression::FunctionArgument(index) => {
2090                let name_key = match context.origin {
2091                    FunctionOrigin::Handle(handle) => NameKey::FunctionArgument(handle, index),
2092                    FunctionOrigin::EntryPoint(ep_index) => {
2093                        NameKey::EntryPointArgument(ep_index, index)
2094                    }
2095                };
2096                let name = &self.names[&name_key];
2097                write!(self.out, "{name}")?;
2098            }
2099            crate::Expression::GlobalVariable(handle) => {
2100                let name = &self.names[&NameKey::GlobalVariable(handle)];
2101                write!(self.out, "{name}")?;
2102            }
2103            crate::Expression::LocalVariable(handle) => {
2104                let name_key = NameKey::local(context.origin, handle);
2105                let name = &self.names[&name_key];
2106                write!(self.out, "{name}")?;
2107            }
2108            crate::Expression::Load { pointer } => self.put_load(pointer, context, is_scoped)?,
2109            crate::Expression::ImageSample {
2110                coordinate,
2111                image,
2112                sampler,
2113                clamp_to_edge: true,
2114                gather: None,
2115                array_index: None,
2116                offset: None,
2117                level: crate::SampleLevel::Zero,
2118                depth_ref: None,
2119            } => {
2120                write!(self.out, "{IMAGE_SAMPLE_BASE_CLAMP_TO_EDGE_FUNCTION}(")?;
2121                self.put_expression(image, context, true)?;
2122                write!(self.out, ", ")?;
2123                self.put_expression(sampler, context, true)?;
2124                write!(self.out, ", ")?;
2125                self.put_expression(coordinate, context, true)?;
2126                write!(self.out, ")")?;
2127            }
2128            crate::Expression::ImageSample {
2129                image,
2130                sampler,
2131                gather,
2132                coordinate,
2133                array_index,
2134                offset,
2135                level,
2136                depth_ref,
2137                clamp_to_edge,
2138            } => {
2139                if clamp_to_edge {
2140                    return Err(Error::GenericValidation(
2141                        "ImageSample::clamp_to_edge should have been validated out".to_string(),
2142                    ));
2143                }
2144
2145                let main_op = match gather {
2146                    Some(_) => "gather",
2147                    None => "sample",
2148                };
2149                let comparison_op = match depth_ref {
2150                    Some(_) => "_compare",
2151                    None => "",
2152                };
2153                self.put_expression(image, context, false)?;
2154                write!(self.out, ".{main_op}{comparison_op}(")?;
2155                self.put_expression(sampler, context, true)?;
2156                write!(self.out, ", ")?;
2157                self.put_expression(coordinate, context, true)?;
2158                if let Some(expr) = array_index {
2159                    write!(self.out, ", ")?;
2160                    self.put_expression(expr, context, true)?;
2161                }
2162                if let Some(dref) = depth_ref {
2163                    write!(self.out, ", ")?;
2164                    self.put_expression(dref, context, true)?;
2165                }
2166
2167                self.put_image_sample_level(image, level, context)?;
2168
2169                if let Some(offset) = offset {
2170                    write!(self.out, ", ")?;
2171                    self.put_expression(offset, context, true)?;
2172                }
2173
2174                match gather {
2175                    None | Some(crate::SwizzleComponent::X) => {}
2176                    Some(component) => {
2177                        let is_cube_map = match *context.resolve_type(image) {
2178                            crate::TypeInner::Image {
2179                                dim: crate::ImageDimension::Cube,
2180                                ..
2181                            } => true,
2182                            _ => false,
2183                        };
2184                        // Offset always comes before the gather, except
2185                        // in cube maps where it's not applicable
2186                        if offset.is_none() && !is_cube_map {
2187                            write!(self.out, ", {NAMESPACE}::int2(0)")?;
2188                        }
2189                        let letter = back::COMPONENTS[component as usize];
2190                        write!(self.out, ", {NAMESPACE}::component::{letter}")?;
2191                    }
2192                }
2193                write!(self.out, ")")?;
2194            }
2195            crate::Expression::ImageLoad {
2196                image,
2197                coordinate,
2198                array_index,
2199                sample,
2200                level,
2201            } => {
2202                let address = TexelAddress {
2203                    coordinate,
2204                    array_index,
2205                    sample,
2206                    level: level.map(LevelOfDetail::Direct),
2207                };
2208                self.put_image_load(expr_handle, image, address, context)?;
2209            }
2210            //Note: for all the queries, the signed integers are expected,
2211            // so a conversion is needed.
2212            crate::Expression::ImageQuery { image, query } => match query {
2213                crate::ImageQuery::Size { level } => {
2214                    self.put_image_size_query(
2215                        image,
2216                        level.map(LevelOfDetail::Direct),
2217                        crate::ScalarKind::Uint,
2218                        context,
2219                    )?;
2220                }
2221                crate::ImageQuery::NumLevels => {
2222                    self.put_expression(image, context, false)?;
2223                    write!(self.out, ".get_num_mip_levels()")?;
2224                }
2225                crate::ImageQuery::NumLayers => {
2226                    self.put_expression(image, context, false)?;
2227                    write!(self.out, ".get_array_size()")?;
2228                }
2229                crate::ImageQuery::NumSamples => {
2230                    self.put_expression(image, context, false)?;
2231                    write!(self.out, ".get_num_samples()")?;
2232                }
2233            },
2234            crate::Expression::Unary { op, expr } => {
2235                let op_str = match op {
2236                    crate::UnaryOperator::Negate => {
2237                        match context.resolve_type(expr).scalar_kind() {
2238                            Some(crate::ScalarKind::Sint) => NEG_FUNCTION,
2239                            _ => "-",
2240                        }
2241                    }
2242                    crate::UnaryOperator::LogicalNot => "!",
2243                    crate::UnaryOperator::BitwiseNot => "~",
2244                };
2245                write!(self.out, "{op_str}(")?;
2246                self.put_expression(expr, context, false)?;
2247                write!(self.out, ")")?;
2248            }
2249            crate::Expression::Binary { op, left, right } => {
2250                let kind = context
2251                    .resolve_type(left)
2252                    .scalar_kind()
2253                    .ok_or(Error::UnsupportedBinaryOp(op))?;
2254
2255                if op == crate::BinaryOperator::Divide
2256                    && (kind == crate::ScalarKind::Sint || kind == crate::ScalarKind::Uint)
2257                {
2258                    write!(self.out, "{DIV_FUNCTION}(")?;
2259                    self.put_expression(left, context, true)?;
2260                    write!(self.out, ", ")?;
2261                    self.put_expression(right, context, true)?;
2262                    write!(self.out, ")")?;
2263                } else if op == crate::BinaryOperator::Modulo
2264                    && (kind == crate::ScalarKind::Sint || kind == crate::ScalarKind::Uint)
2265                {
2266                    write!(self.out, "{MOD_FUNCTION}(")?;
2267                    self.put_expression(left, context, true)?;
2268                    write!(self.out, ", ")?;
2269                    self.put_expression(right, context, true)?;
2270                    write!(self.out, ")")?;
2271                } else if op == crate::BinaryOperator::Modulo && kind == crate::ScalarKind::Float {
2272                    // TODO: handle undefined behavior of BinaryOperator::Modulo
2273                    //
2274                    // float:
2275                    // if right == 0 return ? see https://github.com/gpuweb/gpuweb/issues/2798
2276                    write!(self.out, "{NAMESPACE}::fmod(")?;
2277                    self.put_expression(left, context, true)?;
2278                    write!(self.out, ", ")?;
2279                    self.put_expression(right, context, true)?;
2280                    write!(self.out, ")")?;
2281                } else if (op == crate::BinaryOperator::Add
2282                    || op == crate::BinaryOperator::Subtract
2283                    || op == crate::BinaryOperator::Multiply)
2284                    && kind == crate::ScalarKind::Sint
2285                {
2286                    let to_unsigned = |ty: &crate::TypeInner| match *ty {
2287                        crate::TypeInner::Scalar(scalar) => {
2288                            Ok(crate::TypeInner::Scalar(crate::Scalar {
2289                                kind: crate::ScalarKind::Uint,
2290                                ..scalar
2291                            }))
2292                        }
2293                        crate::TypeInner::Vector { size, scalar } => Ok(crate::TypeInner::Vector {
2294                            size,
2295                            scalar: crate::Scalar {
2296                                kind: crate::ScalarKind::Uint,
2297                                ..scalar
2298                            },
2299                        }),
2300                        _ => Err(Error::UnsupportedBitCast(ty.clone())),
2301                    };
2302
2303                    // Avoid undefined behaviour due to overflowing signed
2304                    // integer arithmetic. Cast the operands to unsigned prior
2305                    // to performing the operation, then cast the result back
2306                    // to signed.
2307                    self.put_bitcasted_expression(
2308                        context.resolve_type(expr_handle),
2309                        expr_handle,
2310                        context,
2311                        &|writer, context, is_scoped| {
2312                            writer.put_binop(
2313                                op,
2314                                left,
2315                                right,
2316                                context,
2317                                is_scoped,
2318                                &|writer, expr, context, _is_scoped| {
2319                                    writer.put_bitcasted_expression(
2320                                        &to_unsigned(context.resolve_type(expr))?,
2321                                        expr,
2322                                        context,
2323                                        &|writer, context, is_scoped| {
2324                                            writer.put_expression(expr, context, is_scoped)
2325                                        },
2326                                    )
2327                                },
2328                            )
2329                        },
2330                    )?;
2331                } else {
2332                    self.put_binop(op, left, right, context, is_scoped, &Self::put_expression)?;
2333                }
2334            }
2335            crate::Expression::Select {
2336                condition,
2337                accept,
2338                reject,
2339            } => match *context.resolve_type(condition) {
2340                crate::TypeInner::Scalar(crate::Scalar {
2341                    kind: crate::ScalarKind::Bool,
2342                    ..
2343                }) => {
2344                    if !is_scoped {
2345                        write!(self.out, "(")?;
2346                    }
2347                    self.put_expression(condition, context, false)?;
2348                    write!(self.out, " ? ")?;
2349                    self.put_expression(accept, context, false)?;
2350                    write!(self.out, " : ")?;
2351                    self.put_expression(reject, context, false)?;
2352                    if !is_scoped {
2353                        write!(self.out, ")")?;
2354                    }
2355                }
2356                crate::TypeInner::Vector {
2357                    scalar:
2358                        crate::Scalar {
2359                            kind: crate::ScalarKind::Bool,
2360                            ..
2361                        },
2362                    ..
2363                } => {
2364                    write!(self.out, "{NAMESPACE}::select(")?;
2365                    self.put_expression(reject, context, true)?;
2366                    write!(self.out, ", ")?;
2367                    self.put_expression(accept, context, true)?;
2368                    write!(self.out, ", ")?;
2369                    self.put_expression(condition, context, true)?;
2370                    write!(self.out, ")")?;
2371                }
2372                ref ty => {
2373                    return Err(Error::GenericValidation(format!(
2374                        "Expected select condition to be a non-bool type, got {ty:?}",
2375                    )))
2376                }
2377            },
2378            crate::Expression::Derivative { axis, expr, .. } => {
2379                use crate::DerivativeAxis as Axis;
2380                let op = match axis {
2381                    Axis::X => "dfdx",
2382                    Axis::Y => "dfdy",
2383                    Axis::Width => "fwidth",
2384                };
2385                write!(self.out, "{NAMESPACE}::{op}")?;
2386                self.put_call_parameters(iter::once(expr), context)?;
2387            }
2388            crate::Expression::Relational { fun, argument } => {
2389                let op = match fun {
2390                    crate::RelationalFunction::Any => "any",
2391                    crate::RelationalFunction::All => "all",
2392                    crate::RelationalFunction::IsNan => "isnan",
2393                    crate::RelationalFunction::IsInf => "isinf",
2394                };
2395                write!(self.out, "{NAMESPACE}::{op}")?;
2396                self.put_call_parameters(iter::once(argument), context)?;
2397            }
2398            crate::Expression::Math {
2399                fun,
2400                arg,
2401                arg1,
2402                arg2,
2403                arg3,
2404            } => {
2405                use crate::MathFunction as Mf;
2406
2407                let arg_type = context.resolve_type(arg);
2408                let scalar_argument = match arg_type {
2409                    &crate::TypeInner::Scalar(_) => true,
2410                    _ => false,
2411                };
2412
2413                let fun_name = match fun {
2414                    // comparison
2415                    Mf::Abs => "abs",
2416                    Mf::Min => "min",
2417                    Mf::Max => "max",
2418                    Mf::Clamp => "clamp",
2419                    Mf::Saturate => "saturate",
2420                    // trigonometry
2421                    Mf::Cos => "cos",
2422                    Mf::Cosh => "cosh",
2423                    Mf::Sin => "sin",
2424                    Mf::Sinh => "sinh",
2425                    Mf::Tan => "tan",
2426                    Mf::Tanh => "tanh",
2427                    Mf::Acos => "acos",
2428                    Mf::Asin => "asin",
2429                    Mf::Atan => "atan",
2430                    Mf::Atan2 => "atan2",
2431                    Mf::Asinh => "asinh",
2432                    Mf::Acosh => "acosh",
2433                    Mf::Atanh => "atanh",
2434                    Mf::Radians => "",
2435                    Mf::Degrees => "",
2436                    // decomposition
2437                    Mf::Ceil => "ceil",
2438                    Mf::Floor => "floor",
2439                    Mf::Round => "rint",
2440                    Mf::Fract => "fract",
2441                    Mf::Trunc => "trunc",
2442                    Mf::Modf => MODF_FUNCTION,
2443                    Mf::Frexp => FREXP_FUNCTION,
2444                    Mf::Ldexp => "ldexp",
2445                    // exponent
2446                    Mf::Exp => "exp",
2447                    Mf::Exp2 => "exp2",
2448                    Mf::Log => "log",
2449                    Mf::Log2 => "log2",
2450                    Mf::Pow => "pow",
2451                    // geometry
2452                    Mf::Dot => match *context.resolve_type(arg) {
2453                        crate::TypeInner::Vector {
2454                            scalar:
2455                                crate::Scalar {
2456                                    // Resolve float values to MSL's builtin dot function.
2457                                    kind: crate::ScalarKind::Float,
2458                                    ..
2459                                },
2460                            ..
2461                        } => "dot",
2462                        crate::TypeInner::Vector {
2463                            size,
2464                            scalar:
2465                                scalar @ crate::Scalar {
2466                                    kind: crate::ScalarKind::Sint | crate::ScalarKind::Uint,
2467                                    ..
2468                                },
2469                        } => {
2470                            // Integer vector dot: call our mangled helper `dot_{type}{N}(a, b)`.
2471                            let fun_name = self.get_dot_wrapper_function_helper_name(scalar, size);
2472                            write!(self.out, "{fun_name}(")?;
2473                            self.put_expression(arg, context, true)?;
2474                            write!(self.out, ", ")?;
2475                            self.put_expression(arg1.unwrap(), context, true)?;
2476                            write!(self.out, ")")?;
2477                            return Ok(());
2478                        }
2479                        _ => unreachable!(
2480                            "Correct TypeInner for dot product should be already validated"
2481                        ),
2482                    },
2483                    fun @ (Mf::Dot4I8Packed | Mf::Dot4U8Packed) => {
2484                        if context.lang_version >= (2, 1) {
2485                            // Write potentially optimizable code using `packed_(u?)char4`.
2486                            // The two function arguments were already reinterpreted as packed (signed
2487                            // or unsigned) chars in `Self::put_block`.
2488                            let packed_type = match fun {
2489                                Mf::Dot4I8Packed => "packed_char4",
2490                                Mf::Dot4U8Packed => "packed_uchar4",
2491                                _ => unreachable!(),
2492                            };
2493
2494                            return self.put_dot_product(
2495                                Reinterpreted::new(packed_type, arg),
2496                                Reinterpreted::new(packed_type, arg1.unwrap()),
2497                                4,
2498                                |writer, arg, index| {
2499                                    // MSL implicitly promotes these (signed or unsigned) chars to
2500                                    // `int` or `uint` in the multiplication, so no overflow can occur.
2501                                    write!(writer.out, "{arg}[{index}]")?;
2502                                    Ok(())
2503                                },
2504                            );
2505                        } else {
2506                            // Fall back to a polyfill since MSL < 2.1 doesn't seem to support
2507                            // bitcasting from uint to `packed_char4` or `packed_uchar4`.
2508                            // See <https://github.com/gfx-rs/wgpu/pull/7574#issuecomment-2835464472>.
2509                            let conversion = match fun {
2510                                Mf::Dot4I8Packed => "int",
2511                                Mf::Dot4U8Packed => "",
2512                                _ => unreachable!(),
2513                            };
2514
2515                            return self.put_dot_product(
2516                                arg,
2517                                arg1.unwrap(),
2518                                4,
2519                                |writer, arg, index| {
2520                                    write!(writer.out, "({conversion}(")?;
2521                                    writer.put_expression(arg, context, true)?;
2522                                    if index == 3 {
2523                                        write!(writer.out, ") >> 24)")?;
2524                                    } else {
2525                                        write!(writer.out, ") << {} >> 24)", (3 - index) * 8)?;
2526                                    }
2527                                    Ok(())
2528                                },
2529                            );
2530                        }
2531                    }
2532                    Mf::Outer => return Err(Error::UnsupportedCall(format!("{fun:?}"))),
2533                    Mf::Cross => "cross",
2534                    Mf::Distance => "distance",
2535                    Mf::Length if scalar_argument => "abs",
2536                    Mf::Length => "length",
2537                    Mf::Normalize => "normalize",
2538                    Mf::FaceForward => "faceforward",
2539                    Mf::Reflect => "reflect",
2540                    Mf::Refract => "refract",
2541                    // computational
2542                    Mf::Sign => match arg_type.scalar_kind() {
2543                        Some(crate::ScalarKind::Sint) => {
2544                            return self.put_isign(arg, context);
2545                        }
2546                        _ => "sign",
2547                    },
2548                    Mf::Fma => "fma",
2549                    Mf::Mix => "mix",
2550                    Mf::Step => "step",
2551                    Mf::SmoothStep => "smoothstep",
2552                    Mf::Sqrt => "sqrt",
2553                    Mf::InverseSqrt => "rsqrt",
2554                    Mf::Inverse => return Err(Error::UnsupportedCall(format!("{fun:?}"))),
2555                    Mf::Transpose => "transpose",
2556                    Mf::Determinant => "determinant",
2557                    Mf::QuantizeToF16 => "",
2558                    // bits
2559                    Mf::CountTrailingZeros => "ctz",
2560                    Mf::CountLeadingZeros => "clz",
2561                    Mf::CountOneBits => "popcount",
2562                    Mf::ReverseBits => "reverse_bits",
2563                    Mf::ExtractBits => "",
2564                    Mf::InsertBits => "",
2565                    Mf::FirstTrailingBit => "",
2566                    Mf::FirstLeadingBit => "",
2567                    // data packing
2568                    Mf::Pack4x8snorm => "pack_float_to_snorm4x8",
2569                    Mf::Pack4x8unorm => "pack_float_to_unorm4x8",
2570                    Mf::Pack2x16snorm => "pack_float_to_snorm2x16",
2571                    Mf::Pack2x16unorm => "pack_float_to_unorm2x16",
2572                    Mf::Pack2x16float => "",
2573                    Mf::Pack4xI8 => "",
2574                    Mf::Pack4xU8 => "",
2575                    Mf::Pack4xI8Clamp => "",
2576                    Mf::Pack4xU8Clamp => "",
2577                    // data unpacking
2578                    Mf::Unpack4x8snorm => "unpack_snorm4x8_to_float",
2579                    Mf::Unpack4x8unorm => "unpack_unorm4x8_to_float",
2580                    Mf::Unpack2x16snorm => "unpack_snorm2x16_to_float",
2581                    Mf::Unpack2x16unorm => "unpack_unorm2x16_to_float",
2582                    Mf::Unpack2x16float => "",
2583                    Mf::Unpack4xI8 => "",
2584                    Mf::Unpack4xU8 => "",
2585                };
2586
2587                match fun {
2588                    Mf::ReverseBits | Mf::ExtractBits | Mf::InsertBits => {
2589                        // reverse_bits is listed as requiring MSL 2.1 but that
2590                        // is a copy/paste error. Looking at previous snapshots
2591                        // on web.archive.org it's present in MSL 1.2.
2592                        //
2593                        // https://developer.apple.com/library/archive/documentation/Miscellaneous/Conceptual/MetalProgrammingGuide/WhatsNewiniOS10tvOS10andOSX1012/WhatsNewiniOS10tvOS10andOSX1012.html
2594                        // also talks about MSL 1.2 adding "New integer
2595                        // functions to extract, insert, and reverse bits, as
2596                        // described in Integer Functions."
2597                        if context.lang_version < (1, 2) {
2598                            return Err(Error::UnsupportedFunction(fun_name.to_string()));
2599                        }
2600                    }
2601                    _ => {}
2602                }
2603
2604                match fun {
2605                    Mf::Abs if arg_type.scalar_kind() == Some(crate::ScalarKind::Sint) => {
2606                        write!(self.out, "{ABS_FUNCTION}(")?;
2607                        self.put_expression(arg, context, true)?;
2608                        write!(self.out, ")")?;
2609                    }
2610                    Mf::Distance if scalar_argument => {
2611                        write!(self.out, "{NAMESPACE}::abs(")?;
2612                        self.put_expression(arg, context, false)?;
2613                        write!(self.out, " - ")?;
2614                        self.put_expression(arg1.unwrap(), context, false)?;
2615                        write!(self.out, ")")?;
2616                    }
2617                    Mf::FirstTrailingBit => {
2618                        let scalar = context.resolve_type(arg).scalar().unwrap();
2619                        let constant = scalar.width * 8 + 1;
2620
2621                        write!(self.out, "((({NAMESPACE}::ctz(")?;
2622                        self.put_expression(arg, context, true)?;
2623                        write!(self.out, ") + 1) % {constant}) - 1)")?;
2624                    }
2625                    Mf::FirstLeadingBit => {
2626                        let inner = context.resolve_type(arg);
2627                        let scalar = inner.scalar().unwrap();
2628                        let constant = scalar.width * 8 - 1;
2629
2630                        write!(
2631                            self.out,
2632                            "{NAMESPACE}::select({constant} - {NAMESPACE}::clz("
2633                        )?;
2634
2635                        if scalar.kind == crate::ScalarKind::Sint {
2636                            write!(self.out, "{NAMESPACE}::select(")?;
2637                            self.put_expression(arg, context, true)?;
2638                            write!(self.out, ", ~")?;
2639                            self.put_expression(arg, context, true)?;
2640                            write!(self.out, ", ")?;
2641                            self.put_expression(arg, context, true)?;
2642                            write!(self.out, " < 0)")?;
2643                        } else {
2644                            self.put_expression(arg, context, true)?;
2645                        }
2646
2647                        write!(self.out, "), ")?;
2648
2649                        // or metal will complain that select is ambiguous
2650                        match *inner {
2651                            crate::TypeInner::Vector { size, scalar } => {
2652                                let size = common::vector_size_str(size);
2653                                let name = scalar.to_msl_name();
2654                                write!(self.out, "{name}{size}")?;
2655                            }
2656                            crate::TypeInner::Scalar(scalar) => {
2657                                let name = scalar.to_msl_name();
2658                                write!(self.out, "{name}")?;
2659                            }
2660                            _ => (),
2661                        }
2662
2663                        write!(self.out, "(-1), ")?;
2664                        self.put_expression(arg, context, true)?;
2665                        write!(self.out, " == 0")?;
2666                        if scalar.kind == crate::ScalarKind::Sint {
2667                            write!(self.out, " || ")?;
2668                            self.put_expression(arg, context, true)?;
2669                            write!(self.out, " == -1")?;
2670                        }
2671                        write!(self.out, ")")?;
2672                    }
2673                    Mf::Unpack2x16float => {
2674                        write!(self.out, "float2(as_type<half2>(")?;
2675                        self.put_expression(arg, context, false)?;
2676                        write!(self.out, "))")?;
2677                    }
2678                    Mf::Pack2x16float => {
2679                        write!(self.out, "as_type<uint>(half2(")?;
2680                        self.put_expression(arg, context, false)?;
2681                        write!(self.out, "))")?;
2682                    }
2683                    Mf::ExtractBits => {
2684                        // The behavior of ExtractBits is undefined when offset + count > bit_width. We need
2685                        // to first sanitize the offset and count first. If we don't do this, Apple chips
2686                        // will return out-of-spec values if the extracted range is not within the bit width.
2687                        //
2688                        // This encodes the exact formula specified by the wgsl spec, without temporary values:
2689                        // https://gpuweb.github.io/gpuweb/wgsl/#extractBits-unsigned-builtin
2690                        //
2691                        // w = sizeof(x) * 8
2692                        // o = min(offset, w)
2693                        // tmp = w - o
2694                        // c = min(count, tmp)
2695                        //
2696                        // bitfieldExtract(x, o, c)
2697                        //
2698                        // extract_bits(e, min(offset, w), min(count, w - min(offset, w))))
2699
2700                        let scalar_bits = context.resolve_type(arg).scalar_width().unwrap() * 8;
2701
2702                        write!(self.out, "{NAMESPACE}::extract_bits(")?;
2703                        self.put_expression(arg, context, true)?;
2704                        write!(self.out, ", {NAMESPACE}::min(")?;
2705                        self.put_expression(arg1.unwrap(), context, true)?;
2706                        write!(self.out, ", {scalar_bits}u), {NAMESPACE}::min(")?;
2707                        self.put_expression(arg2.unwrap(), context, true)?;
2708                        write!(self.out, ", {scalar_bits}u - {NAMESPACE}::min(")?;
2709                        self.put_expression(arg1.unwrap(), context, true)?;
2710                        write!(self.out, ", {scalar_bits}u)))")?;
2711                    }
2712                    Mf::InsertBits => {
2713                        // The behavior of InsertBits has the same issue as ExtractBits.
2714                        //
2715                        // insertBits(e, newBits, min(offset, w), min(count, w - min(offset, w))))
2716
2717                        let scalar_bits = context.resolve_type(arg).scalar_width().unwrap() * 8;
2718
2719                        write!(self.out, "{NAMESPACE}::insert_bits(")?;
2720                        self.put_expression(arg, context, true)?;
2721                        write!(self.out, ", ")?;
2722                        self.put_expression(arg1.unwrap(), context, true)?;
2723                        write!(self.out, ", {NAMESPACE}::min(")?;
2724                        self.put_expression(arg2.unwrap(), context, true)?;
2725                        write!(self.out, ", {scalar_bits}u), {NAMESPACE}::min(")?;
2726                        self.put_expression(arg3.unwrap(), context, true)?;
2727                        write!(self.out, ", {scalar_bits}u - {NAMESPACE}::min(")?;
2728                        self.put_expression(arg2.unwrap(), context, true)?;
2729                        write!(self.out, ", {scalar_bits}u)))")?;
2730                    }
2731                    Mf::Radians => {
2732                        write!(self.out, "((")?;
2733                        self.put_expression(arg, context, false)?;
2734                        write!(self.out, ") * 0.017453292519943295474)")?;
2735                    }
2736                    Mf::Degrees => {
2737                        write!(self.out, "((")?;
2738                        self.put_expression(arg, context, false)?;
2739                        write!(self.out, ") * 57.295779513082322865)")?;
2740                    }
2741                    Mf::Modf | Mf::Frexp => {
2742                        write!(self.out, "{fun_name}")?;
2743                        self.put_call_parameters(iter::once(arg), context)?;
2744                    }
2745                    Mf::Pack4xI8 => self.put_pack4x8(arg, context, true, None)?,
2746                    Mf::Pack4xU8 => self.put_pack4x8(arg, context, false, None)?,
2747                    Mf::Pack4xI8Clamp => {
2748                        self.put_pack4x8(arg, context, true, Some(("-128", "127")))?
2749                    }
2750                    Mf::Pack4xU8Clamp => {
2751                        self.put_pack4x8(arg, context, false, Some(("0", "255")))?
2752                    }
2753                    fun @ (Mf::Unpack4xI8 | Mf::Unpack4xU8) => {
2754                        let sign_prefix = if matches!(fun, Mf::Unpack4xU8) {
2755                            "u"
2756                        } else {
2757                            ""
2758                        };
2759
2760                        if context.lang_version >= (2, 1) {
2761                            // Metal uses little endian byte order, which matches what WGSL expects here.
2762                            write!(
2763                                self.out,
2764                                "{sign_prefix}int4(as_type<packed_{sign_prefix}char4>("
2765                            )?;
2766                            self.put_expression(arg, context, true)?;
2767                            write!(self.out, "))")?;
2768                        } else {
2769                            // MSL < 2.1 doesn't support `as_type` casting between packed chars and scalars.
2770                            write!(self.out, "({sign_prefix}int4(")?;
2771                            self.put_expression(arg, context, true)?;
2772                            write!(self.out, ", ")?;
2773                            self.put_expression(arg, context, true)?;
2774                            write!(self.out, " >> 8, ")?;
2775                            self.put_expression(arg, context, true)?;
2776                            write!(self.out, " >> 16, ")?;
2777                            self.put_expression(arg, context, true)?;
2778                            write!(self.out, " >> 24) << 24 >> 24)")?;
2779                        }
2780                    }
2781                    Mf::QuantizeToF16 => {
2782                        match *context.resolve_type(arg) {
2783                            crate::TypeInner::Scalar { .. } => write!(self.out, "float(half(")?,
2784                            crate::TypeInner::Vector { size, .. } => write!(
2785                                self.out,
2786                                "{NAMESPACE}::float{size}({NAMESPACE}::half{size}(",
2787                                size = common::vector_size_str(size),
2788                            )?,
2789                            _ => unreachable!(
2790                                "Correct TypeInner for QuantizeToF16 should be already validated"
2791                            ),
2792                        };
2793
2794                        self.put_expression(arg, context, true)?;
2795                        write!(self.out, "))")?;
2796                    }
2797                    _ => {
2798                        write!(self.out, "{NAMESPACE}::{fun_name}")?;
2799                        self.put_call_parameters(
2800                            iter::once(arg).chain(arg1).chain(arg2).chain(arg3),
2801                            context,
2802                        )?;
2803                    }
2804                }
2805            }
2806            crate::Expression::As {
2807                expr,
2808                kind,
2809                convert,
2810            } => match *context.resolve_type(expr) {
2811                crate::TypeInner::Scalar(src) | crate::TypeInner::Vector { scalar: src, .. } => {
2812                    if src.kind == crate::ScalarKind::Float
2813                        && (kind == crate::ScalarKind::Sint || kind == crate::ScalarKind::Uint)
2814                        && convert.is_some()
2815                    {
2816                        // Use helper functions for float to int casts in order to avoid
2817                        // undefined behaviour when value is out of range for the target
2818                        // type.
2819                        let fun_name = match (kind, convert) {
2820                            (crate::ScalarKind::Sint, Some(4)) => F2I32_FUNCTION,
2821                            (crate::ScalarKind::Uint, Some(4)) => F2U32_FUNCTION,
2822                            (crate::ScalarKind::Sint, Some(8)) => F2I64_FUNCTION,
2823                            (crate::ScalarKind::Uint, Some(8)) => F2U64_FUNCTION,
2824                            _ => unreachable!(),
2825                        };
2826                        write!(self.out, "{fun_name}(")?;
2827                        self.put_expression(expr, context, true)?;
2828                        write!(self.out, ")")?;
2829                    } else {
2830                        let target_scalar = crate::Scalar {
2831                            kind,
2832                            width: convert.unwrap_or(src.width),
2833                        };
2834                        let op = match convert {
2835                            Some(_) => "static_cast",
2836                            None => "as_type",
2837                        };
2838                        write!(self.out, "{op}<")?;
2839                        match *context.resolve_type(expr) {
2840                            crate::TypeInner::Vector { size, .. } => {
2841                                put_numeric_type(&mut self.out, target_scalar, &[size])?
2842                            }
2843                            _ => put_numeric_type(&mut self.out, target_scalar, &[])?,
2844                        };
2845                        write!(self.out, ">(")?;
2846                        self.put_expression(expr, context, true)?;
2847                        write!(self.out, ")")?;
2848                    }
2849                }
2850                crate::TypeInner::Matrix {
2851                    columns,
2852                    rows,
2853                    scalar,
2854                } => {
2855                    let target_scalar = crate::Scalar {
2856                        kind,
2857                        width: convert.unwrap_or(scalar.width),
2858                    };
2859                    put_numeric_type(&mut self.out, target_scalar, &[rows, columns])?;
2860                    write!(self.out, "(")?;
2861                    self.put_expression(expr, context, true)?;
2862                    write!(self.out, ")")?;
2863                }
2864                ref ty => {
2865                    return Err(Error::GenericValidation(format!(
2866                        "Unsupported type for As: {ty:?}"
2867                    )))
2868                }
2869            },
2870            // has to be a named expression
2871            crate::Expression::CallResult(_)
2872            | crate::Expression::AtomicResult { .. }
2873            | crate::Expression::WorkGroupUniformLoadResult { .. }
2874            | crate::Expression::SubgroupBallotResult
2875            | crate::Expression::SubgroupOperationResult { .. }
2876            | crate::Expression::RayQueryProceedResult => {
2877                unreachable!()
2878            }
2879            crate::Expression::ArrayLength(expr) => {
2880                // Find the global to which the array belongs.
2881                let global = match context.function.expressions[expr] {
2882                    crate::Expression::AccessIndex { base, .. } => {
2883                        match context.function.expressions[base] {
2884                            crate::Expression::GlobalVariable(handle) => handle,
2885                            ref ex => {
2886                                return Err(Error::GenericValidation(format!(
2887                                    "Expected global variable in AccessIndex, got {ex:?}"
2888                                )))
2889                            }
2890                        }
2891                    }
2892                    crate::Expression::GlobalVariable(handle) => handle,
2893                    ref ex => {
2894                        return Err(Error::GenericValidation(format!(
2895                            "Unexpected expression in ArrayLength, got {ex:?}"
2896                        )))
2897                    }
2898                };
2899
2900                if !is_scoped {
2901                    write!(self.out, "(")?;
2902                }
2903                write!(self.out, "1 + ")?;
2904                self.put_dynamic_array_max_index(global, context)?;
2905                if !is_scoped {
2906                    write!(self.out, ")")?;
2907                }
2908            }
2909            crate::Expression::RayQueryVertexPositions { .. } => {
2910                unimplemented!()
2911            }
2912            crate::Expression::RayQueryGetIntersection { query, committed } => {
2913                if context.lang_version < (2, 4) {
2914                    return Err(Error::UnsupportedRayTracing);
2915                }
2916
2917                write!(
2918                    self.out,
2919                    "{}_{committed}(",
2920                    super::ray::INTERSECTION_FUNCTION_NAME
2921                )?;
2922                self.put_expression(query, context, true)?;
2923                write!(self.out, ")")?;
2924            }
2925            crate::Expression::CooperativeLoad { ref data, .. } => {
2926                if context.lang_version < (2, 3) {
2927                    return Err(Error::UnsupportedCooperativeMatrix);
2928                }
2929                write!(self.out, "{COOPERATIVE_LOAD_FUNCTION}(")?;
2930                write!(self.out, "&")?;
2931                self.put_access_chain(data.pointer, context.policies.index, context)?;
2932                write!(self.out, ", ")?;
2933                self.put_expression(data.stride, context, true)?;
2934                // Metal's `simdgroup_load` treats its `transpose` flag as
2935                // "memory is transposed from the simdgroup_matrix's canonical
2936                // layout". On Apple GPUs that canonical layout is row-major,
2937                // so `transpose=false` loads from row-major memory. WGSL's
2938                // `coopLoadT` (row_major=true) = row-major memory, so it must
2939                // map to `transpose=false`. Hence the negation.
2940                write!(self.out, ", {})", !data.row_major)?;
2941            }
2942            crate::Expression::CooperativeMultiplyAdd { a, b, c } => {
2943                if context.lang_version < (2, 3) {
2944                    return Err(Error::UnsupportedCooperativeMatrix);
2945                }
2946                write!(self.out, "{COOPERATIVE_MULTIPLY_ADD_FUNCTION}(")?;
2947                self.put_expression(a, context, true)?;
2948                write!(self.out, ", ")?;
2949                self.put_expression(b, context, true)?;
2950                write!(self.out, ", ")?;
2951                self.put_expression(c, context, true)?;
2952                write!(self.out, ")")?;
2953            }
2954        }
2955        Ok(())
2956    }
2957
2958    /// Emits code for a binary operation, using the provided callback to emit
2959    /// the left and right operands.
2960    fn put_binop<F>(
2961        &mut self,
2962        op: crate::BinaryOperator,
2963        left: Handle<crate::Expression>,
2964        right: Handle<crate::Expression>,
2965        context: &ExpressionContext,
2966        is_scoped: bool,
2967        put_expression: &F,
2968    ) -> BackendResult
2969    where
2970        F: Fn(&mut Self, Handle<crate::Expression>, &ExpressionContext, bool) -> BackendResult,
2971    {
2972        let op_str = back::binary_operation_str(op);
2973
2974        if !is_scoped {
2975            write!(self.out, "(")?;
2976        }
2977
2978        // Cast packed vector if necessary
2979        // Packed vector - matrix multiplications are not supported in MSL
2980        if op == crate::BinaryOperator::Multiply
2981            && matches!(
2982                context.resolve_type(right),
2983                &crate::TypeInner::Matrix { .. }
2984            )
2985        {
2986            self.put_wrapped_expression_for_packed_vec3_access(
2987                left,
2988                context,
2989                false,
2990                put_expression,
2991            )?;
2992        } else {
2993            put_expression(self, left, context, false)?;
2994        }
2995
2996        write!(self.out, " {op_str} ")?;
2997
2998        // See comment above
2999        if op == crate::BinaryOperator::Multiply
3000            && matches!(context.resolve_type(left), &crate::TypeInner::Matrix { .. })
3001        {
3002            self.put_wrapped_expression_for_packed_vec3_access(
3003                right,
3004                context,
3005                false,
3006                put_expression,
3007            )?;
3008        } else {
3009            put_expression(self, right, context, false)?;
3010        }
3011
3012        if !is_scoped {
3013            write!(self.out, ")")?;
3014        }
3015
3016        Ok(())
3017    }
3018
3019    /// Used by expressions like Swizzle and Binary since they need packed_vec3's to be casted to a vec3
3020    fn put_wrapped_expression_for_packed_vec3_access<F>(
3021        &mut self,
3022        expr_handle: Handle<crate::Expression>,
3023        context: &ExpressionContext,
3024        is_scoped: bool,
3025        put_expression: &F,
3026    ) -> BackendResult
3027    where
3028        F: Fn(&mut Self, Handle<crate::Expression>, &ExpressionContext, bool) -> BackendResult,
3029    {
3030        if let Some(scalar) = context.get_packed_vec_kind(expr_handle) {
3031            write!(self.out, "{}::{}3(", NAMESPACE, scalar.to_msl_name())?;
3032            put_expression(self, expr_handle, context, is_scoped)?;
3033            write!(self.out, ")")?;
3034        } else {
3035            put_expression(self, expr_handle, context, is_scoped)?;
3036        }
3037        Ok(())
3038    }
3039
3040    /// Emits code for an expression using the provided callback, wrapping the
3041    /// result in a bitcast to the type `cast_to`.
3042    fn put_bitcasted_expression<F>(
3043        &mut self,
3044        cast_to: &crate::TypeInner,
3045        inner_expr: Handle<crate::Expression>,
3046        context: &ExpressionContext,
3047        put_expression: &F,
3048    ) -> BackendResult
3049    where
3050        F: Fn(&mut Self, &ExpressionContext, bool) -> BackendResult,
3051    {
3052        // For sub-32-bit types, C++ integer promotion can widen the inner
3053        // expression (e.g. `ushort + ushort` promotes to `int`), making a
3054        // direct `as_type<short>(int_expr)` invalid due to size mismatch.
3055        // We wrap with `static_cast` to truncate back before the bitcast.
3056        let needs_truncation = match *cast_to {
3057            crate::TypeInner::Scalar(scalar) => scalar.width < 4,
3058            crate::TypeInner::Vector { scalar, .. } => scalar.width < 4,
3059            _ => false,
3060        };
3061
3062        write!(self.out, "as_type<")?;
3063        match *cast_to {
3064            crate::TypeInner::Scalar(scalar) => put_numeric_type(&mut self.out, scalar, &[])?,
3065            crate::TypeInner::Vector { size, scalar } => {
3066                put_numeric_type(&mut self.out, scalar, &[size])?
3067            }
3068            _ => return Err(Error::UnsupportedBitCast(cast_to.clone())),
3069        };
3070        write!(self.out, ">(")?;
3071
3072        if needs_truncation {
3073            write!(self.out, "static_cast<")?;
3074            // Cast to the unsigned version of the target type to truncate
3075            let unsigned_scalar = match *cast_to {
3076                crate::TypeInner::Scalar(scalar) => crate::Scalar {
3077                    kind: crate::ScalarKind::Uint,
3078                    ..scalar
3079                },
3080                crate::TypeInner::Vector { scalar, .. } => crate::Scalar {
3081                    kind: crate::ScalarKind::Uint,
3082                    ..scalar
3083                },
3084                _ => unreachable!(),
3085            };
3086            match *cast_to {
3087                crate::TypeInner::Scalar(_) => {
3088                    put_numeric_type(&mut self.out, unsigned_scalar, &[])?
3089                }
3090                crate::TypeInner::Vector { size, .. } => {
3091                    put_numeric_type(&mut self.out, unsigned_scalar, &[size])?
3092                }
3093                _ => unreachable!(),
3094            };
3095            write!(self.out, ">(")?;
3096        }
3097
3098        // if it's packed, we must unpack it (e.g., float3(val)) before the bitcast.
3099        if let Some(scalar) = context.get_packed_vec_kind(inner_expr) {
3100            put_numeric_type(&mut self.out, scalar, &[crate::VectorSize::Tri])?;
3101            write!(self.out, "(")?;
3102            put_expression(self, context, true)?;
3103            write!(self.out, ")")?;
3104        } else {
3105            put_expression(self, context, true)?;
3106        }
3107
3108        if needs_truncation {
3109            write!(self.out, ")")?;
3110        }
3111
3112        write!(self.out, ")")?;
3113        Ok(())
3114    }
3115
3116    /// Write a `GuardedIndex` as a Metal expression.
3117    fn put_index(
3118        &mut self,
3119        index: index::GuardedIndex,
3120        context: &ExpressionContext,
3121        is_scoped: bool,
3122    ) -> BackendResult {
3123        match index {
3124            index::GuardedIndex::Expression(expr) => {
3125                self.put_expression(expr, context, is_scoped)?
3126            }
3127            index::GuardedIndex::Known(value) => write!(self.out, "{value}")?,
3128        }
3129        Ok(())
3130    }
3131
3132    /// Emit an index bounds check condition for `chain`, if required.
3133    ///
3134    /// `chain` is a subtree of `Access` and `AccessIndex` expressions,
3135    /// operating either on a pointer to a value, or on a value directly. If we cannot
3136    /// statically determine that all indexing operations in `chain` are within
3137    /// bounds, then write a conditional expression to check them dynamically,
3138    /// and return true. All accesses in the chain are checked by the generated
3139    /// expression.
3140    ///
3141    /// This assumes that the [`BoundsCheckPolicy`] for `chain` is [`ReadZeroSkipWrite`].
3142    ///
3143    /// The text written is of the form:
3144    ///
3145    /// ```ignore
3146    /// {level}{prefix}uint(i) < 4 && uint(j) < 10
3147    /// ```
3148    ///
3149    /// where `{level}` and `{prefix}` are the arguments to this function. For [`Store`]
3150    /// statements, presumably these arguments start an indented `if` statement; for
3151    /// [`Load`] expressions, the caller is probably building up a ternary `?:`
3152    /// expression. In either case, what is written is not a complete syntactic structure
3153    /// in its own right, and the caller will have to finish it off if we return `true`.
3154    ///
3155    /// If no expression is written, return false.
3156    ///
3157    /// [`BoundsCheckPolicy`]: index::BoundsCheckPolicy
3158    /// [`ReadZeroSkipWrite`]: index::BoundsCheckPolicy::ReadZeroSkipWrite
3159    /// [`Store`]: crate::Statement::Store
3160    /// [`Load`]: crate::Expression::Load
3161    fn put_bounds_checks(
3162        &mut self,
3163        chain: Handle<crate::Expression>,
3164        context: &ExpressionContext,
3165        level: back::Level,
3166        prefix: &'static str,
3167    ) -> Result<bool, Error> {
3168        let mut check_written = false;
3169
3170        // Iterate over the access chain, handling each required bounds check.
3171        for item in context.bounds_check_iter(chain) {
3172            let BoundsCheck {
3173                base,
3174                index,
3175                length,
3176            } = item;
3177
3178            if check_written {
3179                write!(self.out, " && ")?;
3180            } else {
3181                write!(self.out, "{level}{prefix}")?;
3182                check_written = true;
3183            }
3184
3185            // Check that the index falls within bounds. Do this with a single
3186            // comparison, by casting the index to `uint` first, so that negative
3187            // indices become large positive values.
3188            write!(self.out, "uint(")?;
3189            self.put_index(index, context, true)?;
3190            self.out.write_str(") < ")?;
3191            match length {
3192                index::IndexableLength::Known(value) => write!(self.out, "{value}")?,
3193                index::IndexableLength::Dynamic => {
3194                    let global = context.function.originating_global(base).ok_or_else(|| {
3195                        Error::GenericValidation("Could not find originating global".into())
3196                    })?;
3197                    write!(self.out, "1 + ")?;
3198                    self.put_dynamic_array_max_index(global, context)?
3199                }
3200            }
3201        }
3202
3203        Ok(check_written)
3204    }
3205
3206    /// Write the access chain `chain`.
3207    ///
3208    /// `chain` is a subtree of [`Access`] and [`AccessIndex`] expressions,
3209    /// operating either on a pointer to a value, or on a value directly.
3210    ///
3211    /// Generate bounds checks code only if `policy` is [`Restrict`]. The
3212    /// [`ReadZeroSkipWrite`] policy requires checks before any accesses take place, so
3213    /// that must be handled in the caller.
3214    ///
3215    /// Handle the entire chain, recursing back into `put_expression` only for index
3216    /// expressions and the base expression that originates the pointer or composite value
3217    /// being accessed. This allows `put_expression` to assume that any `Access` or
3218    /// `AccessIndex` expressions it sees are the top of a chain, so it can emit
3219    /// `ReadZeroSkipWrite` checks.
3220    ///
3221    /// [`Access`]: crate::Expression::Access
3222    /// [`AccessIndex`]: crate::Expression::AccessIndex
3223    /// [`Restrict`]: crate::proc::index::BoundsCheckPolicy::Restrict
3224    /// [`ReadZeroSkipWrite`]: crate::proc::index::BoundsCheckPolicy::ReadZeroSkipWrite
3225    fn put_access_chain(
3226        &mut self,
3227        chain: Handle<crate::Expression>,
3228        policy: index::BoundsCheckPolicy,
3229        context: &ExpressionContext,
3230    ) -> BackendResult {
3231        match context.function.expressions[chain] {
3232            crate::Expression::Access { base, index } => {
3233                let mut base_ty = context.resolve_type(base);
3234
3235                // Look through any pointers to see what we're really indexing.
3236                if let crate::TypeInner::Pointer { base, space: _ } = *base_ty {
3237                    base_ty = &context.module.types[base].inner;
3238                }
3239
3240                self.put_subscripted_access_chain(
3241                    base,
3242                    base_ty,
3243                    index::GuardedIndex::Expression(index),
3244                    policy,
3245                    context,
3246                )?;
3247            }
3248            crate::Expression::AccessIndex { base, index } => {
3249                let base_resolution = &context.info[base].ty;
3250                let mut base_ty = base_resolution.inner_with(&context.module.types);
3251                let mut base_ty_handle = base_resolution.handle();
3252
3253                // Look through any pointers to see what we're really indexing.
3254                if let crate::TypeInner::Pointer { base, space: _ } = *base_ty {
3255                    base_ty = &context.module.types[base].inner;
3256                    base_ty_handle = Some(base);
3257                }
3258
3259                // Handle structs and anything else that can use `.x` syntax here, so
3260                // `put_subscripted_access_chain` won't have to handle the absurd case of
3261                // indexing a struct with an expression.
3262                match *base_ty {
3263                    crate::TypeInner::Struct { .. } => {
3264                        let base_ty = base_ty_handle.unwrap();
3265                        self.put_access_chain(base, policy, context)?;
3266                        let name = &self.names[&NameKey::StructMember(base_ty, index)];
3267                        write!(self.out, ".{name}")?;
3268                    }
3269                    crate::TypeInner::ValuePointer { .. } | crate::TypeInner::Vector { .. } => {
3270                        self.put_access_chain(base, policy, context)?;
3271                        // Prior to Metal v2.1 component access for packed vectors wasn't available
3272                        // however array indexing is
3273                        if context.get_packed_vec_kind(base).is_some() {
3274                            write!(self.out, "[{index}]")?;
3275                        } else {
3276                            write!(self.out, ".{}", back::COMPONENTS[index as usize])?;
3277                        }
3278                    }
3279                    _ => {
3280                        self.put_subscripted_access_chain(
3281                            base,
3282                            base_ty,
3283                            index::GuardedIndex::Known(index),
3284                            policy,
3285                            context,
3286                        )?;
3287                    }
3288                }
3289            }
3290            _ => self.put_expression(chain, context, false)?,
3291        }
3292
3293        Ok(())
3294    }
3295
3296    /// Write a `[]`-style access of `base` by `index`.
3297    ///
3298    /// If `policy` is [`Restrict`], then generate code as needed to force all index
3299    /// values within bounds.
3300    ///
3301    /// The `base_ty` argument must be the type we are actually indexing, like [`Array`] or
3302    /// [`Vector`]. In other words, it's `base`'s type with any surrounding [`Pointer`]
3303    /// removed. Our callers often already have this handy.
3304    ///
3305    /// This only emits `[]` expressions; it doesn't handle struct member accesses or
3306    /// referencing vector components by name.
3307    ///
3308    /// [`Restrict`]: crate::proc::index::BoundsCheckPolicy::Restrict
3309    /// [`Array`]: crate::TypeInner::Array
3310    /// [`Vector`]: crate::TypeInner::Vector
3311    /// [`Pointer`]: crate::TypeInner::Pointer
3312    fn put_subscripted_access_chain(
3313        &mut self,
3314        base: Handle<crate::Expression>,
3315        base_ty: &crate::TypeInner,
3316        index: index::GuardedIndex,
3317        policy: index::BoundsCheckPolicy,
3318        context: &ExpressionContext,
3319    ) -> BackendResult {
3320        let accessing_wrapped_array = match *base_ty {
3321            crate::TypeInner::Array {
3322                size: crate::ArraySize::Constant(_) | crate::ArraySize::Pending(_),
3323                ..
3324            } => true,
3325            _ => false,
3326        };
3327        let accessing_wrapped_binding_array =
3328            matches!(*base_ty, crate::TypeInner::BindingArray { .. });
3329
3330        self.put_access_chain(base, policy, context)?;
3331        if accessing_wrapped_array {
3332            write!(self.out, ".{WRAPPED_ARRAY_FIELD}")?;
3333        }
3334        write!(self.out, "[")?;
3335
3336        // Decide whether this index needs to be clamped to fall within range.
3337        let restriction_needed = if policy == index::BoundsCheckPolicy::Restrict {
3338            context.access_needs_check(base, index)
3339        } else {
3340            None
3341        };
3342        if let Some(limit) = restriction_needed {
3343            write!(self.out, "{NAMESPACE}::min(unsigned(")?;
3344            self.put_index(index, context, true)?;
3345            write!(self.out, "), ")?;
3346            match limit {
3347                index::IndexableLength::Known(limit) => {
3348                    write!(self.out, "{}u", limit - 1)?;
3349                }
3350                index::IndexableLength::Dynamic => {
3351                    let global = context.function.originating_global(base).ok_or_else(|| {
3352                        Error::GenericValidation("Could not find originating global".into())
3353                    })?;
3354                    self.put_dynamic_array_max_index(global, context)?;
3355                }
3356            }
3357            write!(self.out, ")")?;
3358        } else {
3359            self.put_index(index, context, true)?;
3360        }
3361
3362        write!(self.out, "]")?;
3363
3364        if accessing_wrapped_binding_array {
3365            write!(self.out, ".{WRAPPED_ARRAY_FIELD}")?;
3366        }
3367
3368        Ok(())
3369    }
3370
3371    fn put_load(
3372        &mut self,
3373        pointer: Handle<crate::Expression>,
3374        context: &ExpressionContext,
3375        is_scoped: bool,
3376    ) -> BackendResult {
3377        // Since access chains never cross between address spaces, we can just
3378        // check the index bounds check policy once at the top.
3379        let policy = context.choose_bounds_check_policy(pointer);
3380        if policy == index::BoundsCheckPolicy::ReadZeroSkipWrite
3381            && self.put_bounds_checks(
3382                pointer,
3383                context,
3384                back::Level(0),
3385                if is_scoped { "" } else { "(" },
3386            )?
3387        {
3388            write!(self.out, " ? ")?;
3389            self.put_unchecked_load(pointer, policy, context)?;
3390            write!(self.out, " : DefaultConstructible()")?;
3391
3392            if !is_scoped {
3393                write!(self.out, ")")?;
3394            }
3395        } else {
3396            self.put_unchecked_load(pointer, policy, context)?;
3397        }
3398
3399        Ok(())
3400    }
3401
3402    fn put_unchecked_load(
3403        &mut self,
3404        pointer: Handle<crate::Expression>,
3405        policy: index::BoundsCheckPolicy,
3406        context: &ExpressionContext,
3407    ) -> BackendResult {
3408        let is_atomic_pointer = context
3409            .resolve_type(pointer)
3410            .is_atomic_pointer(&context.module.types);
3411
3412        if is_atomic_pointer {
3413            write!(
3414                self.out,
3415                "{NAMESPACE}::atomic_load_explicit({ATOMIC_REFERENCE}"
3416            )?;
3417            self.put_access_chain(pointer, policy, context)?;
3418            write!(self.out, ", {NAMESPACE}::memory_order_relaxed)")?;
3419        } else {
3420            // We don't do any dereferencing with `*` here as pointer arguments to functions
3421            // are done by `&` references and not `*` pointers. These do not need to be
3422            // dereferenced.
3423            self.put_access_chain(pointer, policy, context)?;
3424        }
3425
3426        Ok(())
3427    }
3428
3429    fn put_return_value(
3430        &mut self,
3431        level: back::Level,
3432        expr_handle: Handle<crate::Expression>,
3433        result_struct: Option<&str>,
3434        context: &ExpressionContext,
3435    ) -> BackendResult {
3436        match result_struct {
3437            Some(struct_name) => {
3438                let mut has_point_size = false;
3439                let result_ty = context.function.result.as_ref().unwrap().ty;
3440                match context.module.types[result_ty].inner {
3441                    crate::TypeInner::Struct { ref members, .. } => {
3442                        let tmp = "_tmp";
3443                        write!(self.out, "{level}const auto {tmp} = ")?;
3444                        self.put_expression(expr_handle, context, true)?;
3445                        writeln!(self.out, ";")?;
3446                        write!(self.out, "{level}return {struct_name} {{")?;
3447
3448                        let mut is_first = true;
3449
3450                        for (index, member) in members.iter().enumerate() {
3451                            if let Some(crate::Binding::BuiltIn(crate::BuiltIn::PointSize)) =
3452                                member.binding
3453                            {
3454                                has_point_size = true;
3455                                if !context.pipeline_options.allow_and_force_point_size {
3456                                    continue;
3457                                }
3458                            }
3459
3460                            let comma = if is_first { "" } else { "," };
3461                            is_first = false;
3462                            let name = &self.names[&NameKey::StructMember(result_ty, index as u32)];
3463                            // HACK: we are forcefully deduplicating the expression here
3464                            // to convert from a wrapped struct to a raw array, e.g.
3465                            // `float gl_ClipDistance1 [[clip_distance]] [1];`.
3466                            if let crate::TypeInner::Array {
3467                                size: crate::ArraySize::Constant(size),
3468                                ..
3469                            } = context.module.types[member.ty].inner
3470                            {
3471                                write!(self.out, "{comma} {{")?;
3472                                for j in 0..size.get() {
3473                                    if j != 0 {
3474                                        write!(self.out, ",")?;
3475                                    }
3476                                    write!(self.out, "{tmp}.{name}.{WRAPPED_ARRAY_FIELD}[{j}]")?;
3477                                }
3478                                write!(self.out, "}}")?;
3479                            } else {
3480                                write!(self.out, "{comma} {tmp}.{name}")?;
3481                            }
3482                        }
3483                    }
3484                    _ => {
3485                        write!(self.out, "{level}return {struct_name} {{ ")?;
3486                        self.put_expression(expr_handle, context, true)?;
3487                    }
3488                }
3489
3490                if let FunctionOrigin::EntryPoint(ep_index) = context.origin {
3491                    let stage = context.module.entry_points[ep_index as usize].stage;
3492                    if context.pipeline_options.allow_and_force_point_size
3493                        && stage == crate::ShaderStage::Vertex
3494                        && !has_point_size
3495                    {
3496                        // point size was injected and comes last
3497                        write!(self.out, ", 1.0")?;
3498                    }
3499                }
3500                write!(self.out, " }}")?;
3501            }
3502            None => {
3503                write!(self.out, "{level}return ")?;
3504                self.put_expression(expr_handle, context, true)?;
3505            }
3506        }
3507        writeln!(self.out, ";")?;
3508        Ok(())
3509    }
3510
3511    /// Helper method used to find which expressions of a given function require baking
3512    ///
3513    /// # Notes
3514    /// This function overwrites the contents of `self.need_bake_expressions`
3515    fn update_expressions_to_bake(
3516        &mut self,
3517        func: &crate::Function,
3518        info: &valid::FunctionInfo,
3519        context: &ExpressionContext,
3520    ) {
3521        use crate::Expression;
3522        self.need_bake_expressions.clear();
3523
3524        for (expr_handle, expr) in func.expressions.iter() {
3525            // Expressions whose reference count is above the
3526            // threshold should always be stored in temporaries.
3527            let expr_info = &info[expr_handle];
3528            let min_ref_count = func.expressions[expr_handle].bake_ref_count();
3529            if min_ref_count <= expr_info.ref_count {
3530                self.need_bake_expressions.insert(expr_handle);
3531            } else {
3532                match expr_info.ty {
3533                    // force ray desc to be baked: it's used multiple times internally
3534                    TypeResolution::Handle(h)
3535                        if Some(h) == context.module.special_types.ray_desc =>
3536                    {
3537                        self.need_bake_expressions.insert(expr_handle);
3538                    }
3539                    _ => {}
3540                }
3541            }
3542
3543            if let Expression::Math {
3544                fun,
3545                arg,
3546                arg1,
3547                arg2,
3548                ..
3549            } = *expr
3550            {
3551                match fun {
3552                    // WGSL's `dot` function works on any `vecN` type, but Metal's only
3553                    // works on floating-point vectors, so we emit inline code for
3554                    // integer vector `dot` calls. But that code uses each argument `N`
3555                    // times, once for each component (see `put_dot_product`), so to
3556                    // avoid duplicated evaluation, we must bake integer operands.
3557                    // This applies both when using the polyfill (because of the duplicate
3558                    // evaluation issue) and when we don't use the polyfill (because we
3559                    // need them to be emitted before casting to packed chars -- see the
3560                    // comment at the call to `put_casting_to_packed_chars`).
3561                    crate::MathFunction::Dot4U8Packed | crate::MathFunction::Dot4I8Packed => {
3562                        self.need_bake_expressions.insert(arg);
3563                        self.need_bake_expressions.insert(arg1.unwrap());
3564                    }
3565                    crate::MathFunction::FirstLeadingBit => {
3566                        self.need_bake_expressions.insert(arg);
3567                    }
3568                    crate::MathFunction::Pack4xI8
3569                    | crate::MathFunction::Pack4xU8
3570                    | crate::MathFunction::Pack4xI8Clamp
3571                    | crate::MathFunction::Pack4xU8Clamp
3572                    | crate::MathFunction::Unpack4xI8
3573                    | crate::MathFunction::Unpack4xU8 => {
3574                        // On MSL < 2.1, we emit a polyfill for these functions that uses the
3575                        // argument multiple times. This is no longer necessary on MSL >= 2.1.
3576                        if context.lang_version < (2, 1) {
3577                            self.need_bake_expressions.insert(arg);
3578                        }
3579                    }
3580                    crate::MathFunction::ExtractBits => {
3581                        // Only argument 1 is re-used.
3582                        self.need_bake_expressions.insert(arg1.unwrap());
3583                    }
3584                    crate::MathFunction::InsertBits => {
3585                        // Only argument 2 is re-used.
3586                        self.need_bake_expressions.insert(arg2.unwrap());
3587                    }
3588                    crate::MathFunction::Sign => {
3589                        // WGSL's `sign` function works also on signed ints, but Metal's only
3590                        // works on floating points, so we emit inline code for integer `sign`
3591                        // calls. But that code uses each argument 2 times (see `put_isign`),
3592                        // so to avoid duplicated evaluation, we must bake the argument.
3593                        let inner = context.resolve_type(expr_handle);
3594                        if inner.scalar_kind() == Some(crate::ScalarKind::Sint) {
3595                            self.need_bake_expressions.insert(arg);
3596                        }
3597                    }
3598                    _ => {}
3599                }
3600            }
3601        }
3602    }
3603
3604    pub(super) fn start_baking_expression(
3605        &mut self,
3606        handle: Handle<crate::Expression>,
3607        context: &ExpressionContext,
3608        name: &str,
3609    ) -> BackendResult {
3610        match context.info[handle].ty {
3611            TypeResolution::Handle(ty_handle) => {
3612                let ty_name = TypeContext {
3613                    handle: ty_handle,
3614                    gctx: context.module.to_ctx(),
3615                    names: &self.names,
3616                    access: crate::StorageAccess::empty(),
3617                    first_time: false,
3618                };
3619                write!(self.out, "{ty_name}")?;
3620            }
3621            TypeResolution::Value(crate::TypeInner::Scalar(scalar)) => {
3622                put_numeric_type(&mut self.out, scalar, &[])?;
3623            }
3624            TypeResolution::Value(crate::TypeInner::Vector { size, scalar }) => {
3625                put_numeric_type(&mut self.out, scalar, &[size])?;
3626            }
3627            TypeResolution::Value(crate::TypeInner::Matrix {
3628                columns,
3629                rows,
3630                scalar,
3631            }) => {
3632                put_numeric_type(&mut self.out, scalar, &[rows, columns])?;
3633            }
3634            TypeResolution::Value(crate::TypeInner::CooperativeMatrix {
3635                columns,
3636                rows,
3637                scalar,
3638                role: _,
3639            }) => {
3640                write!(
3641                    self.out,
3642                    "{}::simdgroup_{}{}x{}",
3643                    NAMESPACE,
3644                    scalar.to_msl_name(),
3645                    columns as u32,
3646                    rows as u32,
3647                )?;
3648            }
3649            TypeResolution::Value(ref other) => {
3650                log::warn!("Type {other:?} isn't a known local");
3651                return Err(Error::FeatureNotImplemented("weird local type".to_string()));
3652            }
3653        }
3654
3655        //TODO: figure out the naming scheme that wouldn't collide with user names.
3656        write!(self.out, " {name} = ")?;
3657
3658        Ok(())
3659    }
3660
3661    /// Cache a clamped level of detail value, if necessary.
3662    ///
3663    /// [`ImageLoad`] accesses covered by [`BoundsCheckPolicy::Restrict`] use a
3664    /// properly clamped level of detail value both in the access itself, and
3665    /// for fetching the size of the requested MIP level, needed to clamp the
3666    /// coordinates. To avoid recomputing this clamped level of detail, we cache
3667    /// it in a temporary variable, as part of the [`Emit`] statement covering
3668    /// the [`ImageLoad`] expression.
3669    ///
3670    /// [`ImageLoad`]: crate::Expression::ImageLoad
3671    /// [`BoundsCheckPolicy::Restrict`]: index::BoundsCheckPolicy::Restrict
3672    /// [`Emit`]: crate::Statement::Emit
3673    fn put_cache_restricted_level(
3674        &mut self,
3675        load: Handle<crate::Expression>,
3676        image: Handle<crate::Expression>,
3677        mip_level: Option<Handle<crate::Expression>>,
3678        indent: back::Level,
3679        context: &StatementContext,
3680    ) -> BackendResult {
3681        // Does this image access actually require (or even permit) a
3682        // level-of-detail, and does the policy require us to restrict it?
3683        let level_of_detail = match mip_level {
3684            Some(level) => level,
3685            None => return Ok(()),
3686        };
3687
3688        if context.expression.policies.image_load != index::BoundsCheckPolicy::Restrict
3689            || !context.expression.image_needs_lod(image)
3690        {
3691            return Ok(());
3692        }
3693
3694        write!(self.out, "{}uint {} = ", indent, ClampedLod(load),)?;
3695        self.put_restricted_scalar_image_index(
3696            image,
3697            level_of_detail,
3698            "get_num_mip_levels",
3699            &context.expression,
3700        )?;
3701        writeln!(self.out, ";")?;
3702
3703        Ok(())
3704    }
3705
3706    /// Convert the arguments of `Dot4{I, U}Packed` to `packed_(u?)char4`.
3707    ///
3708    /// Caches the results in temporary variables (whose names are derived from
3709    /// the original variable names). This caching avoids the need to redo the
3710    /// casting for each vector component when emitting the dot product.
3711    fn put_casting_to_packed_chars(
3712        &mut self,
3713        fun: crate::MathFunction,
3714        arg0: Handle<crate::Expression>,
3715        arg1: Handle<crate::Expression>,
3716        indent: back::Level,
3717        context: &StatementContext<'_>,
3718    ) -> Result<(), Error> {
3719        let packed_type = match fun {
3720            crate::MathFunction::Dot4I8Packed => "packed_char4",
3721            crate::MathFunction::Dot4U8Packed => "packed_uchar4",
3722            _ => unreachable!(),
3723        };
3724
3725        for arg in [arg0, arg1] {
3726            write!(
3727                self.out,
3728                "{indent}{packed_type} {0} = as_type<{packed_type}>(",
3729                Reinterpreted::new(packed_type, arg)
3730            )?;
3731            self.put_expression(arg, &context.expression, true)?;
3732            writeln!(self.out, ");")?;
3733        }
3734
3735        Ok(())
3736    }
3737
3738    fn put_block(
3739        &mut self,
3740        level: back::Level,
3741        statements: &[crate::Statement],
3742        context: &StatementContext,
3743    ) -> BackendResult {
3744        // Add to the set in order to track the stack size.
3745        #[cfg(test)]
3746        self.put_block_stack_pointers
3747            .insert(ptr::from_ref(&level).cast());
3748
3749        for statement in statements {
3750            log::trace!("statement[{}] {:?}", level.0, statement);
3751            match *statement {
3752                crate::Statement::Emit(ref range) => {
3753                    for handle in range.clone() {
3754                        use crate::MathFunction as Mf;
3755
3756                        match context.expression.function.expressions[handle] {
3757                            // `ImageLoad` expressions covered by the `Restrict` bounds check policy
3758                            // may need to cache a clamped version of their level-of-detail argument.
3759                            crate::Expression::ImageLoad {
3760                                image,
3761                                level: mip_level,
3762                                ..
3763                            } => {
3764                                self.put_cache_restricted_level(
3765                                    handle, image, mip_level, level, context,
3766                                )?;
3767                            }
3768
3769                            // If we are going to write a `Dot4I8Packed` or `Dot4U8Packed` on Metal
3770                            // 2.1+ then we introduce two intermediate variables that recast the two
3771                            // arguments as packed (signed or unsigned) chars. The actual dot product
3772                            // is implemented in `Self::put_expression`, and it uses both of these
3773                            // intermediate variables multiple times. There's no danger that the
3774                            // original arguments get modified between the definition of these
3775                            // intermediate variables and the implementation of the actual dot
3776                            // product since we require the inputs of `Dot4{I, U}Packed` to be baked.
3777                            crate::Expression::Math {
3778                                fun: fun @ (Mf::Dot4I8Packed | Mf::Dot4U8Packed),
3779                                arg,
3780                                arg1,
3781                                ..
3782                            } if context.expression.lang_version >= (2, 1) => {
3783                                self.put_casting_to_packed_chars(
3784                                    fun,
3785                                    arg,
3786                                    arg1.unwrap(),
3787                                    level,
3788                                    context,
3789                                )?;
3790                            }
3791
3792                            _ => (),
3793                        }
3794
3795                        let ptr_class = context.expression.resolve_type(handle).pointer_space();
3796                        let expr_name = if ptr_class.is_some() {
3797                            None // don't bake pointer expressions (just yet)
3798                        } else if let Some(name) =
3799                            context.expression.function.named_expressions.get(&handle)
3800                        {
3801                            // The `crate::Function::named_expressions` table holds
3802                            // expressions that should be saved in temporaries once they
3803                            // are `Emit`ted. We only add them to `self.named_expressions`
3804                            // when we reach the `Emit` that covers them, so that we don't
3805                            // try to use their names before we've actually initialized
3806                            // the temporary that holds them.
3807                            //
3808                            // Don't assume the names in `named_expressions` are unique,
3809                            // or even valid. Use the `Namer`.
3810                            Some(self.namer.call(name))
3811                        } else {
3812                            // If this expression is an index that we're going to first compare
3813                            // against a limit, and then actually use as an index, then we may
3814                            // want to cache it in a temporary, to avoid evaluating it twice.
3815                            let bake = if context.expression.guarded_indices.contains(handle) {
3816                                true
3817                            } else {
3818                                self.need_bake_expressions.contains(&handle)
3819                            };
3820
3821                            if bake {
3822                                Some(Baked(handle).to_string())
3823                            } else {
3824                                None
3825                            }
3826                        };
3827
3828                        if let Some(name) = expr_name {
3829                            write!(self.out, "{level}")?;
3830                            self.start_baking_expression(handle, &context.expression, &name)?;
3831                            self.put_expression(handle, &context.expression, true)?;
3832                            self.named_expressions.insert(handle, name);
3833                            writeln!(self.out, ";")?;
3834                        }
3835                    }
3836                }
3837                crate::Statement::Block(ref block) => {
3838                    if !block.is_empty() {
3839                        writeln!(self.out, "{level}{{")?;
3840                        self.put_block(level.next(), block, context)?;
3841                        writeln!(self.out, "{level}}}")?;
3842                    }
3843                }
3844                crate::Statement::If {
3845                    condition,
3846                    ref accept,
3847                    ref reject,
3848                } => {
3849                    write!(self.out, "{level}if (")?;
3850                    self.put_expression(condition, &context.expression, true)?;
3851                    writeln!(self.out, ") {{")?;
3852                    self.put_block(level.next(), accept, context)?;
3853                    if !reject.is_empty() {
3854                        writeln!(self.out, "{level}}} else {{")?;
3855                        self.put_block(level.next(), reject, context)?;
3856                    }
3857                    writeln!(self.out, "{level}}}")?;
3858                }
3859                crate::Statement::Switch {
3860                    selector,
3861                    ref cases,
3862                } => {
3863                    write!(self.out, "{level}switch(")?;
3864                    self.put_expression(selector, &context.expression, true)?;
3865                    writeln!(self.out, ") {{")?;
3866                    let lcase = level.next();
3867                    for case in cases.iter() {
3868                        match case.value {
3869                            crate::SwitchValue::I32(value) => {
3870                                write!(self.out, "{lcase}case {value}:")?;
3871                            }
3872                            crate::SwitchValue::U32(value) => {
3873                                write!(self.out, "{lcase}case {value}u:")?;
3874                            }
3875                            crate::SwitchValue::Default => {
3876                                write!(self.out, "{lcase}default:")?;
3877                            }
3878                        }
3879
3880                        let write_block_braces = !(case.fall_through && case.body.is_empty());
3881                        if write_block_braces {
3882                            writeln!(self.out, " {{")?;
3883                        } else {
3884                            writeln!(self.out)?;
3885                        }
3886
3887                        self.put_block(lcase.next(), &case.body, context)?;
3888                        if !case.fall_through && case.body.last().is_none_or(|s| !s.is_terminator())
3889                        {
3890                            writeln!(self.out, "{}break;", lcase.next())?;
3891                        }
3892
3893                        if write_block_braces {
3894                            writeln!(self.out, "{lcase}}}")?;
3895                        }
3896                    }
3897                    writeln!(self.out, "{level}}}")?;
3898                }
3899                crate::Statement::Loop {
3900                    ref body,
3901                    ref continuing,
3902                    break_if,
3903                } => {
3904                    let force_loop_bound_statements =
3905                        self.gen_force_bounded_loop_statements(level, context);
3906                    let gate_name = (!continuing.is_empty() || break_if.is_some())
3907                        .then(|| self.namer.call("loop_init"));
3908
3909                    if let Some((ref decl, _)) = force_loop_bound_statements {
3910                        writeln!(self.out, "{decl}")?;
3911                    }
3912                    if let Some(ref gate_name) = gate_name {
3913                        writeln!(self.out, "{level}bool {gate_name} = true;")?;
3914                    }
3915
3916                    writeln!(self.out, "{level}while(true) {{",)?;
3917                    if let Some((_, ref break_and_inc)) = force_loop_bound_statements {
3918                        writeln!(self.out, "{break_and_inc}")?;
3919                    }
3920                    if let Some(ref gate_name) = gate_name {
3921                        let lif = level.next();
3922                        let lcontinuing = lif.next();
3923                        writeln!(self.out, "{lif}if (!{gate_name}) {{")?;
3924                        self.put_block(lcontinuing, continuing, context)?;
3925                        if let Some(condition) = break_if {
3926                            write!(self.out, "{lcontinuing}if (")?;
3927                            self.put_expression(condition, &context.expression, true)?;
3928                            writeln!(self.out, ") {{")?;
3929                            writeln!(self.out, "{}break;", lcontinuing.next())?;
3930                            writeln!(self.out, "{lcontinuing}}}")?;
3931                        }
3932                        writeln!(self.out, "{lif}}}")?;
3933                        writeln!(self.out, "{lif}{gate_name} = false;")?;
3934                    }
3935                    self.put_block(level.next(), body, context)?;
3936
3937                    writeln!(self.out, "{level}}}")?;
3938                }
3939                crate::Statement::Break => {
3940                    writeln!(self.out, "{level}break;")?;
3941                }
3942                crate::Statement::Continue => {
3943                    writeln!(self.out, "{level}continue;")?;
3944                }
3945                crate::Statement::Return {
3946                    value: Some(expr_handle),
3947                } => {
3948                    self.put_return_value(
3949                        level,
3950                        expr_handle,
3951                        context.result_struct,
3952                        &context.expression,
3953                    )?;
3954                }
3955                crate::Statement::Return { value: None } => {
3956                    writeln!(self.out, "{level}return;")?;
3957                }
3958                crate::Statement::Kill => {
3959                    writeln!(self.out, "{level}{NAMESPACE}::discard_fragment();")?;
3960                }
3961                crate::Statement::ControlBarrier(flags)
3962                | crate::Statement::MemoryBarrier(flags) => {
3963                    self.write_barrier(flags, level)?;
3964                }
3965                crate::Statement::Store { pointer, value } => {
3966                    self.put_store(pointer, value, level, context)?
3967                }
3968                crate::Statement::ImageStore {
3969                    image,
3970                    coordinate,
3971                    array_index,
3972                    value,
3973                } => {
3974                    let address = TexelAddress {
3975                        coordinate,
3976                        array_index,
3977                        sample: None,
3978                        level: None,
3979                    };
3980                    self.put_image_store(level, image, &address, value, context)?
3981                }
3982                crate::Statement::Call {
3983                    function,
3984                    ref arguments,
3985                    result,
3986                } => {
3987                    write!(self.out, "{level}")?;
3988                    if let Some(expr) = result {
3989                        let name = Baked(expr).to_string();
3990                        self.start_baking_expression(expr, &context.expression, &name)?;
3991                        self.named_expressions.insert(expr, name);
3992                    }
3993                    let fun_name = &self.names[&NameKey::Function(function)];
3994                    write!(self.out, "{fun_name}(")?;
3995                    // first, write down the actual arguments
3996                    for (i, &handle) in arguments.iter().enumerate() {
3997                        if i != 0 {
3998                            write!(self.out, ", ")?;
3999                        }
4000                        self.put_expression(handle, &context.expression, true)?;
4001                    }
4002                    // follow-up with any global resources used
4003                    let mut separate = !arguments.is_empty();
4004                    let fun_info = &context.expression.mod_info[function];
4005                    let mut needs_buffer_sizes = false;
4006                    for (handle, var) in context.expression.module.global_variables.iter() {
4007                        if fun_info[handle].is_empty() {
4008                            continue;
4009                        }
4010                        if var.space.needs_pass_through() {
4011                            let name = &self.names[&NameKey::GlobalVariable(handle)];
4012                            if separate {
4013                                write!(self.out, ", ")?;
4014                            } else {
4015                                separate = true;
4016                            }
4017                            write!(self.out, "{name}")?;
4018                        }
4019                        needs_buffer_sizes |=
4020                            needs_array_length(var.ty, &context.expression.module.types);
4021                    }
4022                    if needs_buffer_sizes {
4023                        if separate {
4024                            write!(self.out, ", ")?;
4025                        }
4026                        write!(self.out, "_buffer_sizes")?;
4027                    }
4028
4029                    // done
4030                    writeln!(self.out, ");")?;
4031                }
4032                crate::Statement::Atomic {
4033                    pointer,
4034                    ref fun,
4035                    value,
4036                    result,
4037                } => {
4038                    let context = &context.expression;
4039
4040                    // This backend supports `SHADER_INT64_ATOMIC_MIN_MAX` but not
4041                    // `SHADER_INT64_ATOMIC_ALL_OPS`, so we can assume that if `result` is
4042                    // `Some`, we are not operating on a 64-bit value, and that if we are
4043                    // operating on a 64-bit value, `result` is `None`.
4044                    write!(self.out, "{level}")?;
4045                    let fun_key = if let Some(result) = result {
4046                        let res_name = Baked(result).to_string();
4047                        self.start_baking_expression(result, context, &res_name)?;
4048                        self.named_expressions.insert(result, res_name);
4049                        fun.to_msl()
4050                    } else if context.resolve_type(value).scalar_width() == Some(8) {
4051                        fun.to_msl_64_bit()?
4052                    } else {
4053                        fun.to_msl()
4054                    };
4055
4056                    // If the pointer we're passing to the atomic operation needs to be conditional
4057                    // for `ReadZeroSkipWrite`, the condition needs to *surround* the atomic op, and
4058                    // the pointer operand should be unchecked.
4059                    let policy = context.choose_bounds_check_policy(pointer);
4060                    let checked = policy == index::BoundsCheckPolicy::ReadZeroSkipWrite
4061                        && self.put_bounds_checks(pointer, context, back::Level(0), "")?;
4062
4063                    // If requested and successfully put bounds checks, continue the ternary expression.
4064                    if checked {
4065                        write!(self.out, " ? ")?;
4066                    }
4067
4068                    // Put the atomic function invocation.
4069                    match *fun {
4070                        crate::AtomicFunction::Exchange { compare: Some(cmp) } => {
4071                            write!(self.out, "{ATOMIC_COMP_EXCH_FUNCTION}({ATOMIC_REFERENCE}")?;
4072                            self.put_access_chain(pointer, policy, context)?;
4073                            write!(self.out, ", ")?;
4074                            self.put_expression(cmp, context, true)?;
4075                            write!(self.out, ", ")?;
4076                            self.put_expression(value, context, true)?;
4077                            write!(self.out, ")")?;
4078                        }
4079                        _ => {
4080                            write!(
4081                                self.out,
4082                                "{NAMESPACE}::atomic_{fun_key}_explicit({ATOMIC_REFERENCE}"
4083                            )?;
4084                            self.put_access_chain(pointer, policy, context)?;
4085                            write!(self.out, ", ")?;
4086                            self.put_expression(value, context, true)?;
4087                            write!(self.out, ", {NAMESPACE}::memory_order_relaxed)")?;
4088                        }
4089                    }
4090
4091                    // Finish the ternary expression.
4092                    if checked {
4093                        write!(self.out, " : DefaultConstructible()")?;
4094                    }
4095
4096                    // Done
4097                    writeln!(self.out, ";")?;
4098                }
4099                crate::Statement::ImageAtomic {
4100                    image,
4101                    coordinate,
4102                    array_index,
4103                    fun,
4104                    value,
4105                } => {
4106                    let address = TexelAddress {
4107                        coordinate,
4108                        array_index,
4109                        sample: None,
4110                        level: None,
4111                    };
4112                    self.put_image_atomic(level, image, &address, fun, value, context)?
4113                }
4114                crate::Statement::WorkGroupUniformLoad { pointer, result } => {
4115                    self.write_barrier(crate::Barrier::WORK_GROUP, level)?;
4116
4117                    write!(self.out, "{level}")?;
4118                    let name = self.namer.call("");
4119                    self.start_baking_expression(result, &context.expression, &name)?;
4120                    self.put_load(pointer, &context.expression, true)?;
4121                    self.named_expressions.insert(result, name);
4122
4123                    writeln!(self.out, ";")?;
4124                    self.write_barrier(crate::Barrier::WORK_GROUP, level)?;
4125                }
4126                crate::Statement::RayQuery { query, ref fun } => {
4127                    self.write_ray_query_stmt(level, context, query, fun)?;
4128                }
4129                crate::Statement::SubgroupBallot { result, predicate } => {
4130                    write!(self.out, "{level}")?;
4131                    let name = self.namer.call("");
4132                    self.start_baking_expression(result, &context.expression, &name)?;
4133                    self.named_expressions.insert(result, name);
4134                    write!(
4135                        self.out,
4136                        "{NAMESPACE}::uint4((uint64_t){NAMESPACE}::simd_ballot("
4137                    )?;
4138                    if let Some(predicate) = predicate {
4139                        self.put_expression(predicate, &context.expression, true)?;
4140                    } else {
4141                        write!(self.out, "true")?;
4142                    }
4143                    writeln!(self.out, "), 0, 0, 0);")?;
4144                }
4145                crate::Statement::SubgroupCollectiveOperation {
4146                    op,
4147                    collective_op,
4148                    argument,
4149                    result,
4150                } => {
4151                    write!(self.out, "{level}")?;
4152                    let name = self.namer.call("");
4153                    self.start_baking_expression(result, &context.expression, &name)?;
4154                    self.named_expressions.insert(result, name);
4155                    match (collective_op, op) {
4156                        (crate::CollectiveOperation::Reduce, crate::SubgroupOperation::All) => {
4157                            write!(self.out, "{NAMESPACE}::simd_all(")?
4158                        }
4159                        (crate::CollectiveOperation::Reduce, crate::SubgroupOperation::Any) => {
4160                            write!(self.out, "{NAMESPACE}::simd_any(")?
4161                        }
4162                        (crate::CollectiveOperation::Reduce, crate::SubgroupOperation::Add) => {
4163                            write!(self.out, "{NAMESPACE}::simd_sum(")?
4164                        }
4165                        (crate::CollectiveOperation::Reduce, crate::SubgroupOperation::Mul) => {
4166                            write!(self.out, "{NAMESPACE}::simd_product(")?
4167                        }
4168                        (crate::CollectiveOperation::Reduce, crate::SubgroupOperation::Max) => {
4169                            write!(self.out, "{NAMESPACE}::simd_max(")?
4170                        }
4171                        (crate::CollectiveOperation::Reduce, crate::SubgroupOperation::Min) => {
4172                            write!(self.out, "{NAMESPACE}::simd_min(")?
4173                        }
4174                        (crate::CollectiveOperation::Reduce, crate::SubgroupOperation::And) => {
4175                            write!(self.out, "{NAMESPACE}::simd_and(")?
4176                        }
4177                        (crate::CollectiveOperation::Reduce, crate::SubgroupOperation::Or) => {
4178                            write!(self.out, "{NAMESPACE}::simd_or(")?
4179                        }
4180                        (crate::CollectiveOperation::Reduce, crate::SubgroupOperation::Xor) => {
4181                            write!(self.out, "{NAMESPACE}::simd_xor(")?
4182                        }
4183                        (
4184                            crate::CollectiveOperation::ExclusiveScan,
4185                            crate::SubgroupOperation::Add,
4186                        ) => write!(self.out, "{NAMESPACE}::simd_prefix_exclusive_sum(")?,
4187                        (
4188                            crate::CollectiveOperation::ExclusiveScan,
4189                            crate::SubgroupOperation::Mul,
4190                        ) => write!(self.out, "{NAMESPACE}::simd_prefix_exclusive_product(")?,
4191                        (
4192                            crate::CollectiveOperation::InclusiveScan,
4193                            crate::SubgroupOperation::Add,
4194                        ) => write!(self.out, "{NAMESPACE}::simd_prefix_inclusive_sum(")?,
4195                        (
4196                            crate::CollectiveOperation::InclusiveScan,
4197                            crate::SubgroupOperation::Mul,
4198                        ) => write!(self.out, "{NAMESPACE}::simd_prefix_inclusive_product(")?,
4199                        _ => unimplemented!(),
4200                    }
4201                    self.put_expression(argument, &context.expression, true)?;
4202                    writeln!(self.out, ");")?;
4203                }
4204                crate::Statement::SubgroupGather {
4205                    mode,
4206                    argument,
4207                    result,
4208                } => {
4209                    write!(self.out, "{level}")?;
4210                    let name = self.namer.call("");
4211                    self.start_baking_expression(result, &context.expression, &name)?;
4212                    self.named_expressions.insert(result, name);
4213                    match mode {
4214                        crate::GatherMode::BroadcastFirst => {
4215                            write!(self.out, "{NAMESPACE}::simd_broadcast_first(")?;
4216                        }
4217                        crate::GatherMode::Broadcast(_) => {
4218                            write!(self.out, "{NAMESPACE}::simd_broadcast(")?;
4219                        }
4220                        crate::GatherMode::Shuffle(_) => {
4221                            write!(self.out, "{NAMESPACE}::simd_shuffle(")?;
4222                        }
4223                        crate::GatherMode::ShuffleDown(_) => {
4224                            write!(self.out, "{NAMESPACE}::simd_shuffle_down(")?;
4225                        }
4226                        crate::GatherMode::ShuffleUp(_) => {
4227                            write!(self.out, "{NAMESPACE}::simd_shuffle_up(")?;
4228                        }
4229                        crate::GatherMode::ShuffleXor(_) => {
4230                            write!(self.out, "{NAMESPACE}::simd_shuffle_xor(")?;
4231                        }
4232                        crate::GatherMode::QuadBroadcast(_) => {
4233                            write!(self.out, "{NAMESPACE}::quad_broadcast(")?;
4234                        }
4235                        crate::GatherMode::QuadSwap(_) => {
4236                            write!(self.out, "{NAMESPACE}::quad_shuffle_xor(")?;
4237                        }
4238                    }
4239                    self.put_expression(argument, &context.expression, true)?;
4240                    match mode {
4241                        crate::GatherMode::BroadcastFirst => {}
4242                        crate::GatherMode::Broadcast(index)
4243                        | crate::GatherMode::Shuffle(index)
4244                        | crate::GatherMode::ShuffleDown(index)
4245                        | crate::GatherMode::ShuffleUp(index)
4246                        | crate::GatherMode::ShuffleXor(index)
4247                        | crate::GatherMode::QuadBroadcast(index) => {
4248                            write!(self.out, ", ")?;
4249                            self.put_expression(index, &context.expression, true)?;
4250                        }
4251                        crate::GatherMode::QuadSwap(direction) => {
4252                            write!(self.out, ", ")?;
4253                            match direction {
4254                                crate::Direction::X => {
4255                                    write!(self.out, "1u")?;
4256                                }
4257                                crate::Direction::Y => {
4258                                    write!(self.out, "2u")?;
4259                                }
4260                                crate::Direction::Diagonal => {
4261                                    write!(self.out, "3u")?;
4262                                }
4263                            }
4264                        }
4265                    }
4266                    writeln!(self.out, ");")?;
4267                }
4268                crate::Statement::CooperativeStore { target, ref data } => {
4269                    write!(self.out, "{level}simdgroup_store(")?;
4270                    self.put_expression(target, &context.expression, true)?;
4271                    write!(self.out, ", &")?;
4272                    self.put_access_chain(
4273                        data.pointer,
4274                        context.expression.policies.index,
4275                        &context.expression,
4276                    )?;
4277                    write!(self.out, ", ")?;
4278                    self.put_expression(data.stride, &context.expression, true)?;
4279                    // See the comment in `CooperativeLoad` above: WGSL's
4280                    // row_major flag is negated when emitting Metal's
4281                    // `transpose` flag, so a col-major store (row_major=false)
4282                    // must use `transpose=true`.
4283                    if !data.row_major {
4284                        let matrix_origin = "0";
4285                        let transpose = true;
4286                        write!(self.out, ", {matrix_origin}, {transpose}")?;
4287                    }
4288                    writeln!(self.out, ");")?;
4289                }
4290                crate::Statement::RayPipelineFunction(_) => unreachable!(),
4291            }
4292        }
4293
4294        // un-emit expressions
4295        //TODO: take care of loop/continuing?
4296        for statement in statements {
4297            if let crate::Statement::Emit(ref range) = *statement {
4298                for handle in range.clone() {
4299                    self.named_expressions.shift_remove(&handle);
4300                }
4301            }
4302        }
4303        Ok(())
4304    }
4305
4306    fn put_store(
4307        &mut self,
4308        pointer: Handle<crate::Expression>,
4309        value: Handle<crate::Expression>,
4310        level: back::Level,
4311        context: &StatementContext,
4312    ) -> BackendResult {
4313        let policy = context.expression.choose_bounds_check_policy(pointer);
4314        if policy == index::BoundsCheckPolicy::ReadZeroSkipWrite
4315            && self.put_bounds_checks(pointer, &context.expression, level, "if (")?
4316        {
4317            writeln!(self.out, ") {{")?;
4318            self.put_unchecked_store(pointer, value, policy, level.next(), context)?;
4319            writeln!(self.out, "{level}}}")?;
4320        } else {
4321            self.put_unchecked_store(pointer, value, policy, level, context)?;
4322        }
4323
4324        Ok(())
4325    }
4326
4327    fn put_unchecked_store(
4328        &mut self,
4329        pointer: Handle<crate::Expression>,
4330        value: Handle<crate::Expression>,
4331        policy: index::BoundsCheckPolicy,
4332        level: back::Level,
4333        context: &StatementContext,
4334    ) -> BackendResult {
4335        let is_atomic_pointer = context
4336            .expression
4337            .resolve_type(pointer)
4338            .is_atomic_pointer(&context.expression.module.types);
4339
4340        if is_atomic_pointer {
4341            write!(
4342                self.out,
4343                "{level}{NAMESPACE}::atomic_store_explicit({ATOMIC_REFERENCE}"
4344            )?;
4345            self.put_access_chain(pointer, policy, &context.expression)?;
4346            write!(self.out, ", ")?;
4347            self.put_expression(value, &context.expression, true)?;
4348            writeln!(self.out, ", {NAMESPACE}::memory_order_relaxed);")?;
4349        } else {
4350            write!(self.out, "{level}")?;
4351            self.put_access_chain(pointer, policy, &context.expression)?;
4352            write!(self.out, " = ")?;
4353            self.put_expression(value, &context.expression, true)?;
4354            writeln!(self.out, ";")?;
4355        }
4356
4357        Ok(())
4358    }
4359
4360    pub fn write(
4361        &mut self,
4362        module: &crate::Module,
4363        info: &valid::ModuleInfo,
4364        options: &Options,
4365        pipeline_options: &PipelineOptions,
4366    ) -> Result<TranslationInfo, Error> {
4367        self.names.clear();
4368        self.namer.reset(
4369            module,
4370            &super::keywords::RESERVED_SET,
4371            proc::KeywordSet::empty(),
4372            proc::CaseInsensitiveKeywordSet::empty(),
4373            &[
4374                CLAMPED_LOD_LOAD_PREFIX,
4375                super::ray::INTERSECTION_FUNCTION_NAME,
4376            ],
4377            &mut self.names,
4378        );
4379        self.wrapped_functions.clear();
4380        self.struct_member_pads.clear();
4381
4382        writeln!(
4383            self.out,
4384            "// language: metal{}.{}",
4385            options.lang_version.0, options.lang_version.1
4386        )?;
4387        writeln!(self.out, "#include <metal_stdlib>")?;
4388        writeln!(self.out, "#include <simd/simd.h>")?;
4389        writeln!(self.out)?;
4390        // Work around Metal bug where `uint` is not available by default
4391        writeln!(self.out, "using {NAMESPACE}::uint;")?;
4392
4393        if module.uses_mesh_shaders() && options.lang_version < (3, 0) {
4394            return Err(Error::UnsupportedMeshShader);
4395        }
4396        self.needs_object_memory_barriers = module
4397            .entry_points
4398            .iter()
4399            .any(|e| e.stage == crate::ShaderStage::Task && e.task_payload.is_some());
4400
4401        if module.special_types.ray_desc.is_some()
4402            || module.special_types.ray_intersection.is_some()
4403        {
4404            if options.lang_version < (2, 4) {
4405                return Err(Error::UnsupportedRayTracing);
4406            }
4407        }
4408
4409        if options
4410            .bounds_check_policies
4411            .contains(index::BoundsCheckPolicy::ReadZeroSkipWrite)
4412        {
4413            self.put_default_constructible()?;
4414        }
4415        writeln!(self.out)?;
4416
4417        {
4418            // Make a `Vec` of all the `GlobalVariable`s that contain
4419            // runtime-sized arrays.
4420            let globals: Vec<Handle<crate::GlobalVariable>> = module
4421                .global_variables
4422                .iter()
4423                .filter(|&(_, var)| needs_array_length(var.ty, &module.types))
4424                .map(|(handle, _)| handle)
4425                .collect();
4426
4427            let mut buffer_indices = vec![];
4428            for vbm in &pipeline_options.vertex_buffer_mappings {
4429                buffer_indices.push(vbm.id);
4430            }
4431
4432            if !globals.is_empty() || !buffer_indices.is_empty() {
4433                writeln!(self.out, "struct _mslBufferSizes {{")?;
4434
4435                for global in globals {
4436                    writeln!(
4437                        self.out,
4438                        "{}uint {};",
4439                        back::INDENT,
4440                        ArraySizeMember(global)
4441                    )?;
4442                }
4443
4444                for idx in buffer_indices {
4445                    writeln!(self.out, "{}uint buffer_size{};", back::INDENT, idx)?;
4446                }
4447
4448                writeln!(self.out, "}};")?;
4449                writeln!(self.out)?;
4450            }
4451        };
4452
4453        self.write_type_defs(module)?;
4454        self.write_global_constants(module, info)?;
4455        self.write_functions(module, info, options, pipeline_options)
4456    }
4457
4458    /// Write the definition for the `DefaultConstructible` class.
4459    ///
4460    /// The [`ReadZeroSkipWrite`] bounds check policy requires us to be able to
4461    /// produce 'zero' values for any type, including structs, arrays, and so
4462    /// on. We could do this by emitting default constructor applications, but
4463    /// that would entail printing the name of the type, which is more trouble
4464    /// than you'd think. Instead, we just construct this magic C++14 class that
4465    /// can be converted to any type that can be default constructed, using
4466    /// template parameter inference to detect which type is needed, so we don't
4467    /// have to figure out the name.
4468    ///
4469    /// [`ReadZeroSkipWrite`]: index::BoundsCheckPolicy::ReadZeroSkipWrite
4470    fn put_default_constructible(&mut self) -> BackendResult {
4471        let tab = back::INDENT;
4472        writeln!(self.out, "struct DefaultConstructible {{")?;
4473        writeln!(self.out, "{tab}template<typename T>")?;
4474        writeln!(self.out, "{tab}operator T() && {{")?;
4475        writeln!(self.out, "{tab}{tab}return T {{}};")?;
4476        writeln!(self.out, "{tab}}}")?;
4477        writeln!(self.out, "}};")?;
4478        Ok(())
4479    }
4480
4481    fn write_type_defs(&mut self, module: &crate::Module) -> BackendResult {
4482        let mut generated_argument_buffer_wrapper = false;
4483        let mut generated_external_texture_wrapper = false;
4484        for (handle, ty) in module.types.iter() {
4485            match ty.inner {
4486                crate::TypeInner::BindingArray { .. } if !generated_argument_buffer_wrapper => {
4487                    writeln!(self.out, "template <typename T>")?;
4488                    writeln!(self.out, "struct {ARGUMENT_BUFFER_WRAPPER_STRUCT} {{")?;
4489                    writeln!(self.out, "{}T {WRAPPED_ARRAY_FIELD};", back::INDENT)?;
4490                    writeln!(self.out, "}};")?;
4491                    generated_argument_buffer_wrapper = true;
4492                }
4493                crate::TypeInner::Image {
4494                    class: crate::ImageClass::External,
4495                    ..
4496                } if !generated_external_texture_wrapper => {
4497                    let params_ty_name = &self.names
4498                        [&NameKey::Type(module.special_types.external_texture_params.unwrap())];
4499                    writeln!(self.out, "struct {EXTERNAL_TEXTURE_WRAPPER_STRUCT} {{")?;
4500                    writeln!(
4501                        self.out,
4502                        "{}{NAMESPACE}::texture2d<float, {NAMESPACE}::access::sample> plane0;",
4503                        back::INDENT
4504                    )?;
4505                    writeln!(
4506                        self.out,
4507                        "{}{NAMESPACE}::texture2d<float, {NAMESPACE}::access::sample> plane1;",
4508                        back::INDENT
4509                    )?;
4510                    writeln!(
4511                        self.out,
4512                        "{}{NAMESPACE}::texture2d<float, {NAMESPACE}::access::sample> plane2;",
4513                        back::INDENT
4514                    )?;
4515                    writeln!(self.out, "{}{params_ty_name} params;", back::INDENT)?;
4516                    writeln!(self.out, "}};")?;
4517                    generated_external_texture_wrapper = true;
4518                }
4519                _ => {}
4520            }
4521
4522            if !ty.needs_alias() {
4523                continue;
4524            }
4525            let name = &self.names[&NameKey::Type(handle)];
4526            match ty.inner {
4527                // Naga IR can pass around arrays by value, but Metal, following
4528                // C++, performs an array-to-pointer conversion (C++ [conv.array])
4529                // on expressions of array type, so assigning the array by value
4530                // isn't possible. However, Metal *does* assign structs by
4531                // value. So in our Metal output, we wrap all array types in
4532                // synthetic struct types:
4533                //
4534                //     struct type1 {
4535                //         float inner[10]
4536                //     };
4537                //
4538                // Then we carefully include `.inner` (`WRAPPED_ARRAY_FIELD`) in
4539                // any expression that actually wants access to the array.
4540                crate::TypeInner::Array {
4541                    base,
4542                    size,
4543                    stride: _,
4544                } => {
4545                    let base_name = TypeContext {
4546                        handle: base,
4547                        gctx: module.to_ctx(),
4548                        names: &self.names,
4549                        access: crate::StorageAccess::empty(),
4550                        first_time: false,
4551                    };
4552
4553                    match size.resolve(module.to_ctx())? {
4554                        proc::IndexableLength::Known(size) => {
4555                            writeln!(self.out, "struct {name} {{")?;
4556                            writeln!(
4557                                self.out,
4558                                "{}{} {}[{}];",
4559                                back::INDENT,
4560                                base_name,
4561                                WRAPPED_ARRAY_FIELD,
4562                                size
4563                            )?;
4564                            writeln!(self.out, "}};")?;
4565                        }
4566                        proc::IndexableLength::Dynamic => {
4567                            writeln!(self.out, "typedef {base_name} {name}[1];")?;
4568                        }
4569                    }
4570                }
4571                crate::TypeInner::Struct {
4572                    ref members, span, ..
4573                } => {
4574                    writeln!(self.out, "struct {name} {{")?;
4575                    let mut last_offset = 0;
4576                    for (index, member) in members.iter().enumerate() {
4577                        if member.offset > last_offset {
4578                            self.struct_member_pads.insert((handle, index as u32));
4579                            let pad = member.offset - last_offset;
4580                            writeln!(self.out, "{}char _pad{}[{}];", back::INDENT, index, pad)?;
4581                        }
4582                        let ty_inner = &module.types[member.ty].inner;
4583                        last_offset = member.offset + ty_inner.size(module.to_ctx());
4584
4585                        let member_name = &self.names[&NameKey::StructMember(handle, index as u32)];
4586
4587                        // If the member should be packed (as is the case for a misaligned vec3) issue a packed vector
4588                        match should_pack_struct_member(members, span, index, module) {
4589                            Some(scalar) => {
4590                                writeln!(
4591                                    self.out,
4592                                    "{}{}::packed_{}3 {};",
4593                                    back::INDENT,
4594                                    NAMESPACE,
4595                                    scalar.to_msl_name(),
4596                                    member_name
4597                                )?;
4598                            }
4599                            None => {
4600                                let base_name = TypeContext {
4601                                    handle: member.ty,
4602                                    gctx: module.to_ctx(),
4603                                    names: &self.names,
4604                                    access: crate::StorageAccess::empty(),
4605                                    first_time: false,
4606                                };
4607                                writeln!(
4608                                    self.out,
4609                                    "{}{} {};",
4610                                    back::INDENT,
4611                                    base_name,
4612                                    member_name
4613                                )?;
4614
4615                                // for 3-component vectors, add one component
4616                                if let crate::TypeInner::Vector {
4617                                    size: crate::VectorSize::Tri,
4618                                    scalar,
4619                                } = *ty_inner
4620                                {
4621                                    last_offset += scalar.width as u32;
4622                                }
4623                            }
4624                        }
4625                    }
4626                    if last_offset < span {
4627                        let pad = span - last_offset;
4628                        writeln!(
4629                            self.out,
4630                            "{}char _pad{}[{}];",
4631                            back::INDENT,
4632                            members.len(),
4633                            pad
4634                        )?;
4635                    }
4636                    writeln!(self.out, "}};")?;
4637                }
4638                _ => {
4639                    let ty_name = TypeContext {
4640                        handle,
4641                        gctx: module.to_ctx(),
4642                        names: &self.names,
4643                        access: crate::StorageAccess::empty(),
4644                        first_time: true,
4645                    };
4646                    writeln!(self.out, "typedef {ty_name} {name};")?;
4647                }
4648            }
4649        }
4650
4651        // Write functions to create special types.
4652        for (type_key, struct_ty) in module.special_types.predeclared_types.iter() {
4653            match type_key {
4654                &crate::PredeclaredType::ModfResult { size, scalar }
4655                | &crate::PredeclaredType::FrexpResult { size, scalar } => {
4656                    let arg_type_name_owner;
4657                    let arg_type_name = if let Some(size) = size {
4658                        arg_type_name_owner = format!(
4659                            "{NAMESPACE}::{}{}",
4660                            if scalar.width == 8 { "double" } else { "float" },
4661                            size as u8
4662                        );
4663                        &arg_type_name_owner
4664                    } else if scalar.width == 8 {
4665                        "double"
4666                    } else {
4667                        "float"
4668                    };
4669
4670                    let other_type_name_owner;
4671                    let (defined_func_name, called_func_name, other_type_name) =
4672                        if matches!(type_key, &crate::PredeclaredType::ModfResult { .. }) {
4673                            (MODF_FUNCTION, "modf", arg_type_name)
4674                        } else {
4675                            let other_type_name = if let Some(size) = size {
4676                                other_type_name_owner = format!("int{}", size as u8);
4677                                &other_type_name_owner
4678                            } else {
4679                                "int"
4680                            };
4681                            (FREXP_FUNCTION, "frexp", other_type_name)
4682                        };
4683
4684                    let struct_name = &self.names[&NameKey::Type(*struct_ty)];
4685
4686                    writeln!(self.out)?;
4687                    writeln!(
4688                        self.out,
4689                        "{struct_name} {defined_func_name}({arg_type_name} arg) {{
4690    {other_type_name} other;
4691    {arg_type_name} fract = {NAMESPACE}::{called_func_name}(arg, other);
4692    return {struct_name}{{ fract, other }};
4693}}"
4694                    )?;
4695                }
4696                &crate::PredeclaredType::AtomicCompareExchangeWeakResult(scalar) => {
4697                    let arg_type_name = scalar.to_msl_name();
4698                    let called_func_name = "atomic_compare_exchange_weak_explicit";
4699                    let defined_func_name = ATOMIC_COMP_EXCH_FUNCTION;
4700                    let struct_name = &self.names[&NameKey::Type(*struct_ty)];
4701
4702                    writeln!(self.out)?;
4703
4704                    for address_space_name in ["device", "threadgroup"] {
4705                        writeln!(
4706                            self.out,
4707                            "\
4708template <typename A>
4709{struct_name} {defined_func_name}(
4710    {address_space_name} A *atomic_ptr,
4711    {arg_type_name} cmp,
4712    {arg_type_name} v
4713) {{
4714    bool swapped = {NAMESPACE}::{called_func_name}(
4715        atomic_ptr, &cmp, v,
4716        metal::memory_order_relaxed, metal::memory_order_relaxed
4717    );
4718    return {struct_name}{{cmp, swapped}};
4719}}"
4720                        )?;
4721                    }
4722                }
4723            }
4724        }
4725
4726        Ok(())
4727    }
4728
4729    /// Writes all named constants
4730    fn write_global_constants(
4731        &mut self,
4732        module: &crate::Module,
4733        mod_info: &valid::ModuleInfo,
4734    ) -> BackendResult {
4735        let constants = module.constants.iter().filter(|&(_, c)| c.name.is_some());
4736
4737        for (handle, constant) in constants {
4738            let ty_name = TypeContext {
4739                handle: constant.ty,
4740                gctx: module.to_ctx(),
4741                names: &self.names,
4742                access: crate::StorageAccess::empty(),
4743                first_time: false,
4744            };
4745            let name = &self.names[&NameKey::Constant(handle)];
4746            write!(self.out, "constant {ty_name} {name} = ")?;
4747            self.put_const_expression(constant.init, module, mod_info, &module.global_expressions)?;
4748            writeln!(self.out, ";")?;
4749        }
4750
4751        Ok(())
4752    }
4753
4754    fn put_inline_sampler_properties(
4755        &mut self,
4756        level: back::Level,
4757        sampler: &sm::InlineSampler,
4758    ) -> BackendResult {
4759        for (&letter, address) in ['s', 't', 'r'].iter().zip(sampler.address.iter()) {
4760            writeln!(
4761                self.out,
4762                "{}{}::{}_address::{},",
4763                level,
4764                NAMESPACE,
4765                letter,
4766                address.as_str(),
4767            )?;
4768        }
4769        writeln!(
4770            self.out,
4771            "{}{}::mag_filter::{},",
4772            level,
4773            NAMESPACE,
4774            sampler.mag_filter.as_str(),
4775        )?;
4776        writeln!(
4777            self.out,
4778            "{}{}::min_filter::{},",
4779            level,
4780            NAMESPACE,
4781            sampler.min_filter.as_str(),
4782        )?;
4783        if let Some(filter) = sampler.mip_filter {
4784            writeln!(
4785                self.out,
4786                "{}{}::mip_filter::{},",
4787                level,
4788                NAMESPACE,
4789                filter.as_str(),
4790            )?;
4791        }
4792        // avoid setting it on platforms that don't support it
4793        if sampler.border_color != sm::BorderColor::TransparentBlack {
4794            writeln!(
4795                self.out,
4796                "{}{}::border_color::{},",
4797                level,
4798                NAMESPACE,
4799                sampler.border_color.as_str(),
4800            )?;
4801        }
4802        //TODO: I'm not able to feed this in a way that MSL likes:
4803        //>error: use of undeclared identifier 'lod_clamp'
4804        //>error: no member named 'max_anisotropy' in namespace 'metal'
4805        if false {
4806            if let Some(ref lod) = sampler.lod_clamp {
4807                writeln!(self.out, "{}lod_clamp({},{}),", level, lod.start, lod.end,)?;
4808            }
4809            if let Some(aniso) = sampler.max_anisotropy {
4810                writeln!(self.out, "{}max_anisotropy({}),", level, aniso.get(),)?;
4811            }
4812        }
4813        if sampler.compare_func != sm::CompareFunc::Never {
4814            writeln!(
4815                self.out,
4816                "{}{}::compare_func::{},",
4817                level,
4818                NAMESPACE,
4819                sampler.compare_func.as_str(),
4820            )?;
4821        }
4822        writeln!(
4823            self.out,
4824            "{}{}::coord::{}",
4825            level,
4826            NAMESPACE,
4827            sampler.coord.as_str()
4828        )?;
4829        Ok(())
4830    }
4831
4832    fn write_unpacking_function(
4833        &mut self,
4834        format: back::msl::VertexFormat,
4835    ) -> Result<(String, u32, Option<crate::VectorSize>, crate::Scalar), Error> {
4836        use crate::{Scalar, VectorSize};
4837        use back::msl::VertexFormat::*;
4838        match format {
4839            Uint8 => {
4840                let name = self.namer.call("unpackUint8");
4841                writeln!(self.out, "uint {name}(metal::uchar b0) {{")?;
4842                writeln!(self.out, "{}return uint(b0);", back::INDENT)?;
4843                writeln!(self.out, "}}")?;
4844                Ok((name, 1, None, Scalar::U32))
4845            }
4846            Uint8x2 => {
4847                let name = self.namer.call("unpackUint8x2");
4848                writeln!(
4849                    self.out,
4850                    "metal::uint2 {name}(metal::uchar b0, \
4851                                         metal::uchar b1) {{"
4852                )?;
4853                writeln!(self.out, "{}return metal::uint2(b0, b1);", back::INDENT)?;
4854                writeln!(self.out, "}}")?;
4855                Ok((name, 2, Some(VectorSize::Bi), Scalar::U32))
4856            }
4857            Uint8x4 => {
4858                let name = self.namer.call("unpackUint8x4");
4859                writeln!(
4860                    self.out,
4861                    "metal::uint4 {name}(metal::uchar b0, \
4862                                         metal::uchar b1, \
4863                                         metal::uchar b2, \
4864                                         metal::uchar b3) {{"
4865                )?;
4866                writeln!(
4867                    self.out,
4868                    "{}return metal::uint4(b0, b1, b2, b3);",
4869                    back::INDENT
4870                )?;
4871                writeln!(self.out, "}}")?;
4872                Ok((name, 4, Some(VectorSize::Quad), Scalar::U32))
4873            }
4874            Sint8 => {
4875                let name = self.namer.call("unpackSint8");
4876                writeln!(self.out, "int {name}(metal::uchar b0) {{")?;
4877                writeln!(self.out, "{}return int(as_type<char>(b0));", back::INDENT)?;
4878                writeln!(self.out, "}}")?;
4879                Ok((name, 1, None, Scalar::I32))
4880            }
4881            Sint8x2 => {
4882                let name = self.namer.call("unpackSint8x2");
4883                writeln!(
4884                    self.out,
4885                    "metal::int2 {name}(metal::uchar b0, \
4886                                        metal::uchar b1) {{"
4887                )?;
4888                writeln!(
4889                    self.out,
4890                    "{}return metal::int2(as_type<char>(b0), \
4891                                          as_type<char>(b1));",
4892                    back::INDENT
4893                )?;
4894                writeln!(self.out, "}}")?;
4895                Ok((name, 2, Some(VectorSize::Bi), Scalar::I32))
4896            }
4897            Sint8x4 => {
4898                let name = self.namer.call("unpackSint8x4");
4899                writeln!(
4900                    self.out,
4901                    "metal::int4 {name}(metal::uchar b0, \
4902                                        metal::uchar b1, \
4903                                        metal::uchar b2, \
4904                                        metal::uchar b3) {{"
4905                )?;
4906                writeln!(
4907                    self.out,
4908                    "{}return metal::int4(as_type<char>(b0), \
4909                                          as_type<char>(b1), \
4910                                          as_type<char>(b2), \
4911                                          as_type<char>(b3));",
4912                    back::INDENT
4913                )?;
4914                writeln!(self.out, "}}")?;
4915                Ok((name, 4, Some(VectorSize::Quad), Scalar::I32))
4916            }
4917            Unorm8 => {
4918                let name = self.namer.call("unpackUnorm8");
4919                writeln!(self.out, "float {name}(metal::uchar b0) {{")?;
4920                writeln!(
4921                    self.out,
4922                    "{}return float(float(b0) / 255.0f);",
4923                    back::INDENT
4924                )?;
4925                writeln!(self.out, "}}")?;
4926                Ok((name, 1, None, Scalar::F32))
4927            }
4928            Unorm8x2 => {
4929                let name = self.namer.call("unpackUnorm8x2");
4930                writeln!(
4931                    self.out,
4932                    "metal::float2 {name}(metal::uchar b0, \
4933                                          metal::uchar b1) {{"
4934                )?;
4935                writeln!(
4936                    self.out,
4937                    "{}return metal::float2(float(b0) / 255.0f, \
4938                                            float(b1) / 255.0f);",
4939                    back::INDENT
4940                )?;
4941                writeln!(self.out, "}}")?;
4942                Ok((name, 2, Some(VectorSize::Bi), Scalar::F32))
4943            }
4944            Unorm8x4 => {
4945                let name = self.namer.call("unpackUnorm8x4");
4946                writeln!(
4947                    self.out,
4948                    "metal::float4 {name}(metal::uchar b0, \
4949                                          metal::uchar b1, \
4950                                          metal::uchar b2, \
4951                                          metal::uchar b3) {{"
4952                )?;
4953                writeln!(
4954                    self.out,
4955                    "{}return metal::float4(float(b0) / 255.0f, \
4956                                            float(b1) / 255.0f, \
4957                                            float(b2) / 255.0f, \
4958                                            float(b3) / 255.0f);",
4959                    back::INDENT
4960                )?;
4961                writeln!(self.out, "}}")?;
4962                Ok((name, 4, Some(VectorSize::Quad), Scalar::F32))
4963            }
4964            Snorm8 => {
4965                let name = self.namer.call("unpackSnorm8");
4966                writeln!(self.out, "float {name}(metal::uchar b0) {{")?;
4967                writeln!(
4968                    self.out,
4969                    "{}return float(metal::max(-1.0f, as_type<char>(b0) / 127.0f));",
4970                    back::INDENT
4971                )?;
4972                writeln!(self.out, "}}")?;
4973                Ok((name, 1, None, Scalar::F32))
4974            }
4975            Snorm8x2 => {
4976                let name = self.namer.call("unpackSnorm8x2");
4977                writeln!(
4978                    self.out,
4979                    "metal::float2 {name}(metal::uchar b0, \
4980                                          metal::uchar b1) {{"
4981                )?;
4982                writeln!(
4983                    self.out,
4984                    "{}return metal::float2(metal::max(-1.0f, as_type<char>(b0) / 127.0f), \
4985                                            metal::max(-1.0f, as_type<char>(b1) / 127.0f));",
4986                    back::INDENT
4987                )?;
4988                writeln!(self.out, "}}")?;
4989                Ok((name, 2, Some(VectorSize::Bi), Scalar::F32))
4990            }
4991            Snorm8x4 => {
4992                let name = self.namer.call("unpackSnorm8x4");
4993                writeln!(
4994                    self.out,
4995                    "metal::float4 {name}(metal::uchar b0, \
4996                                          metal::uchar b1, \
4997                                          metal::uchar b2, \
4998                                          metal::uchar b3) {{"
4999                )?;
5000                writeln!(
5001                    self.out,
5002                    "{}return metal::float4(metal::max(-1.0f, as_type<char>(b0) / 127.0f), \
5003                                            metal::max(-1.0f, as_type<char>(b1) / 127.0f), \
5004                                            metal::max(-1.0f, as_type<char>(b2) / 127.0f), \
5005                                            metal::max(-1.0f, as_type<char>(b3) / 127.0f));",
5006                    back::INDENT
5007                )?;
5008                writeln!(self.out, "}}")?;
5009                Ok((name, 4, Some(VectorSize::Quad), Scalar::F32))
5010            }
5011            Uint16 => {
5012                let name = self.namer.call("unpackUint16");
5013                writeln!(
5014                    self.out,
5015                    "metal::uint {name}(metal::uint b0, \
5016                                        metal::uint b1) {{"
5017                )?;
5018                writeln!(
5019                    self.out,
5020                    "{}return metal::uint(b1 << 8 | b0);",
5021                    back::INDENT
5022                )?;
5023                writeln!(self.out, "}}")?;
5024                Ok((name, 2, None, Scalar::U32))
5025            }
5026            Uint16x2 => {
5027                let name = self.namer.call("unpackUint16x2");
5028                writeln!(
5029                    self.out,
5030                    "metal::uint2 {name}(metal::uint b0, \
5031                                         metal::uint b1, \
5032                                         metal::uint b2, \
5033                                         metal::uint b3) {{"
5034                )?;
5035                writeln!(
5036                    self.out,
5037                    "{}return metal::uint2(b1 << 8 | b0, \
5038                                           b3 << 8 | b2);",
5039                    back::INDENT
5040                )?;
5041                writeln!(self.out, "}}")?;
5042                Ok((name, 4, Some(VectorSize::Bi), Scalar::U32))
5043            }
5044            Uint16x4 => {
5045                let name = self.namer.call("unpackUint16x4");
5046                writeln!(
5047                    self.out,
5048                    "metal::uint4 {name}(metal::uint b0, \
5049                                         metal::uint b1, \
5050                                         metal::uint b2, \
5051                                         metal::uint b3, \
5052                                         metal::uint b4, \
5053                                         metal::uint b5, \
5054                                         metal::uint b6, \
5055                                         metal::uint b7) {{"
5056                )?;
5057                writeln!(
5058                    self.out,
5059                    "{}return metal::uint4(b1 << 8 | b0, \
5060                                           b3 << 8 | b2, \
5061                                           b5 << 8 | b4, \
5062                                           b7 << 8 | b6);",
5063                    back::INDENT
5064                )?;
5065                writeln!(self.out, "}}")?;
5066                Ok((name, 8, Some(VectorSize::Quad), Scalar::U32))
5067            }
5068            Sint16 => {
5069                let name = self.namer.call("unpackSint16");
5070                writeln!(
5071                    self.out,
5072                    "int {name}(metal::ushort b0, \
5073                                metal::ushort b1) {{"
5074                )?;
5075                writeln!(
5076                    self.out,
5077                    "{}return int(as_type<short>(metal::ushort(b1 << 8 | b0)));",
5078                    back::INDENT
5079                )?;
5080                writeln!(self.out, "}}")?;
5081                Ok((name, 2, None, Scalar::I32))
5082            }
5083            Sint16x2 => {
5084                let name = self.namer.call("unpackSint16x2");
5085                writeln!(
5086                    self.out,
5087                    "metal::int2 {name}(metal::ushort b0, \
5088                                        metal::ushort b1, \
5089                                        metal::ushort b2, \
5090                                        metal::ushort b3) {{"
5091                )?;
5092                writeln!(
5093                    self.out,
5094                    "{}return metal::int2(as_type<short>(metal::ushort(b1 << 8 | b0)), \
5095                                          as_type<short>(metal::ushort(b3 << 8 | b2)));",
5096                    back::INDENT
5097                )?;
5098                writeln!(self.out, "}}")?;
5099                Ok((name, 4, Some(VectorSize::Bi), Scalar::I32))
5100            }
5101            Sint16x4 => {
5102                let name = self.namer.call("unpackSint16x4");
5103                writeln!(
5104                    self.out,
5105                    "metal::int4 {name}(metal::ushort b0, \
5106                                        metal::ushort b1, \
5107                                        metal::ushort b2, \
5108                                        metal::ushort b3, \
5109                                        metal::ushort b4, \
5110                                        metal::ushort b5, \
5111                                        metal::ushort b6, \
5112                                        metal::ushort b7) {{"
5113                )?;
5114                writeln!(
5115                    self.out,
5116                    "{}return metal::int4(as_type<short>(metal::ushort(b1 << 8 | b0)), \
5117                                          as_type<short>(metal::ushort(b3 << 8 | b2)), \
5118                                          as_type<short>(metal::ushort(b5 << 8 | b4)), \
5119                                          as_type<short>(metal::ushort(b7 << 8 | b6)));",
5120                    back::INDENT
5121                )?;
5122                writeln!(self.out, "}}")?;
5123                Ok((name, 8, Some(VectorSize::Quad), Scalar::I32))
5124            }
5125            Unorm16 => {
5126                let name = self.namer.call("unpackUnorm16");
5127                writeln!(
5128                    self.out,
5129                    "float {name}(metal::ushort b0, \
5130                                  metal::ushort b1) {{"
5131                )?;
5132                writeln!(
5133                    self.out,
5134                    "{}return float(float(b1 << 8 | b0) / 65535.0f);",
5135                    back::INDENT
5136                )?;
5137                writeln!(self.out, "}}")?;
5138                Ok((name, 2, None, Scalar::F32))
5139            }
5140            Unorm16x2 => {
5141                let name = self.namer.call("unpackUnorm16x2");
5142                writeln!(
5143                    self.out,
5144                    "metal::float2 {name}(metal::ushort b0, \
5145                                          metal::ushort b1, \
5146                                          metal::ushort b2, \
5147                                          metal::ushort b3) {{"
5148                )?;
5149                writeln!(
5150                    self.out,
5151                    "{}return metal::float2(float(b1 << 8 | b0) / 65535.0f, \
5152                                            float(b3 << 8 | b2) / 65535.0f);",
5153                    back::INDENT
5154                )?;
5155                writeln!(self.out, "}}")?;
5156                Ok((name, 4, Some(VectorSize::Bi), Scalar::F32))
5157            }
5158            Unorm16x4 => {
5159                let name = self.namer.call("unpackUnorm16x4");
5160                writeln!(
5161                    self.out,
5162                    "metal::float4 {name}(metal::ushort b0, \
5163                                          metal::ushort b1, \
5164                                          metal::ushort b2, \
5165                                          metal::ushort b3, \
5166                                          metal::ushort b4, \
5167                                          metal::ushort b5, \
5168                                          metal::ushort b6, \
5169                                          metal::ushort b7) {{"
5170                )?;
5171                writeln!(
5172                    self.out,
5173                    "{}return metal::float4(float(b1 << 8 | b0) / 65535.0f, \
5174                                            float(b3 << 8 | b2) / 65535.0f, \
5175                                            float(b5 << 8 | b4) / 65535.0f, \
5176                                            float(b7 << 8 | b6) / 65535.0f);",
5177                    back::INDENT
5178                )?;
5179                writeln!(self.out, "}}")?;
5180                Ok((name, 8, Some(VectorSize::Quad), Scalar::F32))
5181            }
5182            Snorm16 => {
5183                let name = self.namer.call("unpackSnorm16");
5184                writeln!(
5185                    self.out,
5186                    "float {name}(metal::ushort b0, \
5187                                  metal::ushort b1) {{"
5188                )?;
5189                writeln!(
5190                    self.out,
5191                    "{}return metal::unpack_snorm2x16_to_float(b1 << 8 | b0).x;",
5192                    back::INDENT
5193                )?;
5194                writeln!(self.out, "}}")?;
5195                Ok((name, 2, None, Scalar::F32))
5196            }
5197            Snorm16x2 => {
5198                let name = self.namer.call("unpackSnorm16x2");
5199                writeln!(
5200                    self.out,
5201                    "metal::float2 {name}(uint b0, \
5202                                          uint b1, \
5203                                          uint b2, \
5204                                          uint b3) {{"
5205                )?;
5206                writeln!(
5207                    self.out,
5208                    "{}return metal::unpack_snorm2x16_to_float(b3 << 24 | b2 << 16 | b1 << 8 | b0);",
5209                    back::INDENT
5210                )?;
5211                writeln!(self.out, "}}")?;
5212                Ok((name, 4, Some(VectorSize::Bi), Scalar::F32))
5213            }
5214            Snorm16x4 => {
5215                let name = self.namer.call("unpackSnorm16x4");
5216                writeln!(
5217                    self.out,
5218                    "metal::float4 {name}(uint b0, \
5219                                          uint b1, \
5220                                          uint b2, \
5221                                          uint b3, \
5222                                          uint b4, \
5223                                          uint b5, \
5224                                          uint b6, \
5225                                          uint b7) {{"
5226                )?;
5227                writeln!(
5228                    self.out,
5229                    "{}return metal::float4(metal::unpack_snorm2x16_to_float(b3 << 24 | b2 << 16 | b1 << 8 | b0), \
5230                                            metal::unpack_snorm2x16_to_float(b7 << 24 | b6 << 16 | b5 << 8 | b4));",
5231                    back::INDENT
5232                )?;
5233                writeln!(self.out, "}}")?;
5234                Ok((name, 8, Some(VectorSize::Quad), Scalar::F32))
5235            }
5236            Float16 => {
5237                let name = self.namer.call("unpackFloat16");
5238                writeln!(
5239                    self.out,
5240                    "float {name}(metal::ushort b0, \
5241                                  metal::ushort b1) {{"
5242                )?;
5243                writeln!(
5244                    self.out,
5245                    "{}return float(as_type<half>(metal::ushort(b1 << 8 | b0)));",
5246                    back::INDENT
5247                )?;
5248                writeln!(self.out, "}}")?;
5249                Ok((name, 2, None, Scalar::F32))
5250            }
5251            Float16x2 => {
5252                let name = self.namer.call("unpackFloat16x2");
5253                writeln!(
5254                    self.out,
5255                    "metal::float2 {name}(metal::ushort b0, \
5256                                          metal::ushort b1, \
5257                                          metal::ushort b2, \
5258                                          metal::ushort b3) {{"
5259                )?;
5260                writeln!(
5261                    self.out,
5262                    "{}return metal::float2(as_type<half>(metal::ushort(b1 << 8 | b0)), \
5263                                            as_type<half>(metal::ushort(b3 << 8 | b2)));",
5264                    back::INDENT
5265                )?;
5266                writeln!(self.out, "}}")?;
5267                Ok((name, 4, Some(VectorSize::Bi), Scalar::F32))
5268            }
5269            Float16x4 => {
5270                let name = self.namer.call("unpackFloat16x4");
5271                writeln!(
5272                    self.out,
5273                    "metal::float4 {name}(metal::ushort b0, \
5274                                        metal::ushort b1, \
5275                                        metal::ushort b2, \
5276                                        metal::ushort b3, \
5277                                        metal::ushort b4, \
5278                                        metal::ushort b5, \
5279                                        metal::ushort b6, \
5280                                        metal::ushort b7) {{"
5281                )?;
5282                writeln!(
5283                    self.out,
5284                    "{}return metal::float4(as_type<half>(metal::ushort(b1 << 8 | b0)), \
5285                                          as_type<half>(metal::ushort(b3 << 8 | b2)), \
5286                                          as_type<half>(metal::ushort(b5 << 8 | b4)), \
5287                                          as_type<half>(metal::ushort(b7 << 8 | b6)));",
5288                    back::INDENT
5289                )?;
5290                writeln!(self.out, "}}")?;
5291                Ok((name, 8, Some(VectorSize::Quad), Scalar::F32))
5292            }
5293            Float32 => {
5294                let name = self.namer.call("unpackFloat32");
5295                writeln!(
5296                    self.out,
5297                    "float {name}(uint b0, \
5298                                  uint b1, \
5299                                  uint b2, \
5300                                  uint b3) {{"
5301                )?;
5302                writeln!(
5303                    self.out,
5304                    "{}return as_type<float>(b3 << 24 | b2 << 16 | b1 << 8 | b0);",
5305                    back::INDENT
5306                )?;
5307                writeln!(self.out, "}}")?;
5308                Ok((name, 4, None, Scalar::F32))
5309            }
5310            Float32x2 => {
5311                let name = self.namer.call("unpackFloat32x2");
5312                writeln!(
5313                    self.out,
5314                    "metal::float2 {name}(uint b0, \
5315                                          uint b1, \
5316                                          uint b2, \
5317                                          uint b3, \
5318                                          uint b4, \
5319                                          uint b5, \
5320                                          uint b6, \
5321                                          uint b7) {{"
5322                )?;
5323                writeln!(
5324                    self.out,
5325                    "{}return metal::float2(as_type<float>(b3 << 24 | b2 << 16 | b1 << 8 | b0), \
5326                                            as_type<float>(b7 << 24 | b6 << 16 | b5 << 8 | b4));",
5327                    back::INDENT
5328                )?;
5329                writeln!(self.out, "}}")?;
5330                Ok((name, 8, Some(VectorSize::Bi), Scalar::F32))
5331            }
5332            Float32x3 => {
5333                let name = self.namer.call("unpackFloat32x3");
5334                writeln!(
5335                    self.out,
5336                    "metal::float3 {name}(uint b0, \
5337                                          uint b1, \
5338                                          uint b2, \
5339                                          uint b3, \
5340                                          uint b4, \
5341                                          uint b5, \
5342                                          uint b6, \
5343                                          uint b7, \
5344                                          uint b8, \
5345                                          uint b9, \
5346                                          uint b10, \
5347                                          uint b11) {{"
5348                )?;
5349                writeln!(
5350                    self.out,
5351                    "{}return metal::float3(as_type<float>(b3 << 24 | b2 << 16 | b1 << 8 | b0), \
5352                                            as_type<float>(b7 << 24 | b6 << 16 | b5 << 8 | b4), \
5353                                            as_type<float>(b11 << 24 | b10 << 16 | b9 << 8 | b8));",
5354                    back::INDENT
5355                )?;
5356                writeln!(self.out, "}}")?;
5357                Ok((name, 12, Some(VectorSize::Tri), Scalar::F32))
5358            }
5359            Float32x4 => {
5360                let name = self.namer.call("unpackFloat32x4");
5361                writeln!(
5362                    self.out,
5363                    "metal::float4 {name}(uint b0, \
5364                                          uint b1, \
5365                                          uint b2, \
5366                                          uint b3, \
5367                                          uint b4, \
5368                                          uint b5, \
5369                                          uint b6, \
5370                                          uint b7, \
5371                                          uint b8, \
5372                                          uint b9, \
5373                                          uint b10, \
5374                                          uint b11, \
5375                                          uint b12, \
5376                                          uint b13, \
5377                                          uint b14, \
5378                                          uint b15) {{"
5379                )?;
5380                writeln!(
5381                    self.out,
5382                    "{}return metal::float4(as_type<float>(b3 << 24 | b2 << 16 | b1 << 8 | b0), \
5383                                            as_type<float>(b7 << 24 | b6 << 16 | b5 << 8 | b4), \
5384                                            as_type<float>(b11 << 24 | b10 << 16 | b9 << 8 | b8), \
5385                                            as_type<float>(b15 << 24 | b14 << 16 | b13 << 8 | b12));",
5386                    back::INDENT
5387                )?;
5388                writeln!(self.out, "}}")?;
5389                Ok((name, 16, Some(VectorSize::Quad), Scalar::F32))
5390            }
5391            Uint32 => {
5392                let name = self.namer.call("unpackUint32");
5393                writeln!(
5394                    self.out,
5395                    "uint {name}(uint b0, \
5396                                 uint b1, \
5397                                 uint b2, \
5398                                 uint b3) {{"
5399                )?;
5400                writeln!(
5401                    self.out,
5402                    "{}return (b3 << 24 | b2 << 16 | b1 << 8 | b0);",
5403                    back::INDENT
5404                )?;
5405                writeln!(self.out, "}}")?;
5406                Ok((name, 4, None, Scalar::U32))
5407            }
5408            Uint32x2 => {
5409                let name = self.namer.call("unpackUint32x2");
5410                writeln!(
5411                    self.out,
5412                    "uint2 {name}(uint b0, \
5413                                  uint b1, \
5414                                  uint b2, \
5415                                  uint b3, \
5416                                  uint b4, \
5417                                  uint b5, \
5418                                  uint b6, \
5419                                  uint b7) {{"
5420                )?;
5421                writeln!(
5422                    self.out,
5423                    "{}return uint2((b3 << 24 | b2 << 16 | b1 << 8 | b0), \
5424                                    (b7 << 24 | b6 << 16 | b5 << 8 | b4));",
5425                    back::INDENT
5426                )?;
5427                writeln!(self.out, "}}")?;
5428                Ok((name, 8, Some(VectorSize::Bi), Scalar::U32))
5429            }
5430            Uint32x3 => {
5431                let name = self.namer.call("unpackUint32x3");
5432                writeln!(
5433                    self.out,
5434                    "uint3 {name}(uint b0, \
5435                                  uint b1, \
5436                                  uint b2, \
5437                                  uint b3, \
5438                                  uint b4, \
5439                                  uint b5, \
5440                                  uint b6, \
5441                                  uint b7, \
5442                                  uint b8, \
5443                                  uint b9, \
5444                                  uint b10, \
5445                                  uint b11) {{"
5446                )?;
5447                writeln!(
5448                    self.out,
5449                    "{}return uint3((b3 << 24 | b2 << 16 | b1 << 8 | b0), \
5450                                    (b7 << 24 | b6 << 16 | b5 << 8 | b4), \
5451                                    (b11 << 24 | b10 << 16 | b9 << 8 | b8));",
5452                    back::INDENT
5453                )?;
5454                writeln!(self.out, "}}")?;
5455                Ok((name, 12, Some(VectorSize::Tri), Scalar::U32))
5456            }
5457            Uint32x4 => {
5458                let name = self.namer.call("unpackUint32x4");
5459                writeln!(
5460                    self.out,
5461                    "{NAMESPACE}::uint4 {name}(uint b0, \
5462                                  uint b1, \
5463                                  uint b2, \
5464                                  uint b3, \
5465                                  uint b4, \
5466                                  uint b5, \
5467                                  uint b6, \
5468                                  uint b7, \
5469                                  uint b8, \
5470                                  uint b9, \
5471                                  uint b10, \
5472                                  uint b11, \
5473                                  uint b12, \
5474                                  uint b13, \
5475                                  uint b14, \
5476                                  uint b15) {{"
5477                )?;
5478                writeln!(
5479                    self.out,
5480                    "{}return {NAMESPACE}::uint4((b3 << 24 | b2 << 16 | b1 << 8 | b0), \
5481                                    (b7 << 24 | b6 << 16 | b5 << 8 | b4), \
5482                                    (b11 << 24 | b10 << 16 | b9 << 8 | b8), \
5483                                    (b15 << 24 | b14 << 16 | b13 << 8 | b12));",
5484                    back::INDENT
5485                )?;
5486                writeln!(self.out, "}}")?;
5487                Ok((name, 16, Some(VectorSize::Quad), Scalar::U32))
5488            }
5489            Sint32 => {
5490                let name = self.namer.call("unpackSint32");
5491                writeln!(
5492                    self.out,
5493                    "int {name}(uint b0, \
5494                                uint b1, \
5495                                uint b2, \
5496                                uint b3) {{"
5497                )?;
5498                writeln!(
5499                    self.out,
5500                    "{}return as_type<int>(b3 << 24 | b2 << 16 | b1 << 8 | b0);",
5501                    back::INDENT
5502                )?;
5503                writeln!(self.out, "}}")?;
5504                Ok((name, 4, None, Scalar::I32))
5505            }
5506            Sint32x2 => {
5507                let name = self.namer.call("unpackSint32x2");
5508                writeln!(
5509                    self.out,
5510                    "metal::int2 {name}(uint b0, \
5511                                        uint b1, \
5512                                        uint b2, \
5513                                        uint b3, \
5514                                        uint b4, \
5515                                        uint b5, \
5516                                        uint b6, \
5517                                        uint b7) {{"
5518                )?;
5519                writeln!(
5520                    self.out,
5521                    "{}return metal::int2(as_type<int>(b3 << 24 | b2 << 16 | b1 << 8 | b0), \
5522                                          as_type<int>(b7 << 24 | b6 << 16 | b5 << 8 | b4));",
5523                    back::INDENT
5524                )?;
5525                writeln!(self.out, "}}")?;
5526                Ok((name, 8, Some(VectorSize::Bi), Scalar::I32))
5527            }
5528            Sint32x3 => {
5529                let name = self.namer.call("unpackSint32x3");
5530                writeln!(
5531                    self.out,
5532                    "metal::int3 {name}(uint b0, \
5533                                        uint b1, \
5534                                        uint b2, \
5535                                        uint b3, \
5536                                        uint b4, \
5537                                        uint b5, \
5538                                        uint b6, \
5539                                        uint b7, \
5540                                        uint b8, \
5541                                        uint b9, \
5542                                        uint b10, \
5543                                        uint b11) {{"
5544                )?;
5545                writeln!(
5546                    self.out,
5547                    "{}return metal::int3(as_type<int>(b3 << 24 | b2 << 16 | b1 << 8 | b0), \
5548                                          as_type<int>(b7 << 24 | b6 << 16 | b5 << 8 | b4), \
5549                                          as_type<int>(b11 << 24 | b10 << 16 | b9 << 8 | b8));",
5550                    back::INDENT
5551                )?;
5552                writeln!(self.out, "}}")?;
5553                Ok((name, 12, Some(VectorSize::Tri), Scalar::I32))
5554            }
5555            Sint32x4 => {
5556                let name = self.namer.call("unpackSint32x4");
5557                writeln!(
5558                    self.out,
5559                    "metal::int4 {name}(uint b0, \
5560                                        uint b1, \
5561                                        uint b2, \
5562                                        uint b3, \
5563                                        uint b4, \
5564                                        uint b5, \
5565                                        uint b6, \
5566                                        uint b7, \
5567                                        uint b8, \
5568                                        uint b9, \
5569                                        uint b10, \
5570                                        uint b11, \
5571                                        uint b12, \
5572                                        uint b13, \
5573                                        uint b14, \
5574                                        uint b15) {{"
5575                )?;
5576                writeln!(
5577                    self.out,
5578                    "{}return metal::int4(as_type<int>(b3 << 24 | b2 << 16 | b1 << 8 | b0), \
5579                                          as_type<int>(b7 << 24 | b6 << 16 | b5 << 8 | b4), \
5580                                          as_type<int>(b11 << 24 | b10 << 16 | b9 << 8 | b8), \
5581                                          as_type<int>(b15 << 24 | b14 << 16 | b13 << 8 | b12));",
5582                    back::INDENT
5583                )?;
5584                writeln!(self.out, "}}")?;
5585                Ok((name, 16, Some(VectorSize::Quad), Scalar::I32))
5586            }
5587            Unorm10_10_10_2 => {
5588                let name = self.namer.call("unpackUnorm10_10_10_2");
5589                writeln!(
5590                    self.out,
5591                    "metal::float4 {name}(uint b0, \
5592                                          uint b1, \
5593                                          uint b2, \
5594                                          uint b3) {{"
5595                )?;
5596                writeln!(
5597                    self.out,
5598                    // The following is correct for RGBA packing, but our format seems to
5599                    // match ABGR, which can be fed into the Metal builtin function
5600                    // unpack_unorm10a2_to_float.
5601                    /*
5602                    "{}uint v = (b3 << 24 | b2 << 16 | b1 << 8 | b0); \
5603                       uint r = (v & 0xFFC00000) >> 22; \
5604                       uint g = (v & 0x003FF000) >> 12; \
5605                       uint b = (v & 0x00000FFC) >> 2; \
5606                       uint a = (v & 0x00000003); \
5607                       return metal::float4(float(r) / 1023.0f, float(g) / 1023.0f, float(b) / 1023.0f, float(a) / 3.0f);",
5608                    */
5609                    "{}return metal::unpack_unorm10a2_to_float(b3 << 24 | b2 << 16 | b1 << 8 | b0);",
5610                    back::INDENT
5611                )?;
5612                writeln!(self.out, "}}")?;
5613                Ok((name, 4, Some(VectorSize::Quad), Scalar::F32))
5614            }
5615            Unorm8x4Bgra => {
5616                let name = self.namer.call("unpackUnorm8x4Bgra");
5617                writeln!(
5618                    self.out,
5619                    "metal::float4 {name}(metal::uchar b0, \
5620                                          metal::uchar b1, \
5621                                          metal::uchar b2, \
5622                                          metal::uchar b3) {{"
5623                )?;
5624                writeln!(
5625                    self.out,
5626                    "{}return metal::float4(float(b2) / 255.0f, \
5627                                            float(b1) / 255.0f, \
5628                                            float(b0) / 255.0f, \
5629                                            float(b3) / 255.0f);",
5630                    back::INDENT
5631                )?;
5632                writeln!(self.out, "}}")?;
5633                Ok((name, 4, Some(VectorSize::Quad), Scalar::F32))
5634            }
5635        }
5636    }
5637
5638    fn write_wrapped_unary_op(
5639        &mut self,
5640        module: &crate::Module,
5641        func_ctx: &back::FunctionCtx,
5642        op: crate::UnaryOperator,
5643        operand: Handle<crate::Expression>,
5644    ) -> BackendResult {
5645        let operand_ty = func_ctx.resolve_type(operand, &module.types);
5646        match op {
5647            // Negating the TYPE_MIN of a two's complement signed integer
5648            // type causes overflow, which is undefined behaviour in MSL. To
5649            // avoid this we bitcast the value to unsigned and negate it,
5650            // then bitcast back to signed.
5651            // This adheres to the WGSL spec in that the negative of the
5652            // type's minimum value should equal to the minimum value.
5653            crate::UnaryOperator::Negate
5654                if operand_ty.scalar_kind() == Some(crate::ScalarKind::Sint) =>
5655            {
5656                let Some((vector_size, scalar)) = operand_ty.vector_size_and_scalar() else {
5657                    return Ok(());
5658                };
5659                let wrapped = WrappedFunction::UnaryOp {
5660                    op,
5661                    ty: (vector_size, scalar),
5662                };
5663                if !self.wrapped_functions.insert(wrapped) {
5664                    return Ok(());
5665                }
5666
5667                let unsigned_scalar = crate::Scalar {
5668                    kind: crate::ScalarKind::Uint,
5669                    ..scalar
5670                };
5671                let mut type_name = String::new();
5672                let mut unsigned_type_name = String::new();
5673                match vector_size {
5674                    None => {
5675                        put_numeric_type(&mut type_name, scalar, &[])?;
5676                        put_numeric_type(&mut unsigned_type_name, unsigned_scalar, &[])?
5677                    }
5678                    Some(size) => {
5679                        put_numeric_type(&mut type_name, scalar, &[size])?;
5680                        put_numeric_type(&mut unsigned_type_name, unsigned_scalar, &[size])?;
5681                    }
5682                };
5683
5684                writeln!(self.out, "{type_name} {NEG_FUNCTION}({type_name} val) {{")?;
5685                let level = back::Level(1);
5686                // For sub-32-bit types, C++ integer promotion widens
5687                // `-as_type<ushort>(val)` to `int`, so we need static_cast
5688                // to truncate back before the outer as_type bitcast.
5689                if scalar.width < 4 {
5690                    writeln!(
5691                        self.out,
5692                        "{level}return as_type<{type_name}>(static_cast<{unsigned_type_name}>(-as_type<{unsigned_type_name}>(val)));"
5693                    )?;
5694                } else {
5695                    writeln!(
5696                        self.out,
5697                        "{level}return as_type<{type_name}>(-as_type<{unsigned_type_name}>(val));"
5698                    )?;
5699                }
5700                writeln!(self.out, "}}")?;
5701                writeln!(self.out)?;
5702            }
5703            _ => {}
5704        }
5705        Ok(())
5706    }
5707
5708    fn write_wrapped_binary_op(
5709        &mut self,
5710        module: &crate::Module,
5711        func_ctx: &back::FunctionCtx,
5712        expr: Handle<crate::Expression>,
5713        op: crate::BinaryOperator,
5714        left: Handle<crate::Expression>,
5715        right: Handle<crate::Expression>,
5716    ) -> BackendResult {
5717        let expr_ty = func_ctx.resolve_type(expr, &module.types);
5718        let left_ty = func_ctx.resolve_type(left, &module.types);
5719        let right_ty = func_ctx.resolve_type(right, &module.types);
5720        match (op, expr_ty.scalar_kind()) {
5721            // Signed integer division of TYPE_MIN / -1, or signed or
5722            // unsigned division by zero, gives an unspecified value in MSL.
5723            // We override the divisor to 1 in these cases.
5724            // This adheres to the WGSL spec in that:
5725            // * TYPE_MIN / -1 == TYPE_MIN
5726            // * x / 0 == x
5727            (
5728                crate::BinaryOperator::Divide,
5729                Some(crate::ScalarKind::Sint | crate::ScalarKind::Uint),
5730            ) => {
5731                let Some(left_wrapped_ty) = left_ty.vector_size_and_scalar() else {
5732                    return Ok(());
5733                };
5734                let Some(right_wrapped_ty) = right_ty.vector_size_and_scalar() else {
5735                    return Ok(());
5736                };
5737                let wrapped = WrappedFunction::BinaryOp {
5738                    op,
5739                    left_ty: left_wrapped_ty,
5740                    right_ty: right_wrapped_ty,
5741                };
5742                if !self.wrapped_functions.insert(wrapped) {
5743                    return Ok(());
5744                }
5745
5746                let Some((vector_size, scalar)) = expr_ty.vector_size_and_scalar() else {
5747                    return Ok(());
5748                };
5749                let mut type_name = String::new();
5750                match vector_size {
5751                    None => put_numeric_type(&mut type_name, scalar, &[])?,
5752                    Some(size) => put_numeric_type(&mut type_name, scalar, &[size])?,
5753                };
5754                writeln!(
5755                    self.out,
5756                    "{type_name} {DIV_FUNCTION}({type_name} lhs, {type_name} rhs) {{"
5757                )?;
5758                let level = back::Level(1);
5759                // Sub-32-bit types need typed literal wrappers (e.g. `short(1)`)
5760                // to avoid ambiguous metal::select overloads. For >= 32-bit,
5761                // bare literals like `1`, `-1`, `0` are unambiguous.
5762                let (lp, rp) = if scalar.width < 4 {
5763                    (format!("{type_name}("), ")".to_string())
5764                } else {
5765                    (String::new(), String::new())
5766                };
5767                match scalar.kind {
5768                    crate::ScalarKind::Sint => {
5769                        let min_val = match scalar.width {
5770                            2 => crate::Literal::I16(i16::MIN),
5771                            4 => crate::Literal::I32(i32::MIN),
5772                            8 => crate::Literal::I64(i64::MIN),
5773                            _ => {
5774                                return Err(Error::GenericValidation(format!(
5775                                    "Unexpected width for scalar {scalar:?}"
5776                                )));
5777                            }
5778                        };
5779                        write!(
5780                            self.out,
5781                            "{level}return lhs / metal::select(rhs, {lp}1{rp}, (lhs == "
5782                        )?;
5783                        self.put_literal(min_val)?;
5784                        writeln!(self.out, " & rhs == {lp}-1{rp}) | (rhs == {lp}0{rp}));")?
5785                    }
5786                    crate::ScalarKind::Uint => {
5787                        let suffix = if scalar.width < 4 { "" } else { "u" };
5788                        writeln!(
5789                            self.out,
5790                            "{level}return lhs / metal::select(rhs, {lp}1{suffix}{rp}, rhs == {lp}0{suffix}{rp});"
5791                        )?
5792                    }
5793                    _ => unreachable!(),
5794                }
5795                writeln!(self.out, "}}")?;
5796                writeln!(self.out)?;
5797            }
5798            // Integer modulo where one or both operands are negative, or the
5799            // divisor is zero, is undefined behaviour in MSL. To avoid this
5800            // we use the following equation:
5801            //
5802            // dividend - (dividend / divisor) * divisor
5803            //
5804            // overriding the divisor to 1 if either it is 0, or it is -1
5805            // and the dividend is TYPE_MIN.
5806            //
5807            // This adheres to the WGSL spec in that:
5808            // * TYPE_MIN % -1 == 0
5809            // * x % 0 == 0
5810            (
5811                crate::BinaryOperator::Modulo,
5812                Some(crate::ScalarKind::Sint | crate::ScalarKind::Uint),
5813            ) => {
5814                let Some(left_wrapped_ty) = left_ty.vector_size_and_scalar() else {
5815                    return Ok(());
5816                };
5817                let Some((right_vector_size, right_scalar)) = right_ty.vector_size_and_scalar()
5818                else {
5819                    return Ok(());
5820                };
5821                let wrapped = WrappedFunction::BinaryOp {
5822                    op,
5823                    left_ty: left_wrapped_ty,
5824                    right_ty: (right_vector_size, right_scalar),
5825                };
5826                if !self.wrapped_functions.insert(wrapped) {
5827                    return Ok(());
5828                }
5829
5830                let Some((vector_size, scalar)) = expr_ty.vector_size_and_scalar() else {
5831                    return Ok(());
5832                };
5833                let mut type_name = String::new();
5834                match vector_size {
5835                    None => put_numeric_type(&mut type_name, scalar, &[])?,
5836                    Some(size) => put_numeric_type(&mut type_name, scalar, &[size])?,
5837                };
5838                let mut rhs_type_name = String::new();
5839                match right_vector_size {
5840                    None => put_numeric_type(&mut rhs_type_name, right_scalar, &[])?,
5841                    Some(size) => put_numeric_type(&mut rhs_type_name, right_scalar, &[size])?,
5842                };
5843
5844                writeln!(
5845                    self.out,
5846                    "{type_name} {MOD_FUNCTION}({type_name} lhs, {type_name} rhs) {{"
5847                )?;
5848                let level = back::Level(1);
5849                let (lp, rp) = if scalar.width < 4 {
5850                    (format!("{type_name}("), ")".to_string())
5851                } else {
5852                    (String::new(), String::new())
5853                };
5854                match scalar.kind {
5855                    crate::ScalarKind::Sint => {
5856                        let min_val = match scalar.width {
5857                            2 => crate::Literal::I16(i16::MIN),
5858                            4 => crate::Literal::I32(i32::MIN),
5859                            8 => crate::Literal::I64(i64::MIN),
5860                            _ => {
5861                                return Err(Error::GenericValidation(format!(
5862                                    "Unexpected width for scalar {scalar:?}"
5863                                )));
5864                            }
5865                        };
5866                        write!(
5867                            self.out,
5868                            "{level}{rhs_type_name} divisor = metal::select(rhs, {lp}1{rp}, (lhs == "
5869                        )?;
5870                        self.put_literal(min_val)?;
5871                        writeln!(self.out, " & rhs == {lp}-1{rp}) | (rhs == {lp}0{rp}));")?;
5872                        writeln!(self.out, "{level}return lhs - (lhs / divisor) * divisor;")?
5873                    }
5874                    crate::ScalarKind::Uint => {
5875                        let suffix = if scalar.width < 4 { "" } else { "u" };
5876                        writeln!(
5877                            self.out,
5878                            "{level}return lhs % metal::select(rhs, {lp}1{suffix}{rp}, rhs == {lp}0{suffix}{rp});"
5879                        )?
5880                    }
5881                    _ => unreachable!(),
5882                }
5883                writeln!(self.out, "}}")?;
5884                writeln!(self.out)?;
5885            }
5886            _ => {}
5887        }
5888        Ok(())
5889    }
5890
5891    /// Build the mangled helper name for integer vector dot products.
5892    ///
5893    /// `scalar` must be a concrete integer scalar type.
5894    ///
5895    /// Result format: `{DOT_FUNCTION_PREFIX}_{type}{N}` (e.g., `naga_dot_int3`).
5896    fn get_dot_wrapper_function_helper_name(
5897        &self,
5898        scalar: crate::Scalar,
5899        size: crate::VectorSize,
5900    ) -> String {
5901        // Check for consistency with [`super::keywords::RESERVED_SET`]
5902        debug_assert!(concrete_int_scalars().any(|s| s == scalar));
5903
5904        let type_name = scalar.to_msl_name();
5905        let size_suffix = common::vector_size_str(size);
5906        format!("{DOT_FUNCTION_PREFIX}_{type_name}{size_suffix}")
5907    }
5908
5909    #[allow(clippy::too_many_arguments)]
5910    fn write_wrapped_math_function(
5911        &mut self,
5912        module: &crate::Module,
5913        func_ctx: &back::FunctionCtx,
5914        fun: crate::MathFunction,
5915        arg: Handle<crate::Expression>,
5916        _arg1: Option<Handle<crate::Expression>>,
5917        _arg2: Option<Handle<crate::Expression>>,
5918        _arg3: Option<Handle<crate::Expression>>,
5919    ) -> BackendResult {
5920        let arg_ty = func_ctx.resolve_type(arg, &module.types);
5921        match fun {
5922            // Taking the absolute value of the TYPE_MIN of a two's
5923            // complement signed integer type causes overflow, which is
5924            // undefined behaviour in MSL. To avoid this, when the value is
5925            // negative we bitcast the value to unsigned and negate it, then
5926            // bitcast back to signed.
5927            // This adheres to the WGSL spec in that the absolute of the
5928            // type's minimum value should equal to the minimum value.
5929            crate::MathFunction::Abs if arg_ty.scalar_kind() == Some(crate::ScalarKind::Sint) => {
5930                let Some((vector_size, scalar)) = arg_ty.vector_size_and_scalar() else {
5931                    return Ok(());
5932                };
5933                let wrapped = WrappedFunction::Math {
5934                    fun,
5935                    arg_ty: (vector_size, scalar),
5936                };
5937                if !self.wrapped_functions.insert(wrapped) {
5938                    return Ok(());
5939                }
5940
5941                let unsigned_scalar = crate::Scalar {
5942                    kind: crate::ScalarKind::Uint,
5943                    ..scalar
5944                };
5945                let mut type_name = String::new();
5946                let mut unsigned_type_name = String::new();
5947                match vector_size {
5948                    None => {
5949                        put_numeric_type(&mut type_name, scalar, &[])?;
5950                        put_numeric_type(&mut unsigned_type_name, unsigned_scalar, &[])?
5951                    }
5952                    Some(size) => {
5953                        put_numeric_type(&mut type_name, scalar, &[size])?;
5954                        put_numeric_type(&mut unsigned_type_name, unsigned_scalar, &[size])?;
5955                    }
5956                };
5957
5958                writeln!(self.out, "{type_name} {ABS_FUNCTION}({type_name} val) {{")?;
5959                let level = back::Level(1);
5960                let zero = if scalar.width < 4 {
5961                    format!("{type_name}(0)")
5962                } else {
5963                    "0".to_string()
5964                };
5965                let neg_expr = if scalar.width < 4 {
5966                    format!(
5967                        "static_cast<{unsigned_type_name}>(-as_type<{unsigned_type_name}>(val))"
5968                    )
5969                } else {
5970                    format!("-as_type<{unsigned_type_name}>(val)")
5971                };
5972                writeln!(self.out, "{level}return metal::select(as_type<{type_name}>({neg_expr}), val, val >= {zero});")?;
5973                writeln!(self.out, "}}")?;
5974                writeln!(self.out)?;
5975            }
5976
5977            crate::MathFunction::Dot => match *arg_ty {
5978                crate::TypeInner::Vector { size, scalar }
5979                    if matches!(
5980                        scalar.kind,
5981                        crate::ScalarKind::Sint | crate::ScalarKind::Uint
5982                    ) =>
5983                {
5984                    // De-duplicate per (fun, arg type) like other wrapped math functions
5985                    let wrapped = WrappedFunction::Math {
5986                        fun,
5987                        arg_ty: (Some(size), scalar),
5988                    };
5989                    if !self.wrapped_functions.insert(wrapped) {
5990                        return Ok(());
5991                    }
5992
5993                    let mut vec_ty = String::new();
5994                    put_numeric_type(&mut vec_ty, scalar, &[size])?;
5995                    let mut ret_ty = String::new();
5996                    put_numeric_type(&mut ret_ty, scalar, &[])?;
5997
5998                    let fun_name = self.get_dot_wrapper_function_helper_name(scalar, size);
5999
6000                    // Emit function signature and body using put_dot_product for the expression
6001                    writeln!(self.out, "{ret_ty} {fun_name}({vec_ty} a, {vec_ty} b) {{")?;
6002                    let level = back::Level(1);
6003                    write!(self.out, "{level}return ")?;
6004                    self.put_dot_product("a", "b", size as usize, |writer, name, index| {
6005                        write!(writer.out, "{name}.{}", back::COMPONENTS[index])?;
6006                        Ok(())
6007                    })?;
6008                    writeln!(self.out, ";")?;
6009                    writeln!(self.out, "}}")?;
6010                    writeln!(self.out)?;
6011                }
6012                _ => {}
6013            },
6014
6015            _ => {}
6016        }
6017        Ok(())
6018    }
6019
6020    fn write_wrapped_cast(
6021        &mut self,
6022        module: &crate::Module,
6023        func_ctx: &back::FunctionCtx,
6024        expr: Handle<crate::Expression>,
6025        kind: crate::ScalarKind,
6026        convert: Option<crate::Bytes>,
6027    ) -> BackendResult {
6028        // Avoid undefined behaviour when casting from a float to integer
6029        // when the value is out of range for the target type. Additionally
6030        // ensure we clamp to the correct value as per the WGSL spec.
6031        //
6032        // https://www.w3.org/TR/WGSL/#floating-point-conversion:
6033        // * If X is exactly representable in the target type T, then the
6034        //   result is that value.
6035        // * Otherwise, the result is the value in T closest to
6036        //   truncate(X) and also exactly representable in the original
6037        //   floating point type.
6038        let src_ty = func_ctx.resolve_type(expr, &module.types);
6039        let Some(width) = convert else {
6040            return Ok(());
6041        };
6042        let Some((vector_size, src_scalar)) = src_ty.vector_size_and_scalar() else {
6043            return Ok(());
6044        };
6045        let dst_scalar = crate::Scalar { kind, width };
6046        if src_scalar.kind != crate::ScalarKind::Float
6047            || (dst_scalar.kind != crate::ScalarKind::Sint
6048                && dst_scalar.kind != crate::ScalarKind::Uint)
6049        {
6050            return Ok(());
6051        }
6052        let wrapped = WrappedFunction::Cast {
6053            src_scalar,
6054            vector_size,
6055            dst_scalar,
6056        };
6057        if !self.wrapped_functions.insert(wrapped) {
6058            return Ok(());
6059        }
6060        let (min, max) = proc::min_max_float_representable_by(src_scalar, dst_scalar);
6061
6062        let mut src_type_name = String::new();
6063        match vector_size {
6064            None => put_numeric_type(&mut src_type_name, src_scalar, &[])?,
6065            Some(size) => put_numeric_type(&mut src_type_name, src_scalar, &[size])?,
6066        };
6067        let mut dst_type_name = String::new();
6068        match vector_size {
6069            None => put_numeric_type(&mut dst_type_name, dst_scalar, &[])?,
6070            Some(size) => put_numeric_type(&mut dst_type_name, dst_scalar, &[size])?,
6071        };
6072        let fun_name = match dst_scalar {
6073            crate::Scalar::I32 => F2I32_FUNCTION,
6074            crate::Scalar::U32 => F2U32_FUNCTION,
6075            crate::Scalar::I64 => F2I64_FUNCTION,
6076            crate::Scalar::U64 => F2U64_FUNCTION,
6077            _ => unreachable!(),
6078        };
6079
6080        writeln!(
6081            self.out,
6082            "{dst_type_name} {fun_name}({src_type_name} value) {{"
6083        )?;
6084        let level = back::Level(1);
6085        write!(
6086            self.out,
6087            "{level}return static_cast<{dst_type_name}>({NAMESPACE}::clamp(value, "
6088        )?;
6089        self.put_literal(min)?;
6090        write!(self.out, ", ")?;
6091        self.put_literal(max)?;
6092        writeln!(self.out, "));")?;
6093        writeln!(self.out, "}}")?;
6094        writeln!(self.out)?;
6095        Ok(())
6096    }
6097
6098    /// Helper function used by [`Self::write_wrapped_image_load`] and
6099    /// [`Self::write_wrapped_image_sample`] to write the shared YUV to RGB
6100    /// conversion code for external textures. Expects the preceding code to
6101    /// declare the Y component as a `float` variable of name `y`, the UV
6102    /// components as a `float2` variable of name `uv`, and the external
6103    /// texture params as a variable of name `params`. The emitted code will
6104    /// return the result.
6105    fn write_convert_yuv_to_rgb_and_return(
6106        &mut self,
6107        level: back::Level,
6108        y: &str,
6109        uv: &str,
6110        params: &str,
6111    ) -> BackendResult {
6112        let l1 = level;
6113        let l2 = l1.next();
6114
6115        // Convert from YUV to non-linear RGB in the source color space.
6116        writeln!(
6117            self.out,
6118            "{l1}float3 srcGammaRgb = ({params}.yuv_conversion_matrix * float4({y}, {uv}, 1.0)).rgb;"
6119        )?;
6120
6121        // Apply the inverse of the source transfer function to convert to
6122        // linear RGB in the source color space.
6123        writeln!(self.out, "{l1}float3 srcLinearRgb = {NAMESPACE}::select(")?;
6124        writeln!(self.out, "{l2}{NAMESPACE}::pow((srcGammaRgb + {params}.src_tf.a - 1.0) / {params}.src_tf.a, {params}.src_tf.g),")?;
6125        writeln!(self.out, "{l2}srcGammaRgb / {params}.src_tf.k,")?;
6126        writeln!(
6127            self.out,
6128            "{l2}srcGammaRgb < {params}.src_tf.k * {params}.src_tf.b);"
6129        )?;
6130
6131        // Multiply by the gamut conversion matrix to convert to linear RGB in
6132        // the destination color space.
6133        writeln!(
6134            self.out,
6135            "{l1}float3 dstLinearRgb = {params}.gamut_conversion_matrix * srcLinearRgb;"
6136        )?;
6137
6138        // Finally, apply the dest transfer function to convert to non-linear
6139        // RGB in the destination color space, and return the result.
6140        writeln!(self.out, "{l1}float3 dstGammaRgb = {NAMESPACE}::select(")?;
6141        writeln!(self.out, "{l2}{params}.dst_tf.a * {NAMESPACE}::pow(dstLinearRgb, 1.0 / {params}.dst_tf.g) - ({params}.dst_tf.a - 1),")?;
6142        writeln!(self.out, "{l2}{params}.dst_tf.k * dstLinearRgb,")?;
6143        writeln!(self.out, "{l2}dstLinearRgb < {params}.dst_tf.b);")?;
6144
6145        writeln!(self.out, "{l1}return float4(dstGammaRgb, 1.0);")?;
6146        Ok(())
6147    }
6148
6149    #[allow(clippy::too_many_arguments)]
6150    fn write_wrapped_image_load(
6151        &mut self,
6152        module: &crate::Module,
6153        func_ctx: &back::FunctionCtx,
6154        image: Handle<crate::Expression>,
6155        _coordinate: Handle<crate::Expression>,
6156        _array_index: Option<Handle<crate::Expression>>,
6157        _sample: Option<Handle<crate::Expression>>,
6158        _level: Option<Handle<crate::Expression>>,
6159    ) -> BackendResult {
6160        // We currently only need to wrap image loads for external textures
6161        let class = match *func_ctx.resolve_type(image, &module.types) {
6162            crate::TypeInner::Image { class, .. } => class,
6163            _ => unreachable!(),
6164        };
6165        if class != crate::ImageClass::External {
6166            return Ok(());
6167        }
6168        let wrapped = WrappedFunction::ImageLoad { class };
6169        if !self.wrapped_functions.insert(wrapped) {
6170            return Ok(());
6171        }
6172
6173        writeln!(self.out, "float4 {IMAGE_LOAD_EXTERNAL_FUNCTION}({EXTERNAL_TEXTURE_WRAPPER_STRUCT} tex, uint2 coords) {{")?;
6174        let l1 = back::Level(1);
6175        let l2 = l1.next();
6176        let l3 = l2.next();
6177        writeln!(
6178            self.out,
6179            "{l1}uint2 plane0_size = uint2(tex.plane0.get_width(), tex.plane0.get_height());"
6180        )?;
6181        // Clamp coords to provided size of external texture to prevent OOB
6182        // read. If params.size is zero then clamp to the actual size of the
6183        // texture.
6184        writeln!(
6185            self.out,
6186            "{l1}uint2 cropped_size = {NAMESPACE}::any(tex.params.size != 0) ? tex.params.size : plane0_size;"
6187        )?;
6188        writeln!(
6189            self.out,
6190            "{l1}coords = {NAMESPACE}::min(coords, cropped_size - 1);"
6191        )?;
6192
6193        // Apply load transformation
6194        writeln!(self.out, "{l1}uint2 plane0_coords = uint2({NAMESPACE}::round(tex.params.load_transform * float3(float2(coords), 1.0)));")?;
6195        writeln!(self.out, "{l1}if (tex.params.num_planes == 1u) {{")?;
6196        // For single plane, simply read from plane0
6197        writeln!(self.out, "{l2}return tex.plane0.read(plane0_coords);")?;
6198        writeln!(self.out, "{l1}}} else {{")?;
6199
6200        // Chroma planes may be subsampled so we must scale the coords accordingly.
6201        writeln!(
6202            self.out,
6203            "{l2}uint2 plane1_size = uint2(tex.plane1.get_width(), tex.plane1.get_height());"
6204        )?;
6205        writeln!(self.out, "{l2}uint2 plane1_coords = uint2({NAMESPACE}::floor(float2(plane0_coords) * float2(plane1_size) / float2(plane0_size)));")?;
6206
6207        // For multi-plane, read the Y value from plane 0
6208        writeln!(self.out, "{l2}float y = tex.plane0.read(plane0_coords).x;")?;
6209
6210        writeln!(self.out, "{l2}float2 uv;")?;
6211        writeln!(self.out, "{l2}if (tex.params.num_planes == 2u) {{")?;
6212        // For 2 planes, read UV from interleaved plane 1
6213        writeln!(self.out, "{l3}uv = tex.plane1.read(plane1_coords).xy;")?;
6214        writeln!(self.out, "{l2}}} else {{")?;
6215        // For 3 planes, read U and V from planes 1 and 2 respectively
6216        writeln!(
6217            self.out,
6218            "{l2}uint2 plane2_size = uint2(tex.plane2.get_width(), tex.plane2.get_height());"
6219        )?;
6220        writeln!(self.out, "{l2}uint2 plane2_coords = uint2({NAMESPACE}::floor(float2(plane0_coords) * float2(plane2_size) / float2(plane0_size)));")?;
6221        writeln!(
6222            self.out,
6223            "{l3}uv = float2(tex.plane1.read(plane1_coords).x, tex.plane2.read(plane2_coords).x);"
6224        )?;
6225        writeln!(self.out, "{l2}}}")?;
6226
6227        self.write_convert_yuv_to_rgb_and_return(l2, "y", "uv", "tex.params")?;
6228
6229        writeln!(self.out, "{l1}}}")?;
6230        writeln!(self.out, "}}")?;
6231        writeln!(self.out)?;
6232        Ok(())
6233    }
6234
6235    #[allow(clippy::too_many_arguments)]
6236    fn write_wrapped_image_sample(
6237        &mut self,
6238        module: &crate::Module,
6239        func_ctx: &back::FunctionCtx,
6240        image: Handle<crate::Expression>,
6241        _sampler: Handle<crate::Expression>,
6242        _gather: Option<crate::SwizzleComponent>,
6243        _coordinate: Handle<crate::Expression>,
6244        _array_index: Option<Handle<crate::Expression>>,
6245        _offset: Option<Handle<crate::Expression>>,
6246        _level: crate::SampleLevel,
6247        _depth_ref: Option<Handle<crate::Expression>>,
6248        clamp_to_edge: bool,
6249    ) -> BackendResult {
6250        // We currently only need to wrap textureSampleBaseClampToEdge, for
6251        // both sampled and external textures.
6252        if !clamp_to_edge {
6253            return Ok(());
6254        }
6255        let class = match *func_ctx.resolve_type(image, &module.types) {
6256            crate::TypeInner::Image { class, .. } => class,
6257            _ => unreachable!(),
6258        };
6259        let wrapped = WrappedFunction::ImageSample {
6260            class,
6261            clamp_to_edge: true,
6262        };
6263        if !self.wrapped_functions.insert(wrapped) {
6264            return Ok(());
6265        }
6266        match class {
6267            crate::ImageClass::External => {
6268                writeln!(self.out, "float4 {IMAGE_SAMPLE_BASE_CLAMP_TO_EDGE_FUNCTION}({EXTERNAL_TEXTURE_WRAPPER_STRUCT} tex, {NAMESPACE}::sampler samp, float2 coords) {{")?;
6269                let l1 = back::Level(1);
6270                let l2 = l1.next();
6271                let l3 = l2.next();
6272                writeln!(self.out, "{l1}uint2 plane0_size = uint2(tex.plane0.get_width(), tex.plane0.get_height());")?;
6273                writeln!(
6274                    self.out,
6275                    "{l1}coords = tex.params.sample_transform * float3(coords, 1.0);"
6276                )?;
6277
6278                // Calculate the sample bounds. The purported size of the texture
6279                // (params.size) is irrelevant here as we are dealing with normalized
6280                // coordinates. Usually we would clamp to (0,0)..(1,1). However, we must
6281                // apply the sample transformation to that, also bearing in mind that it
6282                // may contain a flip on either axis. We calculate and adjust for the
6283                // half-texel separately for each plane as it depends on the actual
6284                // texture size which may vary between planes.
6285                writeln!(
6286                    self.out,
6287                    "{l1}float2 bounds_min = tex.params.sample_transform * float3(0.0, 0.0, 1.0);"
6288                )?;
6289                writeln!(
6290                    self.out,
6291                    "{l1}float2 bounds_max = tex.params.sample_transform * float3(1.0, 1.0, 1.0);"
6292                )?;
6293                writeln!(self.out, "{l1}float4 bounds = float4({NAMESPACE}::min(bounds_min, bounds_max), {NAMESPACE}::max(bounds_min, bounds_max));")?;
6294                writeln!(
6295                    self.out,
6296                    "{l1}float2 plane0_half_texel = float2(0.5, 0.5) / float2(plane0_size);"
6297                )?;
6298                writeln!(
6299                    self.out,
6300                    "{l1}float2 plane0_coords = {NAMESPACE}::clamp(coords, bounds.xy + plane0_half_texel, bounds.zw - plane0_half_texel);"
6301                )?;
6302                writeln!(self.out, "{l1}if (tex.params.num_planes == 1u) {{")?;
6303                // For single plane, simply sample from plane0
6304                writeln!(
6305                    self.out,
6306                    "{l2}return tex.plane0.sample(samp, plane0_coords, {NAMESPACE}::level(0.0f));"
6307                )?;
6308                writeln!(self.out, "{l1}}} else {{")?;
6309                writeln!(self.out, "{l2}uint2 plane1_size = uint2(tex.plane1.get_width(), tex.plane1.get_height());")?;
6310                writeln!(
6311                    self.out,
6312                    "{l2}float2 plane1_half_texel = float2(0.5, 0.5) / float2(plane1_size);"
6313                )?;
6314                writeln!(
6315                    self.out,
6316                    "{l2}float2 plane1_coords = {NAMESPACE}::clamp(coords, bounds.xy + plane1_half_texel, bounds.zw - plane1_half_texel);"
6317                )?;
6318
6319                // For multi-plane, sample the Y value from plane 0
6320                writeln!(
6321                    self.out,
6322                    "{l2}float y = tex.plane0.sample(samp, plane0_coords, {NAMESPACE}::level(0.0f)).r;"
6323                )?;
6324                writeln!(self.out, "{l2}float2 uv = float2(0.0, 0.0);")?;
6325                writeln!(self.out, "{l2}if (tex.params.num_planes == 2u) {{")?;
6326                // For 2 planes, sample UV from interleaved plane 1
6327                writeln!(
6328                    self.out,
6329                    "{l3}uv = tex.plane1.sample(samp, plane1_coords, {NAMESPACE}::level(0.0f)).xy;"
6330                )?;
6331                writeln!(self.out, "{l2}}} else {{")?;
6332                // For 3 planes, sample U and V from planes 1 and 2 respectively
6333                writeln!(self.out, "{l3}uint2 plane2_size = uint2(tex.plane2.get_width(), tex.plane2.get_height());")?;
6334                writeln!(
6335                    self.out,
6336                    "{l3}float2 plane2_half_texel = float2(0.5, 0.5) / float2(plane2_size);"
6337                )?;
6338                writeln!(
6339                    self.out,
6340                    "{l3}float2 plane2_coords = {NAMESPACE}::clamp(coords, bounds.xy + plane2_half_texel, bounds.zw - plane1_half_texel);"
6341                )?;
6342                writeln!(self.out, "{l3}uv.x = tex.plane1.sample(samp, plane1_coords, {NAMESPACE}::level(0.0f)).x;")?;
6343                writeln!(self.out, "{l3}uv.y = tex.plane2.sample(samp, plane2_coords, {NAMESPACE}::level(0.0f)).x;")?;
6344                writeln!(self.out, "{l2}}}")?;
6345
6346                self.write_convert_yuv_to_rgb_and_return(l2, "y", "uv", "tex.params")?;
6347
6348                writeln!(self.out, "{l1}}}")?;
6349                writeln!(self.out, "}}")?;
6350                writeln!(self.out)?;
6351            }
6352            _ => {
6353                writeln!(self.out, "{NAMESPACE}::float4 {IMAGE_SAMPLE_BASE_CLAMP_TO_EDGE_FUNCTION}({NAMESPACE}::texture2d<float, {NAMESPACE}::access::sample> tex, {NAMESPACE}::sampler samp, {NAMESPACE}::float2 coords) {{")?;
6354                let l1 = back::Level(1);
6355                writeln!(self.out, "{l1}{NAMESPACE}::float2 half_texel = 0.5 / {NAMESPACE}::float2(tex.get_width(0u), tex.get_height(0u));")?;
6356                writeln!(
6357                    self.out,
6358                    "{l1}return tex.sample(samp, {NAMESPACE}::clamp(coords, half_texel, 1.0 - half_texel), {NAMESPACE}::level(0.0));"
6359                )?;
6360                writeln!(self.out, "}}")?;
6361                writeln!(self.out)?;
6362            }
6363        }
6364        Ok(())
6365    }
6366
6367    fn write_wrapped_image_query(
6368        &mut self,
6369        module: &crate::Module,
6370        func_ctx: &back::FunctionCtx,
6371        image: Handle<crate::Expression>,
6372        query: crate::ImageQuery,
6373    ) -> BackendResult {
6374        // We currently only need to wrap size image queries for external textures
6375        if !matches!(query, crate::ImageQuery::Size { .. }) {
6376            return Ok(());
6377        }
6378        let class = match *func_ctx.resolve_type(image, &module.types) {
6379            crate::TypeInner::Image { class, .. } => class,
6380            _ => unreachable!(),
6381        };
6382        if class != crate::ImageClass::External {
6383            return Ok(());
6384        }
6385        let wrapped = WrappedFunction::ImageQuerySize { class };
6386        if !self.wrapped_functions.insert(wrapped) {
6387            return Ok(());
6388        }
6389        writeln!(
6390            self.out,
6391            "uint2 {IMAGE_SIZE_EXTERNAL_FUNCTION}({EXTERNAL_TEXTURE_WRAPPER_STRUCT} tex) {{"
6392        )?;
6393        let l1 = back::Level(1);
6394        let l2 = l1.next();
6395        writeln!(
6396            self.out,
6397            "{l1}if ({NAMESPACE}::any(tex.params.size != uint2(0u))) {{"
6398        )?;
6399        writeln!(self.out, "{l2}return tex.params.size;")?;
6400        writeln!(self.out, "{l1}}} else {{")?;
6401        // params.size == (0, 0) indicates to query and return plane 0's actual size
6402        writeln!(
6403            self.out,
6404            "{l2}return uint2(tex.plane0.get_width(), tex.plane0.get_height());"
6405        )?;
6406        writeln!(self.out, "{l1}}}")?;
6407        writeln!(self.out, "}}")?;
6408        writeln!(self.out)?;
6409        Ok(())
6410    }
6411
6412    fn write_wrapped_cooperative_load(
6413        &mut self,
6414        module: &crate::Module,
6415        func_ctx: &back::FunctionCtx,
6416        columns: crate::CooperativeSize,
6417        rows: crate::CooperativeSize,
6418        pointer: Handle<crate::Expression>,
6419    ) -> BackendResult {
6420        let ptr_ty = func_ctx.resolve_type(pointer, &module.types);
6421        let space = ptr_ty.pointer_space().unwrap();
6422        let space_name = space.to_msl_name().unwrap_or_default();
6423        let scalar = ptr_ty
6424            .pointer_base_type()
6425            .unwrap()
6426            .inner_with(&module.types)
6427            .scalar()
6428            .unwrap();
6429        let wrapped = WrappedFunction::CooperativeLoad {
6430            space_name,
6431            columns,
6432            rows,
6433            scalar,
6434        };
6435        if !self.wrapped_functions.insert(wrapped) {
6436            return Ok(());
6437        }
6438        let scalar_name = scalar.to_msl_name();
6439        writeln!(
6440            self.out,
6441            "{NAMESPACE}::simdgroup_{scalar_name}{}x{} {COOPERATIVE_LOAD_FUNCTION}(const {space_name} {scalar_name}* ptr, int stride, bool is_row_major) {{",
6442            columns as u32, rows as u32,
6443        )?;
6444        let l1 = back::Level(1);
6445        writeln!(
6446            self.out,
6447            "{l1}{NAMESPACE}::simdgroup_{scalar_name}{}x{} m;",
6448            columns as u32, rows as u32
6449        )?;
6450        let matrix_origin = "0";
6451        writeln!(
6452            self.out,
6453            "{l1}simdgroup_load(m, ptr, stride, {matrix_origin}, is_row_major);"
6454        )?;
6455        writeln!(self.out, "{l1}return m;")?;
6456        writeln!(self.out, "}}")?;
6457        writeln!(self.out)?;
6458        Ok(())
6459    }
6460
6461    fn write_wrapped_cooperative_multiply_add(
6462        &mut self,
6463        module: &crate::Module,
6464        func_ctx: &back::FunctionCtx,
6465        space: crate::AddressSpace,
6466        a: Handle<crate::Expression>,
6467        b: Handle<crate::Expression>,
6468    ) -> BackendResult {
6469        let space_name = space.to_msl_name().unwrap_or_default();
6470        let (a_c, a_r, scalar) = match *func_ctx.resolve_type(a, &module.types) {
6471            crate::TypeInner::CooperativeMatrix {
6472                columns,
6473                rows,
6474                scalar,
6475                ..
6476            } => (columns, rows, scalar),
6477            _ => unreachable!(),
6478        };
6479        let (b_c, b_r) = match *func_ctx.resolve_type(b, &module.types) {
6480            crate::TypeInner::CooperativeMatrix { columns, rows, .. } => (columns, rows),
6481            _ => unreachable!(),
6482        };
6483        let wrapped = WrappedFunction::CooperativeMultiplyAdd {
6484            space_name,
6485            columns: b_c,
6486            rows: a_r,
6487            intermediate: a_c,
6488            scalar,
6489        };
6490        if !self.wrapped_functions.insert(wrapped) {
6491            return Ok(());
6492        }
6493        let scalar_name = scalar.to_msl_name();
6494        writeln!(
6495            self.out,
6496            "{NAMESPACE}::simdgroup_{scalar_name}{}x{} {COOPERATIVE_MULTIPLY_ADD_FUNCTION}(const {space_name} {NAMESPACE}::simdgroup_{scalar_name}{}x{}& a, const {space_name} {NAMESPACE}::simdgroup_{scalar_name}{}x{}& b, const {space_name} {NAMESPACE}::simdgroup_{scalar_name}{}x{}& c) {{",
6497            b_c as u32, a_r as u32, a_c as u32, a_r as u32, b_c as u32, b_r as u32, b_c as u32, a_r as u32,
6498        )?;
6499        let l1 = back::Level(1);
6500        writeln!(
6501            self.out,
6502            "{l1}{NAMESPACE}::simdgroup_{scalar_name}{}x{} d;",
6503            b_c as u32, a_r as u32
6504        )?;
6505        writeln!(self.out, "{l1}simdgroup_multiply_accumulate(d,a,b,c);")?;
6506        writeln!(self.out, "{l1}return d;")?;
6507        writeln!(self.out, "}}")?;
6508        writeln!(self.out)?;
6509        Ok(())
6510    }
6511
6512    pub(super) fn write_wrapped_functions(
6513        &mut self,
6514        module: &crate::Module,
6515        func_ctx: &back::FunctionCtx,
6516    ) -> BackendResult {
6517        for (expr_handle, expr) in func_ctx.expressions.iter() {
6518            match *expr {
6519                crate::Expression::Unary { op, expr: operand } => {
6520                    self.write_wrapped_unary_op(module, func_ctx, op, operand)?;
6521                }
6522                crate::Expression::Binary { op, left, right } => {
6523                    self.write_wrapped_binary_op(module, func_ctx, expr_handle, op, left, right)?;
6524                }
6525                crate::Expression::Math {
6526                    fun,
6527                    arg,
6528                    arg1,
6529                    arg2,
6530                    arg3,
6531                } => {
6532                    self.write_wrapped_math_function(module, func_ctx, fun, arg, arg1, arg2, arg3)?;
6533                }
6534                crate::Expression::As {
6535                    expr,
6536                    kind,
6537                    convert,
6538                } => {
6539                    self.write_wrapped_cast(module, func_ctx, expr, kind, convert)?;
6540                }
6541                crate::Expression::ImageLoad {
6542                    image,
6543                    coordinate,
6544                    array_index,
6545                    sample,
6546                    level,
6547                } => {
6548                    self.write_wrapped_image_load(
6549                        module,
6550                        func_ctx,
6551                        image,
6552                        coordinate,
6553                        array_index,
6554                        sample,
6555                        level,
6556                    )?;
6557                }
6558                crate::Expression::ImageSample {
6559                    image,
6560                    sampler,
6561                    gather,
6562                    coordinate,
6563                    array_index,
6564                    offset,
6565                    level,
6566                    depth_ref,
6567                    clamp_to_edge,
6568                } => {
6569                    self.write_wrapped_image_sample(
6570                        module,
6571                        func_ctx,
6572                        image,
6573                        sampler,
6574                        gather,
6575                        coordinate,
6576                        array_index,
6577                        offset,
6578                        level,
6579                        depth_ref,
6580                        clamp_to_edge,
6581                    )?;
6582                }
6583                crate::Expression::ImageQuery { image, query } => {
6584                    self.write_wrapped_image_query(module, func_ctx, image, query)?;
6585                }
6586                crate::Expression::CooperativeLoad {
6587                    columns,
6588                    rows,
6589                    role: _,
6590                    ref data,
6591                } => {
6592                    self.write_wrapped_cooperative_load(
6593                        module,
6594                        func_ctx,
6595                        columns,
6596                        rows,
6597                        data.pointer,
6598                    )?;
6599                }
6600                crate::Expression::CooperativeMultiplyAdd { a, b, c: _ } => {
6601                    let space = crate::AddressSpace::Private;
6602                    self.write_wrapped_cooperative_multiply_add(module, func_ctx, space, a, b)?;
6603                }
6604                crate::Expression::RayQueryGetIntersection { committed, .. } => {
6605                    self.write_rq_get_intersection_function(module, committed)?;
6606                }
6607                _ => {}
6608            }
6609        }
6610
6611        Ok(())
6612    }
6613
6614    // Returns the array of mapped entry point names.
6615    fn write_functions(
6616        &mut self,
6617        module: &crate::Module,
6618        mod_info: &valid::ModuleInfo,
6619        options: &Options,
6620        pipeline_options: &PipelineOptions,
6621    ) -> Result<TranslationInfo, Error> {
6622        use back::msl::VertexFormat;
6623
6624        // Define structs to hold resolved/generated data for vertex buffers and
6625        // their attributes.
6626        struct AttributeMappingResolved {
6627            ty_name: String,
6628            dimension: Option<crate::VectorSize>,
6629            scalar: crate::Scalar,
6630            name: String,
6631        }
6632        let mut am_resolved = FastHashMap::<u32, AttributeMappingResolved>::default();
6633
6634        struct VertexBufferMappingResolved<'a> {
6635            id: u32,
6636            stride: u32,
6637            step_mode: back::msl::VertexBufferStepMode,
6638            ty_name: String,
6639            param_name: String,
6640            elem_name: String,
6641            attributes: &'a Vec<back::msl::AttributeMapping>,
6642        }
6643        let mut vbm_resolved = Vec::<VertexBufferMappingResolved>::new();
6644
6645        // Define a struct to hold a named reference to a byte-unpacking function.
6646        struct UnpackingFunction {
6647            name: String,
6648            byte_count: u32,
6649            dimension: Option<crate::VectorSize>,
6650            scalar: crate::Scalar,
6651        }
6652        let mut unpacking_functions = FastHashMap::<VertexFormat, UnpackingFunction>::default();
6653
6654        // Check if we are attempting vertex pulling. If we are, generate some
6655        // names we'll need, and iterate the vertex buffer mappings to output
6656        // all the conversion functions we'll need to unpack the attribute data.
6657        // We can re-use these names for all entry points that need them, since
6658        // those entry points also use self.namer.
6659        let mut needs_vertex_id = false;
6660        let v_id = self.namer.call("v_id");
6661
6662        let mut needs_instance_id = false;
6663        let i_id = self.namer.call("i_id");
6664        if pipeline_options.vertex_pulling_transform {
6665            for vbm in &pipeline_options.vertex_buffer_mappings {
6666                let buffer_id = vbm.id;
6667                let buffer_stride = vbm.stride;
6668
6669                assert!(
6670                    buffer_stride > 0,
6671                    "Vertex pulling requires a non-zero buffer stride."
6672                );
6673
6674                match vbm.step_mode {
6675                    back::msl::VertexBufferStepMode::Constant => {}
6676                    back::msl::VertexBufferStepMode::ByVertex => {
6677                        needs_vertex_id = true;
6678                    }
6679                    back::msl::VertexBufferStepMode::ByInstance => {
6680                        needs_instance_id = true;
6681                    }
6682                }
6683
6684                let buffer_ty = self.namer.call(format!("vb_{buffer_id}_type").as_str());
6685                let buffer_param = self.namer.call(format!("vb_{buffer_id}_in").as_str());
6686                let buffer_elem = self.namer.call(format!("vb_{buffer_id}_elem").as_str());
6687
6688                vbm_resolved.push(VertexBufferMappingResolved {
6689                    id: buffer_id,
6690                    stride: buffer_stride,
6691                    step_mode: vbm.step_mode,
6692                    ty_name: buffer_ty,
6693                    param_name: buffer_param,
6694                    elem_name: buffer_elem,
6695                    attributes: &vbm.attributes,
6696                });
6697
6698                // Iterate the attributes and generate needed unpacking functions.
6699                for attribute in &vbm.attributes {
6700                    if unpacking_functions.contains_key(&attribute.format) {
6701                        continue;
6702                    }
6703                    let (name, byte_count, dimension, scalar) =
6704                        match self.write_unpacking_function(attribute.format) {
6705                            Ok((name, byte_count, dimension, scalar)) => {
6706                                (name, byte_count, dimension, scalar)
6707                            }
6708                            _ => {
6709                                continue;
6710                            }
6711                        };
6712                    unpacking_functions.insert(
6713                        attribute.format,
6714                        UnpackingFunction {
6715                            name,
6716                            byte_count,
6717                            dimension,
6718                            scalar,
6719                        },
6720                    );
6721                }
6722            }
6723        }
6724
6725        let mut pass_through_globals = Vec::new();
6726        for (fun_handle, fun) in module.functions.iter() {
6727            log::trace!(
6728                "function {:?}, handle {:?}",
6729                fun.name.as_deref().unwrap_or("(anonymous)"),
6730                fun_handle
6731            );
6732
6733            let ctx = back::FunctionCtx {
6734                ty: back::FunctionType::Function(fun_handle),
6735                info: &mod_info[fun_handle],
6736                expressions: &fun.expressions,
6737                named_expressions: &fun.named_expressions,
6738            };
6739
6740            writeln!(self.out)?;
6741            self.write_wrapped_functions(module, &ctx)?;
6742
6743            let fun_info = &mod_info[fun_handle];
6744            pass_through_globals.clear();
6745            let mut needs_buffer_sizes = false;
6746            for (handle, var) in module.global_variables.iter() {
6747                if !fun_info[handle].is_empty() {
6748                    if var.space.needs_pass_through() {
6749                        pass_through_globals.push(handle);
6750                    }
6751                    needs_buffer_sizes |= needs_array_length(var.ty, &module.types);
6752                }
6753            }
6754
6755            let fun_name = &self.names[&NameKey::Function(fun_handle)];
6756            match fun.result {
6757                Some(ref result) => {
6758                    let ty_name = TypeContext {
6759                        handle: result.ty,
6760                        gctx: module.to_ctx(),
6761                        names: &self.names,
6762                        access: crate::StorageAccess::empty(),
6763                        first_time: false,
6764                    };
6765                    write!(self.out, "{ty_name}")?;
6766                }
6767                None => {
6768                    write!(self.out, "void")?;
6769                }
6770            }
6771            writeln!(self.out, " {fun_name}(")?;
6772
6773            for (index, arg) in fun.arguments.iter().enumerate() {
6774                let name = &self.names[&NameKey::FunctionArgument(fun_handle, index as u32)];
6775                let param_type_name = TypeContext {
6776                    handle: arg.ty,
6777                    gctx: module.to_ctx(),
6778                    names: &self.names,
6779                    access: crate::StorageAccess::empty(),
6780                    first_time: false,
6781                };
6782                let separator = separate(
6783                    !pass_through_globals.is_empty()
6784                        || index + 1 != fun.arguments.len()
6785                        || needs_buffer_sizes,
6786                );
6787                writeln!(
6788                    self.out,
6789                    "{}{} {}{}",
6790                    back::INDENT,
6791                    param_type_name,
6792                    name,
6793                    separator
6794                )?;
6795            }
6796            for (index, &handle) in pass_through_globals.iter().enumerate() {
6797                let tyvar = TypedGlobalVariable {
6798                    module,
6799                    names: &self.names,
6800                    handle,
6801                    usage: fun_info[handle],
6802                    reference: true,
6803                };
6804                let separator =
6805                    separate(index + 1 != pass_through_globals.len() || needs_buffer_sizes);
6806                write!(self.out, "{}", back::INDENT)?;
6807                tyvar.try_fmt(&mut self.out)?;
6808                writeln!(self.out, "{separator}")?;
6809            }
6810
6811            if needs_buffer_sizes {
6812                writeln!(
6813                    self.out,
6814                    "{}constant _mslBufferSizes& _buffer_sizes",
6815                    back::INDENT
6816                )?;
6817            }
6818
6819            writeln!(self.out, ") {{")?;
6820
6821            let guarded_indices =
6822                index::find_checked_indexes(module, fun, fun_info, options.bounds_check_policies);
6823
6824            let context = StatementContext {
6825                expression: ExpressionContext {
6826                    function: fun,
6827                    origin: FunctionOrigin::Handle(fun_handle),
6828                    info: fun_info,
6829                    lang_version: options.lang_version,
6830                    policies: options.bounds_check_policies,
6831                    guarded_indices,
6832                    module,
6833                    mod_info,
6834                    pipeline_options,
6835                    force_loop_bounding: options.force_loop_bounding,
6836                },
6837                result_struct: None,
6838            };
6839
6840            self.put_locals(&context.expression)?;
6841            self.update_expressions_to_bake(fun, fun_info, &context.expression);
6842            self.put_block(back::Level(1), &fun.body, &context)?;
6843            writeln!(self.out, "}}")?;
6844            self.named_expressions.clear();
6845        }
6846
6847        let ep_range = get_entry_points(module, pipeline_options.entry_point.as_ref())
6848            .map_err(|(stage, name)| Error::EntryPointNotFound(stage, name))?;
6849
6850        let mut info = TranslationInfo {
6851            entry_point_names: Vec::with_capacity(ep_range.len()),
6852        };
6853
6854        for ep_index in ep_range {
6855            let ep = &module.entry_points[ep_index];
6856            let fun = &ep.function;
6857            let fun_info = mod_info.get_entry_point(ep_index);
6858            let mut ep_error = None;
6859
6860            // For vertex_id and instance_id arguments, presume that we'll
6861            // use our generated names, but switch to the name of an
6862            // existing @builtin param, if we find one.
6863            let mut v_existing_id = None;
6864            let mut i_existing_id = None;
6865
6866            log::trace!(
6867                "entry point {:?}, index {:?}",
6868                fun.name.as_deref().unwrap_or("(anonymous)"),
6869                ep_index
6870            );
6871
6872            let ctx = back::FunctionCtx {
6873                ty: back::FunctionType::EntryPoint(ep_index as u16),
6874                info: fun_info,
6875                expressions: &fun.expressions,
6876                named_expressions: &fun.named_expressions,
6877            };
6878
6879            self.write_wrapped_functions(module, &ctx)?;
6880
6881            let (em_str, in_mode, out_mode, can_vertex_pull) = match ep.stage {
6882                crate::ShaderStage::Vertex => (
6883                    Some("vertex"),
6884                    LocationMode::VertexInput,
6885                    LocationMode::VertexOutput,
6886                    true,
6887                ),
6888                crate::ShaderStage::Fragment => (
6889                    Some("fragment"),
6890                    LocationMode::FragmentInput,
6891                    LocationMode::FragmentOutput,
6892                    false,
6893                ),
6894                crate::ShaderStage::Compute => (
6895                    Some("kernel"),
6896                    LocationMode::Uniform,
6897                    LocationMode::Uniform,
6898                    false,
6899                ),
6900                crate::ShaderStage::Task => {
6901                    (None, LocationMode::Uniform, LocationMode::Uniform, false)
6902                }
6903                crate::ShaderStage::Mesh => {
6904                    (None, LocationMode::Uniform, LocationMode::MeshOutput, false)
6905                }
6906                crate::ShaderStage::RayGeneration
6907                | crate::ShaderStage::AnyHit
6908                | crate::ShaderStage::ClosestHit
6909                | crate::ShaderStage::Miss => unimplemented!(),
6910            };
6911
6912            // Should this entry point be modified to do vertex pulling?
6913            let do_vertex_pulling = can_vertex_pull
6914                && pipeline_options.vertex_pulling_transform
6915                && !pipeline_options.vertex_buffer_mappings.is_empty();
6916
6917            // Is any global variable used by this entry point dynamically sized?
6918            let needs_buffer_sizes = do_vertex_pulling
6919                || module
6920                    .global_variables
6921                    .iter()
6922                    .filter(|&(handle, _)| !fun_info[handle].is_empty())
6923                    .any(|(_, var)| needs_array_length(var.ty, &module.types));
6924
6925            // skip this entry point if any global bindings are missing,
6926            // or their types are incompatible.
6927            if !options.fake_missing_bindings {
6928                for (var_handle, var) in module.global_variables.iter() {
6929                    if fun_info[var_handle].is_empty() {
6930                        continue;
6931                    }
6932                    match var.space {
6933                        crate::AddressSpace::Uniform
6934                        | crate::AddressSpace::Storage { .. }
6935                        | crate::AddressSpace::Handle => {
6936                            let br = match var.binding {
6937                                Some(ref br) => br,
6938                                None => {
6939                                    let var_name = var.name.clone().unwrap_or_default();
6940                                    ep_error =
6941                                        Some(super::EntryPointError::MissingBinding(var_name));
6942                                    break;
6943                                }
6944                            };
6945                            let target = options.get_resource_binding_target(ep, br);
6946                            let good = match target {
6947                                Some(target) => {
6948                                    // We intentionally don't dereference binding_arrays here,
6949                                    // so that binding arrays fall to the buffer location.
6950
6951                                    match module.types[var.ty].inner {
6952                                        crate::TypeInner::Image {
6953                                            class: crate::ImageClass::External,
6954                                            ..
6955                                        } => target.external_texture.is_some(),
6956                                        crate::TypeInner::Image { .. } => target.texture.is_some(),
6957                                        crate::TypeInner::Sampler { .. } => {
6958                                            target.sampler.is_some()
6959                                        }
6960                                        _ => target.buffer.is_some(),
6961                                    }
6962                                }
6963                                None => false,
6964                            };
6965                            if !good {
6966                                ep_error = Some(super::EntryPointError::MissingBindTarget(*br));
6967                                break;
6968                            }
6969                        }
6970                        crate::AddressSpace::Immediate => {
6971                            if let Err(e) = options.resolve_immediates(ep) {
6972                                ep_error = Some(e);
6973                                break;
6974                            }
6975                        }
6976                        crate::AddressSpace::Function
6977                        | crate::AddressSpace::Private
6978                        | crate::AddressSpace::WorkGroup
6979                        | crate::AddressSpace::TaskPayload => {}
6980                        crate::AddressSpace::RayPayload
6981                        | crate::AddressSpace::IncomingRayPayload => unimplemented!(),
6982                    }
6983                }
6984                if needs_buffer_sizes {
6985                    if let Err(err) = options.resolve_sizes_buffer(ep) {
6986                        ep_error = Some(err);
6987                    }
6988                }
6989            }
6990
6991            if let Some(err) = ep_error {
6992                info.entry_point_names.push(Err(err));
6993                continue;
6994            }
6995            let fun_name = self.names[&NameKey::EntryPoint(ep_index as _)].clone();
6996            info.entry_point_names.push(Ok(fun_name.clone()));
6997
6998            writeln!(self.out)?;
6999
7000            // Since `Namer.reset` wasn't expecting struct members to be
7001            // suddenly injected into another namespace like this,
7002            // `self.names` doesn't keep them distinct from other variables.
7003            // Generate fresh names for these arguments, and remember the
7004            // mapping.
7005            let mut flattened_member_names = FastHashMap::default();
7006            // Varyings' members get their own namespace
7007            let mut varyings_namer = proc::Namer::default();
7008
7009            let mut empty_names = FastHashMap::default(); // Create a throwaway map
7010            varyings_namer.reset(
7011                module,
7012                &super::keywords::RESERVED_SET,
7013                proc::KeywordSet::empty(),
7014                proc::CaseInsensitiveKeywordSet::empty(),
7015                &[CLAMPED_LOD_LOAD_PREFIX],
7016                &mut empty_names,
7017            );
7018
7019            // List all the Naga `EntryPoint`'s `Function`'s arguments,
7020            // flattening structs into their members. In Metal, we will pass
7021            // each of these values to the entry point as a separate argument—
7022            // except for the varyings, handled next.
7023            let mut flattened_arguments = Vec::new();
7024            for (arg_index, arg) in fun.arguments.iter().enumerate() {
7025                match module.types[arg.ty].inner {
7026                    crate::TypeInner::Struct { ref members, .. } => {
7027                        for (member_index, member) in members.iter().enumerate() {
7028                            let member_index = member_index as u32;
7029                            flattened_arguments.push((
7030                                NameKey::StructMember(arg.ty, member_index),
7031                                member.ty,
7032                                member.binding.as_ref(),
7033                            ));
7034                            let name_key = NameKey::StructMember(arg.ty, member_index);
7035                            let name = match member.binding {
7036                                Some(crate::Binding::Location { .. }) => {
7037                                    if do_vertex_pulling {
7038                                        self.namer.call(&self.names[&name_key])
7039                                    } else {
7040                                        varyings_namer.call(&self.names[&name_key])
7041                                    }
7042                                }
7043                                _ => self.namer.call(&self.names[&name_key]),
7044                            };
7045                            flattened_member_names.insert(name_key, name);
7046                        }
7047                    }
7048                    _ => flattened_arguments.push((
7049                        NameKey::EntryPointArgument(ep_index as _, arg_index as u32),
7050                        arg.ty,
7051                        arg.binding.as_ref(),
7052                    )),
7053                }
7054            }
7055
7056            // Identify the varyings among the argument values, and maybe emit
7057            // a struct type named `<fun>Input` to hold them. If we are doing
7058            // vertex pulling, we instead update our attribute mapping to
7059            // note the types, names, and zero values of the attributes.
7060            let stage_in_name = self.namer.call(&format!("{fun_name}Input"));
7061            let varyings_member_name = self.namer.call("varyings");
7062            let mut has_varyings = false;
7063
7064            if !flattened_arguments.is_empty() {
7065                if !do_vertex_pulling {
7066                    writeln!(self.out, "struct {stage_in_name} {{")?;
7067                }
7068                for &(ref name_key, ty, binding) in flattened_arguments.iter() {
7069                    let Some(binding) = binding else {
7070                        continue;
7071                    };
7072                    let name = match *name_key {
7073                        NameKey::StructMember(..) => &flattened_member_names[name_key],
7074                        _ => &self.names[name_key],
7075                    };
7076                    let ty_name = TypeContext {
7077                        handle: ty,
7078                        gctx: module.to_ctx(),
7079                        names: &self.names,
7080                        access: crate::StorageAccess::empty(),
7081                        first_time: false,
7082                    };
7083                    let resolved = options.resolve_local_binding(binding, in_mode)?;
7084                    let location = match *binding {
7085                        crate::Binding::Location { location, .. } => Some(location),
7086                        crate::Binding::BuiltIn(crate::BuiltIn::Barycentric { .. }) => None,
7087                        crate::Binding::BuiltIn(_) => continue,
7088                    };
7089                    if do_vertex_pulling {
7090                        let Some(location) = location else {
7091                            continue;
7092                        };
7093                        // Update our attribute mapping.
7094                        am_resolved.insert(
7095                            location,
7096                            AttributeMappingResolved {
7097                                ty_name: ty_name.to_string(),
7098                                dimension: ty_name.vector_size(),
7099                                scalar: ty_name.scalar().unwrap(),
7100                                name: name.to_string(),
7101                            },
7102                        );
7103                    } else {
7104                        has_varyings = true;
7105                        if let super::ResolvedBinding::User {
7106                            prefix,
7107                            index,
7108                            interpolation: Some(super::ResolvedInterpolation::PerVertex),
7109                        } = resolved
7110                        {
7111                            if options.lang_version < (4, 0) {
7112                                return Err(Error::PerVertexNotSupported);
7113                            }
7114                            write!(
7115                                self.out,
7116                                "{}{NAMESPACE}::vertex_value<{}> {name} [[user({prefix}{index})]]",
7117                                back::INDENT,
7118                                ty_name.unwrap_array()
7119                            )?;
7120                        } else {
7121                            write!(self.out, "{}{} {}", back::INDENT, ty_name, name)?;
7122                            resolved.try_fmt(&mut self.out)?;
7123                        }
7124                        writeln!(self.out, ";")?;
7125                    }
7126                }
7127                if !do_vertex_pulling {
7128                    writeln!(self.out, "}};")?;
7129                }
7130            }
7131
7132            // Define a struct type named for the return value, if any, named
7133            // `<fun>Output`.
7134            let stage_out_name = self.namer.call(&format!("{fun_name}Output"));
7135            let result_member_name = self.namer.call("member");
7136            let result_type_name = match fun.result {
7137                Some(ref result) if ep.stage != crate::ShaderStage::Task => {
7138                    let mut result_members = Vec::new();
7139                    if let crate::TypeInner::Struct { ref members, .. } =
7140                        module.types[result.ty].inner
7141                    {
7142                        for (member_index, member) in members.iter().enumerate() {
7143                            result_members.push((
7144                                &self.names[&NameKey::StructMember(result.ty, member_index as u32)],
7145                                member.ty,
7146                                member.binding.as_ref(),
7147                            ));
7148                        }
7149                    } else {
7150                        result_members.push((
7151                            &result_member_name,
7152                            result.ty,
7153                            result.binding.as_ref(),
7154                        ));
7155                    }
7156
7157                    writeln!(self.out, "struct {stage_out_name} {{")?;
7158                    let mut has_point_size = false;
7159                    for (name, ty, binding) in result_members {
7160                        let ty_name = TypeContext {
7161                            handle: ty,
7162                            gctx: module.to_ctx(),
7163                            names: &self.names,
7164                            access: crate::StorageAccess::empty(),
7165                            first_time: true,
7166                        };
7167                        let binding = binding.ok_or_else(|| {
7168                            Error::GenericValidation("Expected binding, got None".into())
7169                        })?;
7170
7171                        if let crate::Binding::BuiltIn(crate::BuiltIn::PointSize) = *binding {
7172                            has_point_size = true;
7173                            if !pipeline_options.allow_and_force_point_size {
7174                                continue;
7175                            }
7176                        }
7177
7178                        let array_len = match module.types[ty].inner {
7179                            crate::TypeInner::Array {
7180                                size: crate::ArraySize::Constant(size),
7181                                ..
7182                            } => Some(size),
7183                            _ => None,
7184                        };
7185                        let resolved = options.resolve_local_binding(binding, out_mode)?;
7186                        write!(self.out, "{}{} {}", back::INDENT, ty_name, name)?;
7187                        resolved.try_fmt(&mut self.out)?;
7188                        if let Some(array_len) = array_len {
7189                            write!(self.out, " [{array_len}]")?;
7190                        }
7191                        writeln!(self.out, ";")?;
7192                    }
7193
7194                    if pipeline_options.allow_and_force_point_size
7195                        && ep.stage == crate::ShaderStage::Vertex
7196                        && !has_point_size
7197                    {
7198                        // inject the point size output last
7199                        writeln!(
7200                            self.out,
7201                            "{}float _point_size [[point_size]];",
7202                            back::INDENT
7203                        )?;
7204                    }
7205                    writeln!(self.out, "}};")?;
7206                    &stage_out_name
7207                }
7208                Some(ref result) if ep.stage == crate::ShaderStage::Task => {
7209                    assert_eq!(
7210                        module.types[result.ty].inner,
7211                        crate::TypeInner::Vector {
7212                            size: crate::VectorSize::Tri,
7213                            scalar: crate::Scalar::U32
7214                        }
7215                    );
7216
7217                    "metal::uint3"
7218                }
7219                _ => "void",
7220            };
7221
7222            let out_mesh_info = if let Some(ref mesh_info) = ep.mesh_info {
7223                Some(self.write_mesh_output_types(
7224                    mesh_info,
7225                    &fun_name,
7226                    module,
7227                    pipeline_options.allow_and_force_point_size,
7228                    options,
7229                )?)
7230            } else {
7231                None
7232            };
7233
7234            // If we're doing a vertex pulling transform, define the buffer
7235            // structure types.
7236            if do_vertex_pulling {
7237                for vbm in &vbm_resolved {
7238                    let buffer_stride = vbm.stride;
7239                    let buffer_ty = &vbm.ty_name;
7240
7241                    // Define a structure of bytes of the appropriate size.
7242                    // When we access the attributes, we'll be unpacking these
7243                    // bytes at some offset.
7244                    writeln!(
7245                        self.out,
7246                        "struct {buffer_ty} {{ metal::uchar data[{buffer_stride}]; }};"
7247                    )?;
7248                }
7249            }
7250
7251            let is_wrapped = matches!(
7252                ep.stage,
7253                crate::ShaderStage::Task | crate::ShaderStage::Mesh
7254            );
7255            let fun_name = fun_name.clone();
7256            let nested_fun_name = if is_wrapped {
7257                self.namer.call(&format!("_{fun_name}"))
7258            } else {
7259                fun_name.clone()
7260            };
7261
7262            // https://web.archive.org/web/20181029003926/https://developer.apple.com/metal/Metal-Shading-Language-Specification.pdf
7263            if ep.stage == crate::ShaderStage::Compute && options.lang_version >= (2, 1) {
7264                let total_threads =
7265                    ep.workgroup_size[0] * ep.workgroup_size[1] * ep.workgroup_size[2];
7266                write!(
7267                    self.out,
7268                    "[[max_total_threads_per_threadgroup({total_threads})]] "
7269                )?;
7270            }
7271
7272            // Write the entry point function's name, and begin its argument list.
7273            if let Some(em_str) = em_str {
7274                write!(self.out, "{em_str} ")?;
7275            }
7276            writeln!(self.out, "{result_type_name} {nested_fun_name}(")?;
7277
7278            let mut args = Vec::new();
7279
7280            // If we have produced a struct holding the `EntryPoint`'s
7281            // `Function`'s arguments' varyings, pass that struct first.
7282            if has_varyings {
7283                args.push(EntryPointArgument {
7284                    ty_name: stage_in_name,
7285                    name: varyings_member_name.clone(),
7286                    binding: " [[stage_in]]".to_string(),
7287                    init: None,
7288                });
7289            }
7290
7291            let mut local_invocation_index = None;
7292
7293            // Then pass the remaining arguments not included in the varyings
7294            // struct.
7295            for &(ref name_key, ty, binding) in flattened_arguments.iter() {
7296                let binding = match binding {
7297                    Some(&crate::Binding::BuiltIn(crate::BuiltIn::Barycentric { .. })) => continue,
7298                    Some(binding @ &crate::Binding::BuiltIn { .. }) => binding,
7299                    _ => continue,
7300                };
7301                let name = match *name_key {
7302                    NameKey::StructMember(..) => &flattened_member_names[name_key],
7303                    _ => &self.names[name_key],
7304                };
7305
7306                if binding == &crate::Binding::BuiltIn(crate::BuiltIn::LocalInvocationIndex) {
7307                    local_invocation_index = Some(name_key);
7308                }
7309
7310                let ty_name = TypeContext {
7311                    handle: ty,
7312                    gctx: module.to_ctx(),
7313                    names: &self.names,
7314                    access: crate::StorageAccess::empty(),
7315                    first_time: false,
7316                };
7317
7318                match *binding {
7319                    crate::Binding::BuiltIn(crate::BuiltIn::VertexIndex) => {
7320                        v_existing_id = Some(name.clone());
7321                    }
7322                    crate::Binding::BuiltIn(crate::BuiltIn::InstanceIndex) => {
7323                        i_existing_id = Some(name.clone());
7324                    }
7325                    _ => {}
7326                };
7327
7328                let resolved = options.resolve_local_binding(binding, in_mode)?;
7329                let mut binding = String::new();
7330                resolved.try_fmt(&mut binding)?;
7331
7332                args.push(EntryPointArgument {
7333                    ty_name: format!("{ty_name}"),
7334                    name: name.clone(),
7335                    binding,
7336                    init: None,
7337                });
7338            }
7339
7340            let need_workgroup_variables_initialization =
7341                self.need_workgroup_variables_initialization(options, ep, module, fun_info);
7342
7343            if local_invocation_index.is_none()
7344                && (need_workgroup_variables_initialization
7345                    || ep.stage == crate::ShaderStage::Task
7346                    || ep.stage == crate::ShaderStage::Mesh)
7347            {
7348                args.push(EntryPointArgument {
7349                    ty_name: "uint".to_string(),
7350                    name: "__local_invocation_index".to_string(),
7351                    binding: " [[thread_index_in_threadgroup]]".to_string(),
7352                    init: None,
7353                });
7354            }
7355
7356            // Those global variables used by this entry point and its callees
7357            // get passed as arguments. `Private` globals are an exception, they
7358            // don't outlive this invocation, so we declare them below as locals
7359            // within the entry point.
7360            for (handle, var) in module.global_variables.iter() {
7361                let usage = fun_info[handle];
7362                if usage.is_empty() || var.space == crate::AddressSpace::Private {
7363                    continue;
7364                }
7365
7366                if options.lang_version < (1, 2) {
7367                    match var.space {
7368                        // This restriction is not documented in the MSL spec
7369                        // but validation will fail if it is not upheld.
7370                        //
7371                        // We infer the required version from the "Function
7372                        // Buffer Read-Writes" section of [what's new], where
7373                        // the feature sets listed correspond with the ones
7374                        // supporting MSL 1.2.
7375                        //
7376                        // [what's new]: https://developer.apple.com/library/archive/documentation/Miscellaneous/Conceptual/MetalProgrammingGuide/WhatsNewiniOS10tvOS10andOSX1012/WhatsNewiniOS10tvOS10andOSX1012.html
7377                        crate::AddressSpace::Storage { access }
7378                            if access.contains(crate::StorageAccess::STORE)
7379                                && ep.stage == crate::ShaderStage::Fragment =>
7380                        {
7381                            return Err(Error::UnsupportedWritableStorageBuffer)
7382                        }
7383                        crate::AddressSpace::Handle => {
7384                            match module.types[var.ty].inner {
7385                                crate::TypeInner::Image {
7386                                    class: crate::ImageClass::Storage { access, .. },
7387                                    ..
7388                                } => {
7389                                    // This restriction is not documented in the MSL spec
7390                                    // but validation will fail if it is not upheld.
7391                                    //
7392                                    // We infer the required version from the "Function
7393                                    // Texture Read-Writes" section of [what's new], where
7394                                    // the feature sets listed correspond with the ones
7395                                    // supporting MSL 1.2.
7396                                    //
7397                                    // [what's new]: https://developer.apple.com/library/archive/documentation/Miscellaneous/Conceptual/MetalProgrammingGuide/WhatsNewiniOS10tvOS10andOSX1012/WhatsNewiniOS10tvOS10andOSX1012.html
7398                                    if access.contains(crate::StorageAccess::STORE)
7399                                        && (ep.stage == crate::ShaderStage::Vertex
7400                                            || ep.stage == crate::ShaderStage::Fragment)
7401                                    {
7402                                        return Err(Error::UnsupportedWritableStorageTexture(
7403                                            ep.stage,
7404                                        ));
7405                                    }
7406
7407                                    if access.contains(
7408                                        crate::StorageAccess::LOAD | crate::StorageAccess::STORE,
7409                                    ) {
7410                                        return Err(Error::UnsupportedRWStorageTexture);
7411                                    }
7412                                }
7413                                _ => {}
7414                            }
7415                        }
7416                        _ => {}
7417                    }
7418                }
7419
7420                // Check min MSL version for binding arrays
7421                match var.space {
7422                    crate::AddressSpace::Handle => match module.types[var.ty].inner {
7423                        crate::TypeInner::BindingArray { base, .. } => {
7424                            match module.types[base].inner {
7425                                crate::TypeInner::Sampler { .. } => {
7426                                    if options.lang_version < (2, 0) {
7427                                        return Err(Error::UnsupportedArrayOf(
7428                                            "samplers".to_string(),
7429                                        ));
7430                                    }
7431                                }
7432                                crate::TypeInner::Image { class, .. } => match class {
7433                                    crate::ImageClass::Sampled { .. }
7434                                    | crate::ImageClass::Depth { .. }
7435                                    | crate::ImageClass::Storage {
7436                                        access: crate::StorageAccess::LOAD,
7437                                        ..
7438                                    } => {
7439                                        // Array of textures since:
7440                                        // - iOS: Metal 1.2 (check depends on https://github.com/gfx-rs/naga/issues/2164)
7441                                        // - macOS: Metal 2
7442
7443                                        if options.lang_version < (2, 0) {
7444                                            return Err(Error::UnsupportedArrayOf(
7445                                                "textures".to_string(),
7446                                            ));
7447                                        }
7448                                    }
7449                                    crate::ImageClass::Storage {
7450                                        access: crate::StorageAccess::STORE,
7451                                        ..
7452                                    } => {
7453                                        // Array of write-only textures since:
7454                                        // - iOS: Metal 2.2 (check depends on https://github.com/gfx-rs/naga/issues/2164)
7455                                        // - macOS: Metal 2
7456
7457                                        if options.lang_version < (2, 0) {
7458                                            return Err(Error::UnsupportedArrayOf(
7459                                                "write-only textures".to_string(),
7460                                            ));
7461                                        }
7462                                    }
7463                                    crate::ImageClass::Storage { .. } => {
7464                                        if options.lang_version < (3, 0) {
7465                                            return Err(Error::UnsupportedArrayOf(
7466                                                "read-write textures".to_string(),
7467                                            ));
7468                                        }
7469                                    }
7470                                    crate::ImageClass::External => {
7471                                        return Err(Error::UnsupportedArrayOf(
7472                                            "external textures".to_string(),
7473                                        ));
7474                                    }
7475                                },
7476                                _ => {
7477                                    return Err(Error::UnsupportedArrayOfType(base));
7478                                }
7479                            }
7480                        }
7481                        _ => {}
7482                    },
7483                    _ => {}
7484                }
7485
7486                // the resolves have already been checked for `!fake_missing_bindings` case
7487                let resolved = match var.space {
7488                    crate::AddressSpace::Immediate => options.resolve_immediates(ep).ok(),
7489                    crate::AddressSpace::WorkGroup => None,
7490                    crate::AddressSpace::TaskPayload => Some(back::msl::ResolvedBinding::Payload),
7491                    _ => options
7492                        .resolve_resource_binding(ep, var.binding.as_ref().unwrap())
7493                        .ok(),
7494                };
7495                if let Some(ref resolved) = resolved {
7496                    // Inline samplers are be defined in the EP body
7497                    if resolved.as_inline_sampler(options).is_some() {
7498                        continue;
7499                    }
7500                }
7501
7502                match module.types[var.ty].inner {
7503                    crate::TypeInner::Image {
7504                        class: crate::ImageClass::External,
7505                        ..
7506                    } => {
7507                        // External texture global variables get lowered to 3 textures
7508                        // and a constant buffer. We must emit a separate argument for
7509                        // each of these.
7510                        let target = match resolved {
7511                            Some(back::msl::ResolvedBinding::Resource(target)) => {
7512                                target.external_texture
7513                            }
7514                            _ => None,
7515                        };
7516
7517                        for i in 0..3 {
7518                            let plane_name = &self.names[&NameKey::ExternalTextureGlobalVariable(
7519                                handle,
7520                                ExternalTextureNameKey::Plane(i),
7521                            )];
7522                            let ty_name = format!(
7523                                "{NAMESPACE}::texture2d<float, {NAMESPACE}::access::sample>"
7524                            );
7525                            let name = plane_name.clone();
7526                            let binding = if let Some(ref target) = target {
7527                                format!(" [[texture({})]]", target.planes[i])
7528                            } else {
7529                                String::new()
7530                            };
7531                            args.push(EntryPointArgument {
7532                                ty_name,
7533                                name,
7534                                binding,
7535                                init: None,
7536                            });
7537                        }
7538                        let params_ty_name = &self.names
7539                            [&NameKey::Type(module.special_types.external_texture_params.unwrap())];
7540                        let params_name = &self.names[&NameKey::ExternalTextureGlobalVariable(
7541                            handle,
7542                            ExternalTextureNameKey::Params,
7543                        )];
7544                        let binding = if let Some(ref target) = target {
7545                            format!(" [[buffer({})]]", target.params)
7546                        } else {
7547                            String::new()
7548                        };
7549
7550                        args.push(EntryPointArgument {
7551                            ty_name: format!("constant {params_ty_name}&"),
7552                            name: params_name.clone(),
7553                            binding,
7554                            init: None,
7555                        });
7556                    }
7557                    _ => {
7558                        if var.space == crate::AddressSpace::WorkGroup
7559                            && ep.stage == crate::ShaderStage::Mesh
7560                        {
7561                            continue;
7562                        }
7563                        let tyvar = TypedGlobalVariable {
7564                            module,
7565                            names: &self.names,
7566                            handle,
7567                            usage,
7568                            reference: true,
7569                        };
7570                        let parts = tyvar.to_parts()?;
7571                        let mut binding = String::new();
7572                        if let Some(resolved) = resolved {
7573                            resolved.try_fmt(&mut binding)?;
7574                        }
7575                        args.push(EntryPointArgument {
7576                            ty_name: parts.ty_name,
7577                            name: parts.var_name,
7578                            binding,
7579                            init: var.init,
7580                        });
7581                    }
7582                }
7583            }
7584
7585            if do_vertex_pulling {
7586                if needs_vertex_id && v_existing_id.is_none() {
7587                    // Write the [[vertex_id]] argument.
7588                    args.push(EntryPointArgument {
7589                        ty_name: "uint".to_string(),
7590                        name: v_id.clone(),
7591                        binding: " [[vertex_id]]".to_string(),
7592                        init: None,
7593                    });
7594                }
7595
7596                if needs_instance_id && i_existing_id.is_none() {
7597                    args.push(EntryPointArgument {
7598                        ty_name: "uint".to_string(),
7599                        name: i_id.clone(),
7600                        binding: " [[instance_id]]".to_string(),
7601                        init: None,
7602                    });
7603                }
7604
7605                // Iterate vbm_resolved, output one argument for every vertex buffer,
7606                // using the names we generated earlier.
7607                for vbm in &vbm_resolved {
7608                    let id = &vbm.id;
7609                    let ty_name = &vbm.ty_name;
7610                    let param_name = &vbm.param_name;
7611                    args.push(EntryPointArgument {
7612                        ty_name: format!("const device {ty_name}*"),
7613                        name: param_name.clone(),
7614                        binding: format!(" [[buffer({id})]]"),
7615                        init: None,
7616                    });
7617                }
7618            }
7619
7620            // If this entry uses any variable-length arrays, their sizes are
7621            // passed as a final struct-typed argument.
7622            if needs_buffer_sizes {
7623                // this is checked earlier
7624                let resolved = options.resolve_sizes_buffer(ep).unwrap();
7625                let mut binding = String::new();
7626                resolved.try_fmt(&mut binding)?;
7627                args.push(EntryPointArgument {
7628                    ty_name: "constant _mslBufferSizes&".to_string(),
7629                    name: "_buffer_sizes".to_string(),
7630                    binding,
7631                    init: None,
7632                });
7633            }
7634
7635            let mut is_first_arg = true;
7636            for arg in &args {
7637                if is_first_arg {
7638                    write!(self.out, "  ")?;
7639                } else {
7640                    write!(self.out, ", ")?;
7641                }
7642                is_first_arg = false;
7643                write!(self.out, "{} {}", arg.ty_name, arg.name)?;
7644                if !is_wrapped {
7645                    write!(self.out, "{}", arg.binding)?;
7646                    if let Some(init) = arg.init {
7647                        write!(self.out, " = ")?;
7648                        self.put_const_expression(
7649                            init,
7650                            module,
7651                            mod_info,
7652                            &module.global_expressions,
7653                        )?;
7654                    }
7655                }
7656                writeln!(self.out)?;
7657            }
7658            if ep.stage == crate::ShaderStage::Mesh {
7659                for (handle, var) in module.global_variables.iter() {
7660                    if var.space != crate::AddressSpace::WorkGroup || fun_info[handle].is_empty() {
7661                        continue;
7662                    }
7663                    if is_first_arg {
7664                        write!(self.out, "  ")?;
7665                    } else {
7666                        write!(self.out, ", ")?;
7667                    }
7668                    let ty_context = TypeContext {
7669                        handle: module.global_variables[handle].ty,
7670                        gctx: module.to_ctx(),
7671                        names: &self.names,
7672                        access: crate::StorageAccess::empty(),
7673                        first_time: false,
7674                    };
7675                    writeln!(
7676                        self.out,
7677                        "threadgroup {ty_context}& {}",
7678                        self.names[&NameKey::GlobalVariable(handle)]
7679                    )?;
7680                }
7681            }
7682
7683            // end of the entry point argument list
7684            writeln!(self.out, ") {{")?;
7685
7686            // Starting the function body.
7687            if do_vertex_pulling {
7688                // Provide zero values for all the attributes, which we will overwrite with
7689                // real data from the vertex attribute buffers, if the indices are in-bounds.
7690                for vbm in &vbm_resolved {
7691                    for attribute in vbm.attributes {
7692                        let location = attribute.shader_location;
7693                        let am_option = am_resolved.get(&location);
7694                        if am_option.is_none() {
7695                            // This bound attribute isn't used in this entry point, so
7696                            // don't bother zero-initializing it.
7697                            continue;
7698                        }
7699                        let am = am_option.unwrap();
7700                        let attribute_ty_name = &am.ty_name;
7701                        let attribute_name = &am.name;
7702
7703                        writeln!(
7704                            self.out,
7705                            "{}{attribute_ty_name} {attribute_name} = {{}};",
7706                            back::Level(1)
7707                        )?;
7708                    }
7709
7710                    // Output a bounds check block that will set real values for the
7711                    // attributes, if the bounds are satisfied.
7712                    write!(self.out, "{}if (", back::Level(1))?;
7713
7714                    let idx = &vbm.id;
7715                    let stride = &vbm.stride;
7716                    let index_name = match vbm.step_mode {
7717                        back::msl::VertexBufferStepMode::Constant => "0",
7718                        back::msl::VertexBufferStepMode::ByVertex => {
7719                            if let Some(ref name) = v_existing_id {
7720                                name
7721                            } else {
7722                                &v_id
7723                            }
7724                        }
7725                        back::msl::VertexBufferStepMode::ByInstance => {
7726                            if let Some(ref name) = i_existing_id {
7727                                name
7728                            } else {
7729                                &i_id
7730                            }
7731                        }
7732                    };
7733                    write!(
7734                        self.out,
7735                        "{index_name} < (_buffer_sizes.buffer_size{idx} / {stride})"
7736                    )?;
7737
7738                    writeln!(self.out, ") {{")?;
7739
7740                    // Pull the bytes out of the vertex buffer.
7741                    let ty_name = &vbm.ty_name;
7742                    let elem_name = &vbm.elem_name;
7743                    let param_name = &vbm.param_name;
7744
7745                    writeln!(
7746                        self.out,
7747                        "{}const {ty_name} {elem_name} = {param_name}[{index_name}];",
7748                        back::Level(2),
7749                    )?;
7750
7751                    // Now set real values for each of the attributes, by unpacking the data
7752                    // from the buffer elements.
7753                    for attribute in vbm.attributes {
7754                        let location = attribute.shader_location;
7755                        let Some(am) = am_resolved.get(&location) else {
7756                            // This bound attribute isn't used in this entry point, so
7757                            // don't bother extracting the data. Too bad we emitted the
7758                            // unpacking function earlier -- it might not get used.
7759                            continue;
7760                        };
7761                        let attribute_name = &am.name;
7762                        let attribute_ty_name = &am.ty_name;
7763
7764                        let offset = attribute.offset;
7765                        let func = unpacking_functions
7766                            .get(&attribute.format)
7767                            .expect("Should have generated this unpacking function earlier.");
7768                        let func_name = &func.name;
7769
7770                        // Check dimensionality of the attribute compared to the unpacking
7771                        // function. If attribute dimension > unpack dimension, we have to
7772                        // pad out the unpack value from a vec4(0, 0, 0, 1) of matching
7773                        // scalar type. Otherwise, if attribute dimension is < unpack
7774                        // dimension, then we need to explicitly truncate the result.
7775                        let needs_padding_or_truncation = am.dimension.cmp(&func.dimension);
7776
7777                        // We need an extra type conversion if the shader type does not
7778                        // match the type returned from the unpacking function.
7779                        let needs_conversion = am.scalar != func.scalar;
7780
7781                        if needs_padding_or_truncation != Ordering::Equal {
7782                            // Emit a comment flagging that a conversion is happening,
7783                            // since the actual logic can be at the end of a long line.
7784                            writeln!(
7785                                self.out,
7786                                "{}// {attribute_ty_name} <- {:?}",
7787                                back::Level(2),
7788                                attribute.format
7789                            )?;
7790                        }
7791
7792                        write!(self.out, "{}{attribute_name} = ", back::Level(2),)?;
7793
7794                        if needs_padding_or_truncation == Ordering::Greater {
7795                            // Needs padding: emit constructor call for wider type
7796                            write!(self.out, "{attribute_ty_name}(")?;
7797                        }
7798
7799                        // Emit call to unpacking function
7800                        if needs_conversion {
7801                            put_numeric_type(&mut self.out, am.scalar, func.dimension.as_slice())?;
7802                            write!(self.out, "(")?;
7803                        }
7804                        write!(self.out, "{func_name}({elem_name}.data[{offset}]")?;
7805                        for i in (offset + 1)..(offset + func.byte_count) {
7806                            write!(self.out, ", {elem_name}.data[{i}]")?;
7807                        }
7808                        write!(self.out, ")")?;
7809                        if needs_conversion {
7810                            write!(self.out, ")")?;
7811                        }
7812
7813                        match needs_padding_or_truncation {
7814                            Ordering::Greater => {
7815                                // Padding
7816                                let ty_is_int = scalar_is_int(am.scalar);
7817                                let zero_value = if ty_is_int { "0" } else { "0.0" };
7818                                let one_value = if ty_is_int { "1" } else { "1.0" };
7819                                for i in func.dimension.map_or(1, u8::from)
7820                                    ..am.dimension.map_or(1, u8::from)
7821                                {
7822                                    write!(
7823                                        self.out,
7824                                        ", {}",
7825                                        if i == 3 { one_value } else { zero_value }
7826                                    )?;
7827                                }
7828                            }
7829                            Ordering::Less => {
7830                                // Truncate to the first `am.dimension` components
7831                                write!(
7832                                    self.out,
7833                                    ".{}",
7834                                    &"xyzw"[0..usize::from(am.dimension.map_or(1, u8::from))]
7835                                )?;
7836                            }
7837                            Ordering::Equal => {}
7838                        }
7839
7840                        if needs_padding_or_truncation == Ordering::Greater {
7841                            write!(self.out, ")")?;
7842                        }
7843
7844                        writeln!(self.out, ";")?;
7845                    }
7846
7847                    // End the bounds check / attribute setting block.
7848                    writeln!(self.out, "{}}}", back::Level(1))?;
7849                }
7850            }
7851
7852            // Metal doesn't support private mutable variables outside of functions,
7853            // so we put them here, just like the locals.
7854            for (handle, var) in module.global_variables.iter() {
7855                let usage = fun_info[handle];
7856                if usage.is_empty() {
7857                    continue;
7858                }
7859                if var.space == crate::AddressSpace::Private {
7860                    let tyvar = TypedGlobalVariable {
7861                        module,
7862                        names: &self.names,
7863                        handle,
7864                        usage,
7865
7866                        reference: false,
7867                    };
7868                    write!(self.out, "{}", back::INDENT)?;
7869                    tyvar.try_fmt(&mut self.out)?;
7870                    match var.init {
7871                        Some(value) => {
7872                            write!(self.out, " = ")?;
7873                            self.put_const_expression(
7874                                value,
7875                                module,
7876                                mod_info,
7877                                &module.global_expressions,
7878                            )?;
7879                            writeln!(self.out, ";")?;
7880                        }
7881                        None => {
7882                            writeln!(self.out, " = {{}};")?;
7883                        }
7884                    };
7885                } else if let Some(ref binding) = var.binding {
7886                    let resolved = options.resolve_resource_binding(ep, binding).unwrap();
7887                    if let Some(sampler) = resolved.as_inline_sampler(options) {
7888                        // write an inline sampler
7889                        let name = &self.names[&NameKey::GlobalVariable(handle)];
7890                        writeln!(
7891                            self.out,
7892                            "{}constexpr {}::sampler {}(",
7893                            back::INDENT,
7894                            NAMESPACE,
7895                            name
7896                        )?;
7897                        self.put_inline_sampler_properties(back::Level(2), sampler)?;
7898                        writeln!(self.out, "{});", back::INDENT)?;
7899                    } else if let crate::TypeInner::Image {
7900                        class: crate::ImageClass::External,
7901                        ..
7902                    } = module.types[var.ty].inner
7903                    {
7904                        // Wrap the individual arguments for each external texture global
7905                        // in a struct which can be easily passed around.
7906                        let wrapper_name = &self.names[&NameKey::GlobalVariable(handle)];
7907                        let l1 = back::Level(1);
7908                        let l2 = l1.next();
7909                        writeln!(
7910                            self.out,
7911                            "{l1}const {EXTERNAL_TEXTURE_WRAPPER_STRUCT} {wrapper_name} {{"
7912                        )?;
7913                        for i in 0..3 {
7914                            let plane_name = &self.names[&NameKey::ExternalTextureGlobalVariable(
7915                                handle,
7916                                ExternalTextureNameKey::Plane(i),
7917                            )];
7918                            writeln!(self.out, "{l2}.plane{i} = {plane_name},")?;
7919                        }
7920                        let params_name = &self.names[&NameKey::ExternalTextureGlobalVariable(
7921                            handle,
7922                            ExternalTextureNameKey::Params,
7923                        )];
7924                        writeln!(self.out, "{l2}.params = {params_name},")?;
7925                        writeln!(self.out, "{l1}}};")?;
7926                    }
7927                }
7928            }
7929
7930            if need_workgroup_variables_initialization {
7931                self.write_workgroup_variables_initialization(
7932                    module,
7933                    mod_info,
7934                    fun_info,
7935                    local_invocation_index,
7936                    ep.stage,
7937                )?;
7938            }
7939
7940            // Now take the arguments that we gathered into structs, and the
7941            // structs that we flattened into arguments, and emit local
7942            // variables with initializers that put everything back the way the
7943            // body code expects.
7944            //
7945            // If we had to generate fresh names for struct members passed as
7946            // arguments, be sure to use those names when rebuilding the struct.
7947            //
7948            // "Each day, I change some zeros to ones, and some ones to zeros.
7949            // The rest, I leave alone."
7950            for (arg_index, arg) in fun.arguments.iter().enumerate() {
7951                let arg_name =
7952                    &self.names[&NameKey::EntryPointArgument(ep_index as _, arg_index as u32)];
7953                match module.types[arg.ty].inner {
7954                    crate::TypeInner::Struct { ref members, .. } => {
7955                        let struct_name = &self.names[&NameKey::Type(arg.ty)];
7956                        write!(
7957                            self.out,
7958                            "{}const {} {} = {{ ",
7959                            back::INDENT,
7960                            struct_name,
7961                            arg_name
7962                        )?;
7963                        for (member_index, member) in members.iter().enumerate() {
7964                            let key = NameKey::StructMember(arg.ty, member_index as u32);
7965                            let name = &flattened_member_names[&key];
7966                            if member_index != 0 {
7967                                write!(self.out, ", ")?;
7968                            }
7969                            // insert padding initialization, if needed
7970                            if self
7971                                .struct_member_pads
7972                                .contains(&(arg.ty, member_index as u32))
7973                            {
7974                                write!(self.out, "{{}}, ")?;
7975                            }
7976                            match member.binding {
7977                                Some(crate::Binding::Location {
7978                                    interpolation: Some(crate::Interpolation::PerVertex),
7979                                    ..
7980                                }) => {
7981                                    writeln!(
7982                                        self.out,
7983                                        "{0}{{ {1}.{2}.get({NAMESPACE}::vertex_index::first), {1}.{2}.get({NAMESPACE}::vertex_index::second), {1}.{2}.get({NAMESPACE}::vertex_index::third) }}",
7984                                        back::INDENT,
7985                                        varyings_member_name,
7986                                        arg_name,
7987                                    )?;
7988                                    continue;
7989                                }
7990                                Some(crate::Binding::Location { .. }) => {
7991                                    if has_varyings {
7992                                        write!(self.out, "{varyings_member_name}.")?;
7993                                    }
7994                                }
7995                                _ => (),
7996                            }
7997                            write!(self.out, "{name}")?;
7998                        }
7999                        writeln!(self.out, " }};")?;
8000                    }
8001                    _ => match arg.binding {
8002                        Some(crate::Binding::Location {
8003                            interpolation: Some(crate::Interpolation::PerVertex),
8004                            ..
8005                        }) => {
8006                            let ty_name = TypeContext {
8007                                handle: arg.ty,
8008                                gctx: module.to_ctx(),
8009                                names: &self.names,
8010                                access: crate::StorageAccess::empty(),
8011                                first_time: false,
8012                            };
8013                            writeln!(
8014                                self.out,
8015                                "{0}const {ty_name} {arg_name} = {{ {1}.{2}.get({NAMESPACE}::vertex_index::first), {1}.{2}.get({NAMESPACE}::vertex_index::second), {1}.{2}.get({NAMESPACE}::vertex_index::third) }};",
8016                                back::INDENT,
8017                                varyings_member_name,
8018                                arg_name,
8019                            )?;
8020                        }
8021                        Some(crate::Binding::Location { .. })
8022                        | Some(crate::Binding::BuiltIn(crate::BuiltIn::Barycentric { .. })) => {
8023                            if has_varyings {
8024                                writeln!(
8025                                    self.out,
8026                                    "{}const auto {} = {}.{};",
8027                                    back::INDENT,
8028                                    arg_name,
8029                                    varyings_member_name,
8030                                    arg_name
8031                                )?;
8032                            }
8033                        }
8034                        _ => {}
8035                    },
8036                }
8037            }
8038
8039            let guarded_indices =
8040                index::find_checked_indexes(module, fun, fun_info, options.bounds_check_policies);
8041
8042            let context = StatementContext {
8043                expression: ExpressionContext {
8044                    function: fun,
8045                    origin: FunctionOrigin::EntryPoint(ep_index as _),
8046                    info: fun_info,
8047                    lang_version: options.lang_version,
8048                    policies: options.bounds_check_policies,
8049                    guarded_indices,
8050                    module,
8051                    mod_info,
8052                    pipeline_options,
8053                    force_loop_bounding: options.force_loop_bounding,
8054                },
8055                result_struct: if ep.stage == crate::ShaderStage::Task {
8056                    None
8057                } else {
8058                    Some(&stage_out_name)
8059                },
8060            };
8061
8062            // Finally, declare all the local variables that we need
8063            //TODO: we can postpone this till the relevant expressions are emitted
8064            self.put_locals(&context.expression)?;
8065            self.update_expressions_to_bake(fun, fun_info, &context.expression);
8066            self.put_block(back::Level(1), &fun.body, &context)?;
8067            writeln!(self.out, "}}")?;
8068            if ep_index + 1 != module.entry_points.len() {
8069                writeln!(self.out)?;
8070            }
8071            self.named_expressions.clear();
8072
8073            if is_wrapped {
8074                self.write_wrapper_function(NestedFunctionInfo {
8075                    options,
8076                    ep,
8077                    module,
8078                    mod_info,
8079                    fun_info,
8080                    args,
8081                    local_invocation_index,
8082                    nested_name: &nested_fun_name,
8083                    outer_name: &fun_name,
8084                    out_mesh_info,
8085                })?;
8086            }
8087        }
8088
8089        Ok(info)
8090    }
8091
8092    pub(super) fn write_barrier(
8093        &mut self,
8094        flags: crate::Barrier,
8095        level: back::Level,
8096    ) -> BackendResult {
8097        // Note: OR-ring bitflags requires `__HAVE_MEMFLAG_OPERATORS__`,
8098        // so we try to avoid it here.
8099        if flags.is_empty() {
8100            writeln!(
8101                self.out,
8102                "{level}{NAMESPACE}::threadgroup_barrier({NAMESPACE}::mem_flags::mem_none);",
8103            )?;
8104        }
8105        if flags.contains(crate::Barrier::STORAGE) {
8106            writeln!(
8107                self.out,
8108                "{level}{NAMESPACE}::threadgroup_barrier({NAMESPACE}::mem_flags::mem_device);",
8109            )?;
8110        }
8111        if flags.contains(crate::Barrier::WORK_GROUP) {
8112            writeln!(
8113                self.out,
8114                "{level}{NAMESPACE}::threadgroup_barrier({NAMESPACE}::mem_flags::mem_threadgroup);",
8115            )?;
8116            if self.needs_object_memory_barriers {
8117                writeln!(
8118                    self.out,
8119                    "{level}{NAMESPACE}::threadgroup_barrier({NAMESPACE}::mem_flags::mem_object_data);",
8120                )?;
8121            }
8122        }
8123        if flags.contains(crate::Barrier::SUB_GROUP) {
8124            writeln!(
8125                self.out,
8126                "{level}{NAMESPACE}::simdgroup_barrier({NAMESPACE}::mem_flags::mem_threadgroup);",
8127            )?;
8128        }
8129        if flags.contains(crate::Barrier::TEXTURE) {
8130            writeln!(
8131                self.out,
8132                "{level}{NAMESPACE}::threadgroup_barrier({NAMESPACE}::mem_flags::mem_texture);",
8133            )?;
8134        }
8135        Ok(())
8136    }
8137}
8138
8139/// Initializing workgroup variables is more tricky for Metal because we have to deal
8140/// with atomics at the type-level (which don't have a copy constructor).
8141mod workgroup_mem_init {
8142    use crate::EntryPoint;
8143
8144    use super::*;
8145
8146    enum Access {
8147        GlobalVariable(Handle<crate::GlobalVariable>),
8148        StructMember(Handle<crate::Type>, u32),
8149        Array(usize),
8150    }
8151
8152    impl Access {
8153        fn write<W: Write>(
8154            &self,
8155            writer: &mut W,
8156            names: &FastHashMap<NameKey, String>,
8157        ) -> Result<(), core::fmt::Error> {
8158            match *self {
8159                Access::GlobalVariable(handle) => {
8160                    write!(writer, "{}", &names[&NameKey::GlobalVariable(handle)])
8161                }
8162                Access::StructMember(handle, index) => {
8163                    write!(writer, ".{}", &names[&NameKey::StructMember(handle, index)])
8164                }
8165                Access::Array(depth) => write!(writer, ".{WRAPPED_ARRAY_FIELD}[__i{depth}]"),
8166            }
8167        }
8168    }
8169
8170    struct AccessStack {
8171        stack: Vec<Access>,
8172        array_depth: usize,
8173    }
8174
8175    impl AccessStack {
8176        const fn new() -> Self {
8177            Self {
8178                stack: Vec::new(),
8179                array_depth: 0,
8180            }
8181        }
8182
8183        fn enter_array<R>(&mut self, cb: impl FnOnce(&mut Self, usize) -> R) -> R {
8184            let array_depth = self.array_depth;
8185            self.stack.push(Access::Array(array_depth));
8186            self.array_depth += 1;
8187            let res = cb(self, array_depth);
8188            self.stack.pop();
8189            self.array_depth -= 1;
8190            res
8191        }
8192
8193        fn enter<R>(&mut self, new: Access, cb: impl FnOnce(&mut Self) -> R) -> R {
8194            self.stack.push(new);
8195            let res = cb(self);
8196            self.stack.pop();
8197            res
8198        }
8199
8200        fn write<W: Write>(
8201            &self,
8202            writer: &mut W,
8203            names: &FastHashMap<NameKey, String>,
8204        ) -> Result<(), core::fmt::Error> {
8205            for next in self.stack.iter() {
8206                next.write(writer, names)?;
8207            }
8208            Ok(())
8209        }
8210    }
8211
8212    impl<W: Write> Writer<W> {
8213        pub(super) fn need_workgroup_variables_initialization(
8214            &mut self,
8215            options: &Options,
8216            ep: &EntryPoint,
8217            module: &crate::Module,
8218            fun_info: &valid::FunctionInfo,
8219        ) -> bool {
8220            let is_task = ep.stage == crate::ShaderStage::Task;
8221            options.zero_initialize_workgroup_memory
8222                && ep.stage.compute_like()
8223                && module.global_variables.iter().any(|(handle, var)| {
8224                    let is_right_address_space = var.space == crate::AddressSpace::WorkGroup
8225                        || (var.space == crate::AddressSpace::TaskPayload && is_task);
8226                    !fun_info[handle].is_empty() && is_right_address_space
8227                })
8228        }
8229
8230        pub fn write_workgroup_variables_initialization(
8231            &mut self,
8232            module: &crate::Module,
8233            module_info: &valid::ModuleInfo,
8234            fun_info: &valid::FunctionInfo,
8235            local_invocation_index: Option<&NameKey>,
8236            stage: crate::ShaderStage,
8237        ) -> BackendResult {
8238            let level = back::Level(1);
8239
8240            writeln!(
8241                self.out,
8242                "{}if ({} == 0u) {{",
8243                level,
8244                local_invocation_index
8245                    .map(|name_key| self.names[name_key].as_str())
8246                    .unwrap_or("__local_invocation_index"),
8247            )?;
8248
8249            let mut access_stack = AccessStack::new();
8250
8251            let is_task = stage == crate::ShaderStage::Task;
8252            let vars = module.global_variables.iter().filter(|&(handle, var)| {
8253                let is_right_address_space = var.space == crate::AddressSpace::WorkGroup
8254                    || (var.space == crate::AddressSpace::TaskPayload && is_task);
8255                !fun_info[handle].is_empty() && is_right_address_space
8256            });
8257
8258            for (handle, var) in vars {
8259                access_stack.enter(Access::GlobalVariable(handle), |access_stack| {
8260                    self.write_workgroup_variable_initialization(
8261                        module,
8262                        module_info,
8263                        var.ty,
8264                        access_stack,
8265                        level.next(),
8266                    )
8267                })?;
8268            }
8269
8270            writeln!(self.out, "{level}}}")?;
8271            self.write_barrier(crate::Barrier::WORK_GROUP, level)
8272        }
8273
8274        fn write_workgroup_variable_initialization(
8275            &mut self,
8276            module: &crate::Module,
8277            module_info: &valid::ModuleInfo,
8278            ty: Handle<crate::Type>,
8279            access_stack: &mut AccessStack,
8280            level: back::Level,
8281        ) -> BackendResult {
8282            if module_info[ty].contains(valid::TypeFlags::CONSTRUCTIBLE) {
8283                write!(self.out, "{level}")?;
8284                access_stack.write(&mut self.out, &self.names)?;
8285                writeln!(self.out, " = {{}};")?;
8286            } else {
8287                match module.types[ty].inner {
8288                    crate::TypeInner::Atomic { .. } => {
8289                        write!(
8290                            self.out,
8291                            "{level}{NAMESPACE}::atomic_store_explicit({ATOMIC_REFERENCE}"
8292                        )?;
8293                        access_stack.write(&mut self.out, &self.names)?;
8294                        writeln!(self.out, ", 0, {NAMESPACE}::memory_order_relaxed);")?;
8295                    }
8296                    crate::TypeInner::Array { base, size, .. } => {
8297                        let count = match size.resolve(module.to_ctx())? {
8298                            proc::IndexableLength::Known(count) => count,
8299                            proc::IndexableLength::Dynamic => unreachable!(),
8300                        };
8301
8302                        access_stack.enter_array(|access_stack, array_depth| {
8303                            writeln!(
8304                                self.out,
8305                                "{level}for (int __i{array_depth} = 0; __i{array_depth} < {count}; __i{array_depth}++) {{"
8306                            )?;
8307                            self.write_workgroup_variable_initialization(
8308                                module,
8309                                module_info,
8310                                base,
8311                                access_stack,
8312                                level.next(),
8313                            )?;
8314                            writeln!(self.out, "{level}}}")?;
8315                            BackendResult::Ok(())
8316                        })?;
8317                    }
8318                    crate::TypeInner::Struct { ref members, .. } => {
8319                        for (index, member) in members.iter().enumerate() {
8320                            access_stack.enter(
8321                                Access::StructMember(ty, index as u32),
8322                                |access_stack| {
8323                                    self.write_workgroup_variable_initialization(
8324                                        module,
8325                                        module_info,
8326                                        member.ty,
8327                                        access_stack,
8328                                        level,
8329                                    )
8330                                },
8331                            )?;
8332                        }
8333                    }
8334                    _ => unreachable!(),
8335                }
8336            }
8337
8338            Ok(())
8339        }
8340    }
8341}
8342
8343impl crate::AtomicFunction {
8344    const fn to_msl(self) -> &'static str {
8345        match self {
8346            Self::Add => "fetch_add",
8347            Self::Subtract => "fetch_sub",
8348            Self::And => "fetch_and",
8349            Self::InclusiveOr => "fetch_or",
8350            Self::ExclusiveOr => "fetch_xor",
8351            Self::Min => "fetch_min",
8352            Self::Max => "fetch_max",
8353            Self::Exchange { compare: None } => "exchange",
8354            Self::Exchange { compare: Some(_) } => ATOMIC_COMP_EXCH_FUNCTION,
8355        }
8356    }
8357
8358    fn to_msl_64_bit(self) -> Result<&'static str, Error> {
8359        Ok(match self {
8360            Self::Min => "min",
8361            Self::Max => "max",
8362            _ => Err(Error::FeatureNotImplemented(
8363                "64-bit atomic operation other than min/max".to_string(),
8364            ))?,
8365        })
8366    }
8367}